Merge branch 'dev' into feature/cog-3698-enable-batch-queries-in-all-graph-completion-retrievers

This commit is contained in:
lxobr 2026-01-20 19:42:26 +01:00 committed by GitHub
commit 6070f9f71f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
91 changed files with 2286 additions and 139 deletions

View file

@ -659,3 +659,51 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_pipeline_cache.py
run_usage_logger_test:
name: Usage logger test (API/MCP)
runs-on: ubuntu-latest
defaults:
run:
shell: bash
services:
redis:
image: redis:7
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 5s
--health-timeout 3s
--health-retries 5
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "redis"
- name: Install cognee-mcp (local version)
shell: bash
run: |
uv pip install -e ./cognee-mcp
- name: Run api/tool usage logger
env:
ENV: dev
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
USAGE_LOGGING: true
CACHE_BACKEND: 'redis'
run: uv run pytest cognee/tests/test_usage_logger_e2e.py -v --log-level=INFO

View file

@ -34,10 +34,6 @@ COPY README.md pyproject.toml uv.lock entrypoint.sh ./
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --extra chromadb --frozen --no-install-project --no-dev --no-editable
# Copy Alembic configuration
COPY alembic.ini /app/alembic.ini
COPY alembic/ /app/alembic
# Then, add the rest of the project source code and install it
# Installing separately from its dependencies allows optimal layer caching
COPY ./cognee /app/cognee

View file

@ -34,8 +34,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --frozen --no-install-project --no-dev --no-editable
# Copy Alembic configuration
COPY alembic.ini /app/alembic.ini
COPY alembic/ /app/alembic
COPY cognee/alembic.ini /app/cognee/alembic.ini
COPY cognee/alembic/ /app/cognee/alembic
# Then, add the rest of the project source code and install it
# Installing separately from its dependencies allows optimal layer caching

View file

@ -56,24 +56,22 @@ if [ -n "$API_URL" ]; then
echo "Skipping database migrations (API server handles its own database)"
else
echo "Direct mode: Using local cognee instance"
# Run Alembic migrations with proper error handling.
# Note on UserAlreadyExists error handling:
# During database migrations, we attempt to create a default user. If this user
# already exists (e.g., from a previous deployment or migration), it's not a
# critical error and shouldn't prevent the application from starting. This is
# different from other migration errors which could indicate database schema
# inconsistencies and should cause the startup to fail. This check allows for
# smooth redeployments and container restarts while maintaining data integrity.
echo "Running database migrations..."
MIGRATION_OUTPUT=$(alembic upgrade head)
set +e # Disable exit on error to handle specific migration errors
MIGRATION_OUTPUT=$(cd cognee && alembic upgrade head)
MIGRATION_EXIT_CODE=$?
set -e
if [[ $MIGRATION_EXIT_CODE -ne 0 ]]; then
if [[ "$MIGRATION_OUTPUT" == *"UserAlreadyExists"* ]] || [[ "$MIGRATION_OUTPUT" == *"User default_user@example.com already exists"* ]]; then
echo "Warning: Default user already exists, continuing startup..."
else
echo "Migration failed with unexpected error."
echo "Migration failed with unexpected error. Trying to run Cognee without migrations."
echo "Initializing database tables..."
python /app/src/run_cognee_database_setup.py
INIT_EXIT_CODE=$?
if [[ $INIT_EXIT_CODE -ne 0 ]]; then
echo "Database initialization failed!"
exit 1
fi
fi

View file

@ -8,7 +8,6 @@ requires-python = ">=3.10"
dependencies = [
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee",
# TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4
"cognee[postgres,docs,neo4j]==0.5.0",
"fastmcp>=2.10.0,<3.0.0",
"mcp>=1.12.0,<2.0.0",

View file

@ -0,0 +1,5 @@
from cognee.modules.engine.operations.setup import setup
import asyncio
if __name__ == "__main__":
asyncio.run(setup())

View file

@ -8,6 +8,7 @@ from pathlib import Path
from typing import Optional
from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location
from cognee.shared.usage_logger import log_usage
import importlib.util
from contextlib import redirect_stdout
import mcp.types as types
@ -91,6 +92,7 @@ async def health_check(request):
@mcp.tool()
@log_usage(function_name="MCP cognify", log_type="mcp_tool")
async def cognify(
data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None
) -> list:
@ -257,6 +259,7 @@ async def cognify(
@mcp.tool(
name="save_interaction", description="Logs user-agent interactions and query-answer pairs"
)
@log_usage(function_name="MCP save_interaction", log_type="mcp_tool")
async def save_interaction(data: str) -> list:
"""
Transform and save a user-agent interaction into structured knowledge.
@ -316,6 +319,7 @@ async def save_interaction(data: str) -> list:
@mcp.tool()
@log_usage(function_name="MCP search", log_type="mcp_tool")
async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -496,6 +500,7 @@ async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
@mcp.tool()
@log_usage(function_name="MCP list_data", log_type="mcp_tool")
async def list_data(dataset_id: str = None) -> list:
"""
List all datasets and their data items with IDs for deletion operations.
@ -624,6 +629,7 @@ async def list_data(dataset_id: str = None) -> list:
@mcp.tool()
@log_usage(function_name="MCP delete", log_type="mcp_tool")
async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
"""
Delete specific data from a dataset in the Cognee knowledge graph.
@ -703,6 +709,7 @@ async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
@mcp.tool()
@log_usage(function_name="MCP prune", log_type="mcp_tool")
async def prune():
"""
Reset the Cognee knowledge graph by removing all stored information.
@ -739,6 +746,7 @@ async def prune():
@mcp.tool()
@log_usage(function_name="MCP cognify_status", log_type="mcp_tool")
async def cognify_status():
"""
Get the current status of the cognify pipeline.
@ -884,26 +892,11 @@ async def main():
await setup()
# Run Alembic migrations from the main cognee directory where alembic.ini is located
# Run Cognee migrations
logger.info("Running database migrations...")
migration_result = subprocess.run(
["python", "-m", "alembic", "upgrade", "head"],
capture_output=True,
text=True,
cwd=Path(__file__).resolve().parent.parent.parent,
)
from cognee.run_migrations import run_migrations
if migration_result.returncode != 0:
migration_output = migration_result.stderr + migration_result.stdout
# Check for the expected UserAlreadyExists error (which is not critical)
if (
"UserAlreadyExists" in migration_output
or "User default_user@example.com already exists" in migration_output
):
logger.warning("Warning: Default user already exists, continuing startup...")
else:
logger.error(f"Migration failed with unexpected error: {migration_output}")
sys.exit(1)
await run_migrations()
logger.info("Database migrations done.")
elif args.api_url:

View file

@ -33,3 +33,5 @@ from .api.v1.ui import start_ui
# Pipelines
from .modules import pipelines
from cognee.run_migrations import run_migrations

View file

@ -10,6 +10,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
logger = get_logger()
@ -19,6 +20,7 @@ def get_add_router() -> APIRouter:
router = APIRouter()
@router.post("", response_model=dict)
@log_usage(function_name="POST /v1/add", log_type="api_endpoint")
async def add(
data: List[UploadFile] = File(default=None),
datasetName: Optional[str] = Form(default=None),

View file

@ -29,6 +29,7 @@ from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
)
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
logger = get_logger("api.cognify")
@ -57,6 +58,7 @@ def get_cognify_router() -> APIRouter:
router = APIRouter()
@router.post("", response_model=dict)
@log_usage(function_name="POST /v1/cognify", log_type="api_endpoint")
async def cognify(payload: CognifyPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Transform datasets into structured knowledge graphs through cognitive processing.

View file

@ -12,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
logger = get_logger()
@ -35,6 +36,7 @@ def get_memify_router() -> APIRouter:
router = APIRouter()
@router.post("", response_model=dict)
@log_usage(function_name="POST /v1/memify", log_type="api_endpoint")
async def memify(payload: MemifyPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Enrichment pipeline in Cognee, can work with already built graphs. If no data is provided existing knowledge graph will be used as data,

View file

@ -13,6 +13,7 @@ from cognee.modules.users.models import User
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.exceptions import CogneeValidationError
@ -75,6 +76,7 @@ def get_search_router() -> APIRouter:
return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=Union[List[SearchResult], List])
@log_usage(function_name="POST /v1/search", log_type="api_endpoint")
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Search for nodes in the graph database.

View file

@ -8,10 +8,13 @@ class CacheDBInterface(ABC):
Provides a common interface for lock acquisition, release, and context-managed locking.
"""
def __init__(self, host: str, port: int, lock_key: str):
def __init__(
self, host: str, port: int, lock_key: str = "default_lock", log_key: str = "usage_logs"
):
self.host = host
self.port = port
self.lock_key = lock_key
self.log_key = log_key
self.lock = None
@abstractmethod
@ -77,3 +80,37 @@ class CacheDBInterface(ABC):
Gracefully close any async connections.
"""
pass
@abstractmethod
async def log_usage(
self,
user_id: str,
log_entry: dict,
ttl: int | None = 604800,
):
"""
Log usage information (API endpoint calls, MCP tool invocations) to cache.
Args:
user_id: The user ID.
log_entry: Dictionary containing usage log information.
ttl: Optional time-to-live (seconds). If provided, the log list expires after this time.
Raises:
CacheConnectionError: If cache connection fails or times out.
"""
pass
@abstractmethod
async def get_usage_logs(self, user_id: str, limit: int = 100):
"""
Retrieve usage logs for a given user.
Args:
user_id: The user ID.
limit: Maximum number of logs to retrieve (default: 100).
Returns:
List of usage log entries, most recent first.
"""
pass

View file

@ -13,6 +13,8 @@ class CacheConfig(BaseSettings):
- cache_port: Port number for the cache service.
- agentic_lock_expire: Automatic lock expiration time (in seconds).
- agentic_lock_timeout: Maximum time (in seconds) to wait for the lock release.
- usage_logging: Enable/disable usage logging for API endpoints and MCP tools.
- usage_logging_ttl: Time-to-live for usage logs in seconds (default: 7 days).
"""
cache_backend: Literal["redis", "fs"] = "fs"
@ -24,6 +26,8 @@ class CacheConfig(BaseSettings):
cache_password: Optional[str] = None
agentic_lock_expire: int = 240
agentic_lock_timeout: int = 300
usage_logging: bool = False
usage_logging_ttl: int = 604800
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -38,6 +42,8 @@ class CacheConfig(BaseSettings):
"cache_password": self.cache_password,
"agentic_lock_expire": self.agentic_lock_expire,
"agentic_lock_timeout": self.agentic_lock_timeout,
"usage_logging": self.usage_logging,
"usage_logging_ttl": self.usage_logging_ttl,
}

View file

@ -89,6 +89,27 @@ class FSCacheAdapter(CacheDBInterface):
return None
return json.loads(value)
async def log_usage(
self,
user_id: str,
log_entry: dict,
ttl: int | None = 604800,
):
"""
Usage logging is not supported in filesystem cache backend.
This method is a no-op to satisfy the interface.
"""
logger.warning("Usage logging not supported in FSCacheAdapter, skipping")
pass
async def get_usage_logs(self, user_id: str, limit: int = 100):
"""
Usage logging is not supported in filesystem cache backend.
This method returns an empty list to satisfy the interface.
"""
logger.warning("Usage logging not supported in FSCacheAdapter, returning empty list")
return []
async def close(self):
if self.cache is not None:
self.cache.expire()

View file

@ -1,7 +1,6 @@
"""Factory to get the appropriate cache coordination engine (e.g., Redis)."""
from functools import lru_cache
import os
from typing import Optional
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
@ -17,6 +16,7 @@ def create_cache_engine(
cache_username: str,
cache_password: str,
lock_key: str,
log_key: str,
agentic_lock_expire: int = 240,
agentic_lock_timeout: int = 300,
):
@ -30,6 +30,7 @@ def create_cache_engine(
- cache_username: Username to authenticate with.
- cache_password: Password to authenticate with.
- lock_key: Identifier used for the locking resource.
- log_key: Identifier used for usage logging.
- agentic_lock_expire: Duration to hold the lock after acquisition.
- agentic_lock_timeout: Max time to wait for the lock before failing.
@ -37,7 +38,7 @@ def create_cache_engine(
--------
- CacheDBInterface: An instance of the appropriate cache adapter.
"""
if config.caching:
if config.caching or config.usage_logging:
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
if config.cache_backend == "redis":
@ -47,6 +48,7 @@ def create_cache_engine(
username=cache_username,
password=cache_password,
lock_name=lock_key,
log_key=log_key,
timeout=agentic_lock_expire,
blocking_timeout=agentic_lock_timeout,
)
@ -61,7 +63,10 @@ def create_cache_engine(
return None
def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface:
def get_cache_engine(
lock_key: Optional[str] = "default_lock",
log_key: Optional[str] = "usage_logs",
) -> Optional[CacheDBInterface]:
"""
Returns a cache adapter instance using current context configuration.
"""
@ -72,6 +77,7 @@ def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface:
cache_username=config.cache_username,
cache_password=config.cache_password,
lock_key=lock_key,
log_key=log_key,
agentic_lock_expire=config.agentic_lock_expire,
agentic_lock_timeout=config.agentic_lock_timeout,
)

View file

@ -17,13 +17,14 @@ class RedisAdapter(CacheDBInterface):
host,
port,
lock_name="default_lock",
log_key="usage_logs",
username=None,
password=None,
timeout=240,
blocking_timeout=300,
connection_timeout=30,
):
super().__init__(host, port, lock_name)
super().__init__(host, port, lock_name, log_key)
self.host = host
self.port = port
@ -177,6 +178,64 @@ class RedisAdapter(CacheDBInterface):
entries = await self.async_redis.lrange(session_key, 0, -1)
return [json.loads(e) for e in entries]
async def log_usage(
self,
user_id: str,
log_entry: dict,
ttl: int | None = 604800,
):
"""
Log usage information (API endpoint calls, MCP tool invocations) to Redis.
Args:
user_id: The user ID.
log_entry: Dictionary containing usage log information.
ttl: Optional time-to-live (seconds). If provided, the log list expires after this time.
Raises:
CacheConnectionError: If Redis connection fails or times out.
"""
try:
usage_logs_key = f"{self.log_key}:{user_id}"
await self.async_redis.rpush(usage_logs_key, json.dumps(log_entry))
if ttl is not None:
await self.async_redis.expire(usage_logs_key, ttl)
except (redis.ConnectionError, redis.TimeoutError) as e:
error_msg = f"Redis connection error while logging usage: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
except Exception as e:
error_msg = f"Unexpected error while logging usage to Redis: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
async def get_usage_logs(self, user_id: str, limit: int = 100):
"""
Retrieve usage logs for a given user.
Args:
user_id: The user ID.
limit: Maximum number of logs to retrieve (default: 100).
Returns:
List of usage log entries, most recent first.
"""
try:
usage_logs_key = f"{self.log_key}:{user_id}"
entries = await self.async_redis.lrange(usage_logs_key, -limit, -1)
return [json.loads(e) for e in reversed(entries)] if entries else []
except (redis.ConnectionError, redis.TimeoutError) as e:
error_msg = f"Redis connection error while retrieving usage logs: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
except Exception as e:
error_msg = f"Unexpected error while retrieving usage logs from Redis: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
async def close(self):
"""
Gracefully close the async Redis connection.

View file

@ -24,7 +24,6 @@ async def get_graph_engine() -> GraphDBInterface:
return graph_client
@lru_cache
def create_graph_engine(
graph_database_provider,
graph_file_path,
@ -35,6 +34,35 @@ def create_graph_engine(
graph_database_port="",
graph_database_key="",
graph_dataset_database_handler="",
):
"""
Wrapper function to call create graph engine with caching.
For a detailed description, see _create_graph_engine.
"""
return _create_graph_engine(
graph_database_provider,
graph_file_path,
graph_database_url,
graph_database_name,
graph_database_username,
graph_database_password,
graph_database_port,
graph_database_key,
graph_dataset_database_handler,
)
@lru_cache
def _create_graph_engine(
graph_database_provider,
graph_file_path,
graph_database_url="",
graph_database_name="",
graph_database_username="",
graph_database_password="",
graph_database_port="",
graph_database_key="",
graph_dataset_database_handler="",
):
"""
Create a graph engine based on the specified provider type.

View file

@ -236,6 +236,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
query_vector: Optional[List[float]] = None,
limit: Optional[int] = None,
with_vector: bool = False,
include_payload: bool = False, # TODO: Add support for this parameter
):
"""
Perform a search in the specified collection using either a text query or a vector
@ -319,7 +320,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
self._na_exception_handler(e, query_string)
async def batch_search(
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
self,
collection_name: str,
query_texts: List[str],
limit: int,
with_vectors: bool = False,
include_payload: bool = False,
):
"""
Perform a batch search using multiple text queries against a collection.
@ -342,7 +348,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
data_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[
self.search(collection_name, None, vector, limit, with_vectors)
self.search(
collection_name,
None,
vector,
limit,
with_vectors,
include_payload=include_payload,
)
for vector in data_vectors
]
)

View file

@ -355,6 +355,7 @@ class ChromaDBAdapter(VectorDBInterface):
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
include_payload: bool = False, # TODO: Add support for this parameter when set to False
):
"""
Search for items in a collection using either a text or a vector query.
@ -441,6 +442,7 @@ class ChromaDBAdapter(VectorDBInterface):
query_texts: List[str],
limit: int = 5,
with_vectors: bool = False,
include_payload: bool = False,
):
"""
Perform multiple searches in a single request for efficiency, returning results for each

View file

@ -7,7 +7,6 @@ from cognee.infrastructure.databases.graph.config import get_graph_context_confi
from functools import lru_cache
@lru_cache
def create_vector_engine(
vector_db_provider: str,
vector_db_url: str,
@ -15,6 +14,29 @@ def create_vector_engine(
vector_db_port: str = "",
vector_db_key: str = "",
vector_dataset_database_handler: str = "",
):
"""
Wrapper function to call create vector engine with caching.
For a detailed description, see _create_vector_engine.
"""
return _create_vector_engine(
vector_db_provider,
vector_db_url,
vector_db_name,
vector_db_port,
vector_db_key,
vector_dataset_database_handler,
)
@lru_cache
def _create_vector_engine(
vector_db_provider: str,
vector_db_url: str,
vector_db_name: str,
vector_db_port: str = "",
vector_db_key: str = "",
vector_dataset_database_handler: str = "",
):
"""
Create a vector database engine based on the specified provider.

View file

@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface):
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
include_payload: bool = False,
):
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
@ -247,17 +248,27 @@ class LanceDBAdapter(VectorDBInterface):
if limit <= 0:
return []
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
# Note: Exclude payload if not needed to optimize performance
select_columns = (
["id", "vector", "payload", "_distance"]
if include_payload
else ["id", "vector", "_distance"]
)
result_values = (
await collection.vector_search(query_vector)
.select(select_columns)
.limit(limit)
.to_list()
)
if not result_values:
return []
normalized_values = normalize_distances(result_values)
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
payload=result["payload"] if include_payload else None,
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
@ -269,6 +280,7 @@ class LanceDBAdapter(VectorDBInterface):
query_texts: List[str],
limit: Optional[int] = None,
with_vectors: bool = False,
include_payload: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
@ -279,6 +291,7 @@ class LanceDBAdapter(VectorDBInterface):
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
include_payload=include_payload,
)
for query_vector in query_vectors
]

View file

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional
from uuid import UUID
from pydantic import BaseModel
@ -12,10 +12,10 @@ class ScoredResult(BaseModel):
- id (UUID): Unique identifier for the scored result.
- score (float): The score associated with the result, where a lower score indicates a
better outcome.
- payload (Dict[str, Any]): Additional information related to the score, stored as
- payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as
key-value pairs in a dictionary.
"""
id: UUID
score: float # Lower score is better
payload: Dict[str, Any]
payload: Optional[Dict[str, Any]] = None

View file

@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_vector: Optional[List[float]] = None,
limit: Optional[int] = 15,
with_vector: bool = False,
include_payload: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
@ -324,10 +325,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []
# Note: Exclude payload from returned columns if not needed to optimize performance
select_columns = (
[PGVectorDataPoint]
if include_payload
else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector]
)
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
PGVectorDataPoint,
*select_columns,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
@ -344,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_list.append(
{
"id": parse_id(str(vector.id)),
"payload": vector.payload,
"payload": vector.payload if include_payload else None,
"_distance": vector.similarity,
}
)
@ -359,7 +366,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Create and return ScoredResult objects
return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
ScoredResult(
id=row.get("id"),
payload=row.get("payload") if include_payload else None,
score=row.get("score"),
)
for row in vector_list
]
@ -369,6 +380,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_texts: List[str],
limit: int = None,
with_vectors: bool = False,
include_payload: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
@ -379,6 +391,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
include_payload=include_payload,
)
for query_vector in query_vectors
]

View file

@ -87,6 +87,7 @@ class VectorDBInterface(Protocol):
query_vector: Optional[List[float]],
limit: Optional[int],
with_vector: bool = False,
include_payload: bool = False,
):
"""
Perform a search in the specified collection using either a text query or a vector
@ -103,6 +104,9 @@ class VectorDBInterface(Protocol):
- limit (Optional[int]): The maximum number of results to return from the search.
- with_vector (bool): Whether to return the vector representations with search
results. (default False)
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
"""
raise NotImplementedError
@ -113,6 +117,7 @@ class VectorDBInterface(Protocol):
query_texts: List[str],
limit: Optional[int],
with_vectors: bool = False,
include_payload: bool = False,
):
"""
Perform a batch search using multiple text queries against a collection.
@ -125,6 +130,9 @@ class VectorDBInterface(Protocol):
- limit (Optional[int]): The maximum number of results to return for each query.
- with_vectors (bool): Whether to include vector representations with search
results. (default False)
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
"""
raise NotImplementedError

View file

@ -1,5 +1,6 @@
import time
from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any
from cognee.modules.graph.exceptions import (
@ -44,6 +45,12 @@ class CogneeGraph(CogneeAbstractGraph):
def add_edge(self, edge: Edge) -> None:
self.edges.append(edge)
edge_text = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
edge.attributes["edge_type_id"] = (
generate_edge_id(edge_id=edge_text) if edge_text else None
) # Update edge with generated edge_type_id
edge.node1.add_skeleton_edge(edge)
edge.node2.add_skeleton_edge(edge)
key = edge.get_distance_key()
@ -284,13 +291,7 @@ class CogneeGraph(CogneeAbstractGraph):
for query_index, scored_results in enumerate(per_query_scored_results):
for result in scored_results:
payload = getattr(result, "payload", None)
if not isinstance(payload, dict):
continue
text = payload.get("text")
if not text:
continue
matching_edges = self.edges_by_distance_key.get(str(text))
matching_edges = self.edges_by_distance_key.get(str(result.id))
if not matching_edges:
continue
for edge in matching_edges:

View file

@ -141,7 +141,7 @@ class Edge:
self.status = np.ones(dimension, dtype=int)
def get_distance_key(self) -> Optional[str]:
key = self.attributes.get("edge_text") or self.attributes.get("relationship_type")
key = self.attributes.get("edge_type_id")
if key is None:
return None
return str(key)

View file

@ -47,7 +47,9 @@ class ChunksRetriever(BaseRetriever):
vector_engine = get_vector_engine()
try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
found_chunks = await vector_engine.search(
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
)
logger.info(f"Found {len(found_chunks)} chunks from vector search")
await update_node_access_timestamps(found_chunks)

View file

@ -62,7 +62,9 @@ class CompletionRetriever(BaseRetriever):
vector_engine = get_vector_engine()
try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
found_chunks = await vector_engine.search(
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
)
if len(found_chunks) == 0:
return ""

View file

@ -52,7 +52,7 @@ class SummariesRetriever(BaseRetriever):
try:
summaries_results = await vector_engine.search(
"TextSummary_text", query, limit=self.top_k
"TextSummary_text", query, limit=self.top_k, include_payload=True
)
logger.info(f"Found {len(summaries_results)} summaries from vector search")

View file

@ -98,7 +98,7 @@ class TemporalRetriever(GraphCompletionRetriever):
async def filter_top_k_events(self, relevant_events, scored_results):
# Build a score lookup from vector search results
score_lookup = {res.payload["id"]: res.score for res in scored_results}
score_lookup = {res.id: res.score for res in scored_results}
events_with_scores = []
for event in relevant_events[0]["events"]:

View file

@ -67,7 +67,9 @@ class TripletRetriever(BaseRetriever):
"In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
)
found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
found_triplets = await vector_engine.search(
"Triplet_text", query, limit=self.top_k, include_payload=True
)
if len(found_triplets) == 0:
return ""

48
cognee/run_migrations.py Normal file
View file

@ -0,0 +1,48 @@
import os
import sys
import subprocess
from pathlib import Path
import importlib.resources as pkg_resources
# Assuming your package is named 'cognee' and the migrations are under 'cognee/alembic'
# This is a placeholder for the path logic.
MIGRATIONS_PACKAGE = "cognee"
MIGRATIONS_DIR_NAME = "alembic"
async def run_migrations():
"""
Finds the Alembic configuration within the installed package and
programmatically executes 'alembic upgrade head'.
"""
# 1. Locate the base path of the installed package.
# This reliably finds the root directory of the installed 'cognee' package.
# We look for the parent of the 'migrations' directory.
package_root = str(pkg_resources.files(MIGRATIONS_PACKAGE))
# 2. Define the paths for config and scripts
alembic_ini_path = os.path.join(package_root, "alembic.ini")
script_location_path = os.path.join(package_root, MIGRATIONS_DIR_NAME)
if not os.path.exists(alembic_ini_path):
raise FileNotFoundError(
f"Error: alembic.ini not found at expected locations for package '{MIGRATIONS_PACKAGE}'."
)
if not os.path.exists(script_location_path):
raise FileNotFoundError(
f"Error: Migrations directory not found at expected locations for package '{MIGRATIONS_PACKAGE}'."
)
migration_result = subprocess.run(
["python", "-m", "alembic", "upgrade", "head"],
capture_output=True,
text=True,
cwd=Path(package_root),
)
if migration_result.returncode != 0:
migration_output = migration_result.stderr + migration_result.stdout
print(f"Migration failed with unexpected error: {migration_output}")
sys.exit(1)
print("Migration completed successfully.")

View file

@ -4,6 +4,4 @@ Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various shared utility errors
"""
from .exceptions import (
IngestionError,
)
from .exceptions import IngestionError, UsageLoggerError

View file

@ -1,4 +1,4 @@
from cognee.exceptions import CogneeValidationError
from cognee.exceptions import CogneeConfigurationError, CogneeValidationError
from fastapi import status
@ -10,3 +10,13 @@ class IngestionError(CogneeValidationError):
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)
class UsageLoggerError(CogneeConfigurationError):
def __init__(
self,
message: str = "Usage logging configuration is invalid.",
name: str = "UsageLoggerError",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
):
super().__init__(message, name, status_code)

View file

@ -0,0 +1,332 @@
import asyncio
import inspect
import os
from datetime import datetime, timezone
from functools import singledispatch, wraps
from typing import Any, Callable, Optional
from uuid import UUID
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
from cognee.shared.exceptions import UsageLoggerError
from cognee.shared.logging_utils import get_logger
from cognee import __version__ as cognee_version
logger = get_logger("usage_logger")
@singledispatch
def _sanitize_value(value: Any) -> Any:
"""Default handler for JSON serialization - converts to string."""
try:
str_repr = str(value)
if str_repr.startswith("<") and str_repr.endswith(">"):
return f"<cannot be serialized: {type(value).__name__}>"
return str_repr
except Exception:
return f"<cannot be serialized: {type(value).__name__}>"
@_sanitize_value.register(type(None))
def _(value: None) -> None:
"""Handle None values - returns None as-is."""
return None
@_sanitize_value.register(str)
@_sanitize_value.register(int)
@_sanitize_value.register(float)
@_sanitize_value.register(bool)
def _(value: str | int | float | bool) -> str | int | float | bool:
"""Handle primitive types - returns value as-is since they're JSON-serializable."""
return value
@_sanitize_value.register(UUID)
def _(value: UUID) -> str:
"""Convert UUID to string representation."""
return str(value)
@_sanitize_value.register(datetime)
def _(value: datetime) -> str:
"""Convert datetime to ISO format string."""
return value.isoformat()
@_sanitize_value.register(list)
@_sanitize_value.register(tuple)
def _(value: list | tuple) -> list:
"""Recursively sanitize list or tuple elements."""
return [_sanitize_value(v) for v in value]
@_sanitize_value.register(dict)
def _(value: dict) -> dict:
"""Recursively sanitize dictionary keys and values."""
sanitized = {}
for k, v in value.items():
key_str = k if isinstance(k, str) else _sanitize_dict_key(k)
sanitized[key_str] = _sanitize_value(v)
return sanitized
def _sanitize_dict_key(key: Any) -> str:
"""Convert a non-string dict key to a string."""
sanitized_key = _sanitize_value(key)
if isinstance(sanitized_key, str):
if sanitized_key.startswith("<cannot be serialized"):
return f"<key:{type(key).__name__}>"
return sanitized_key
return str(sanitized_key)
def _get_param_names(func: Callable) -> list[str]:
"""Get parameter names from function signature."""
try:
return list(inspect.signature(func).parameters.keys())
except Exception:
return []
def _get_param_defaults(func: Callable) -> dict[str, Any]:
"""Get parameter defaults from function signature."""
try:
sig = inspect.signature(func)
defaults = {}
for param_name, param in sig.parameters.items():
if param.default != inspect.Parameter.empty:
defaults[param_name] = param.default
return defaults
except Exception:
return {}
def _extract_user_id(args: tuple, kwargs: dict, param_names: list[str]) -> Optional[str]:
"""Extract user_id from function arguments if available."""
try:
if "user" in kwargs and kwargs["user"] is not None:
user = kwargs["user"]
if hasattr(user, "id"):
return str(user.id)
for i, param_name in enumerate(param_names):
if i < len(args) and param_name == "user":
user = args[i]
if user is not None and hasattr(user, "id"):
return str(user.id)
return None
except Exception:
return None
def _extract_parameters(args: tuple, kwargs: dict, param_names: list[str], func: Callable) -> dict:
"""Extract function parameters - captures all parameters including defaults, sanitizes for JSON."""
params = {}
for key, value in kwargs.items():
if key != "user":
params[key] = _sanitize_value(value)
if param_names:
for i, param_name in enumerate(param_names):
if i < len(args) and param_name != "user" and param_name not in kwargs:
params[param_name] = _sanitize_value(args[i])
else:
for i, arg_value in enumerate(args):
params[f"arg_{i}"] = _sanitize_value(arg_value)
if param_names:
defaults = _get_param_defaults(func)
for param_name in param_names:
if param_name != "user" and param_name not in params and param_name in defaults:
params[param_name] = _sanitize_value(defaults[param_name])
return params
async def _log_usage_async(
function_name: str,
log_type: str,
user_id: Optional[str],
parameters: dict,
result: Any,
success: bool,
error: Optional[str],
duration_ms: float,
start_time: datetime,
end_time: datetime,
):
"""Asynchronously log function usage to Redis.
Args:
function_name: Name of the function being logged.
log_type: Type of log entry (e.g., "api_endpoint", "mcp_tool", "function").
user_id: User identifier, or None to use "unknown".
parameters: Dictionary of function parameters (sanitized).
result: Function return value (will be sanitized).
success: Whether the function executed successfully.
error: Error message if function failed, None otherwise.
duration_ms: Execution duration in milliseconds.
start_time: Function start timestamp.
end_time: Function end timestamp.
Note:
This function silently handles errors to avoid disrupting the original
function execution. Logs are written to Redis with TTL from config.
"""
try:
logger.debug(f"Starting to log usage for {function_name} at {start_time.isoformat()}")
config = get_cache_config()
if not config.usage_logging:
logger.debug("Usage logging disabled, skipping log")
return
logger.debug(f"Getting cache engine for {function_name}")
cache_engine = get_cache_engine()
if cache_engine is None:
logger.warning(
f"Cache engine not available for usage logging (function: {function_name})"
)
return
logger.debug(f"Cache engine obtained for {function_name}")
if user_id is None:
user_id = "unknown"
logger.debug(f"No user_id provided, using 'unknown' for {function_name}")
log_entry = {
"timestamp": start_time.isoformat(),
"type": log_type,
"function_name": function_name,
"user_id": user_id,
"parameters": parameters,
"result": _sanitize_value(result),
"success": success,
"error": error,
"duration_ms": round(duration_ms, 2),
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"metadata": {
"cognee_version": cognee_version,
"environment": os.getenv("ENV", "prod"),
},
}
logger.debug(f"Calling log_usage for {function_name}, user_id={user_id}")
await cache_engine.log_usage(
user_id=user_id,
log_entry=log_entry,
ttl=config.usage_logging_ttl,
)
logger.info(f"Successfully logged usage for {function_name} (user_id={user_id})")
except Exception as e:
logger.error(f"Failed to log usage for {function_name}: {str(e)}", exc_info=True)
def log_usage(function_name: Optional[str] = None, log_type: str = "function"):
"""
Decorator to log function usage to Redis.
This decorator is completely transparent - it doesn't change function behavior.
It logs function name, parameters, result, timing, and user (if available).
Args:
function_name: Optional name for the function (defaults to func.__name__)
log_type: Type of log entry (e.g., "api_endpoint", "mcp_tool")
Usage:
@log_usage(function_name="MCP my_mcp_tool", log_type="mcp_tool")
async def my_mcp_tool(...):
# mcp code
@log_usage(function_name="POST API /v1/add", log_type="api_endpoint")
async def add(...):
# endpoint code
"""
def decorator(func: Callable) -> Callable:
"""Inner decorator that wraps the function with usage logging.
Args:
func: The async function to wrap with usage logging.
Returns:
Callable: The wrapped function with usage logging enabled.
Raises:
UsageLoggerError: If the function is not async.
"""
if not inspect.iscoroutinefunction(func):
raise UsageLoggerError(
f"@log_usage requires an async function. Got {func.__name__} which is not async."
)
@wraps(func)
async def async_wrapper(*args, **kwargs):
"""Wrapper function that executes the original function and logs usage.
This wrapper:
- Extracts user ID and parameters from function arguments
- Executes the original function
- Captures result, success status, and any errors
- Logs usage information asynchronously without blocking
Args:
*args: Positional arguments passed to the original function.
**kwargs: Keyword arguments passed to the original function.
Returns:
Any: The return value of the original function.
Raises:
Any exception raised by the original function (re-raised after logging).
"""
config = get_cache_config()
if not config.usage_logging:
return await func(*args, **kwargs)
start_time = datetime.now(timezone.utc)
param_names = _get_param_names(func)
user_id = _extract_user_id(args, kwargs, param_names)
parameters = _extract_parameters(args, kwargs, param_names, func)
result = None
success = True
error = None
try:
result = await func(*args, **kwargs)
return result
except Exception as e:
success = False
error = str(e)
raise
finally:
end_time = datetime.now(timezone.utc)
duration_ms = (end_time - start_time).total_seconds() * 1000
try:
await _log_usage_async(
function_name=function_name or func.__name__,
log_type=log_type,
user_id=user_id,
parameters=parameters,
result=result,
success=success,
error=error,
duration_ms=duration_ms,
start_time=start_time,
end_time=end_time,
)
except Exception as e:
logger.error(
f"Failed to log usage for {function_name or func.__name__}: {str(e)}",
exc_info=True,
)
return async_wrapper
return decorator

View file

@ -0,0 +1,255 @@
"""Integration tests for usage logger with real Redis components."""
import os
import pytest
import asyncio
from datetime import datetime, timezone
from types import SimpleNamespace
from uuid import UUID
from unittest.mock import patch
from cognee.shared.usage_logger import log_usage
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.get_cache_engine import (
get_cache_engine,
create_cache_engine,
)
@pytest.fixture
def usage_logging_config():
"""Fixture to enable usage logging via environment variables."""
original_env = os.environ.copy()
os.environ["USAGE_LOGGING"] = "true"
os.environ["CACHE_BACKEND"] = "redis"
os.environ["CACHE_HOST"] = "localhost"
os.environ["CACHE_PORT"] = "6379"
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
yield
os.environ.clear()
os.environ.update(original_env)
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
@pytest.fixture
def usage_logging_disabled():
"""Fixture to disable usage logging via environment variables."""
original_env = os.environ.copy()
os.environ["USAGE_LOGGING"] = "false"
os.environ["CACHE_BACKEND"] = "redis"
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
yield
os.environ.clear()
os.environ.update(original_env)
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
@pytest.fixture
def redis_adapter():
"""Real RedisAdapter instance for testing."""
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
try:
yield RedisAdapter(host="localhost", port=6379, log_key="test_usage_logs")
except Exception as e:
pytest.skip(f"Redis not available: {e}")
@pytest.fixture
def test_user():
"""Test user object."""
return SimpleNamespace(id="test-user-123")
class TestDecoratorBehavior:
"""Test decorator behavior with real components."""
@pytest.mark.asyncio
async def test_decorator_configuration(
self, usage_logging_disabled, usage_logging_config, redis_adapter
):
"""Test decorator skips when disabled and logs when enabled."""
# Test disabled
call_count = 0
@log_usage(function_name="test_func", log_type="test")
async def test_func():
nonlocal call_count
call_count += 1
return "result"
assert await test_func() == "result"
assert call_count == 1
# Test enabled with cache engine None
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = None
assert await test_func() == "result"
@pytest.mark.asyncio
async def test_decorator_logging(self, usage_logging_config, redis_adapter, test_user):
"""Test decorator logs to Redis with correct structure."""
@log_usage(function_name="test_func", log_type="test")
async def test_func(param1: str, param2: int = 42, user=None):
await asyncio.sleep(0.01)
return {"result": f"{param1}_{param2}"}
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
result = await test_func("value1", user=test_user)
assert result == {"result": "value1_42"}
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
log = logs[0]
assert log["function_name"] == "test_func"
assert log["type"] == "test"
assert log["user_id"] == "test-user-123"
assert log["parameters"]["param1"] == "value1"
assert log["parameters"]["param2"] == 42
assert log["success"] is True
assert all(
field in log
for field in [
"timestamp",
"result",
"error",
"duration_ms",
"start_time",
"end_time",
"metadata",
]
)
assert "cognee_version" in log["metadata"]
@pytest.mark.asyncio
async def test_multiple_calls(self, usage_logging_config, redis_adapter, test_user):
"""Test multiple consecutive calls are all logged."""
@log_usage(function_name="multi_test", log_type="test")
async def multi_func(call_num: int, user=None):
return {"call": call_num}
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
for i in range(3):
await multi_func(i, user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert len(logs) >= 3
call_nums = {log["parameters"]["call_num"] for log in logs[:3]}
assert call_nums == {0, 1, 2}
class TestRealRedisIntegration:
"""Test real Redis integration."""
@pytest.mark.asyncio
async def test_redis_storage_retrieval_and_ttl(
self, usage_logging_config, redis_adapter, test_user
):
"""Test logs are stored, retrieved with correct order/limits, and TTL is set."""
@log_usage(function_name="redis_test", log_type="test")
async def redis_func(data: str, user=None):
return {"processed": data}
@log_usage(function_name="order_test", log_type="test")
async def order_func(num: int, user=None):
return {"num": num}
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
# Storage
await redis_func("test_data", user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["function_name"] == "redis_test"
assert logs[0]["parameters"]["data"] == "test_data"
# Order (most recent first)
for i in range(3):
await order_func(i, user=test_user)
await asyncio.sleep(0.01)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert [log["parameters"]["num"] for log in logs[:3]] == [2, 1, 0]
# Limit
assert len(await redis_adapter.get_usage_logs("test-user-123", limit=2)) == 2
# TTL
ttl = await redis_adapter.async_redis.ttl("test_usage_logs:test-user-123")
assert 0 < ttl <= 604800
class TestEdgeCases:
"""Test edge cases in integration tests."""
@pytest.mark.asyncio
async def test_edge_cases(self, usage_logging_config, redis_adapter, test_user):
"""Test no params, defaults, complex structures, exceptions, None, circular refs."""
@log_usage(function_name="no_params", log_type="test")
async def no_params_func(user=None):
return "result"
@log_usage(function_name="defaults_only", log_type="test")
async def defaults_only_func(param1: str = "default1", param2: int = 42, user=None):
return {"param1": param1, "param2": param2}
@log_usage(function_name="complex_test", log_type="test")
async def complex_func(user=None):
return {
"nested": {
"list": [1, 2, 3],
"uuid": UUID("123e4567-e89b-12d3-a456-426614174000"),
"datetime": datetime(2024, 1, 15, tzinfo=timezone.utc),
}
}
@log_usage(function_name="exception_test", log_type="test")
async def exception_func(user=None):
raise RuntimeError("Test exception")
@log_usage(function_name="none_test", log_type="test")
async def none_func(user=None):
return None
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
# No parameters
await no_params_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["parameters"] == {}
# Default parameters
await defaults_only_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["parameters"]["param1"] == "default1"
assert logs[0]["parameters"]["param2"] == 42
# Complex nested structures
await complex_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert isinstance(logs[0]["result"]["nested"]["uuid"], str)
assert isinstance(logs[0]["result"]["nested"]["datetime"], str)
# Exception handling
with pytest.raises(RuntimeError):
await exception_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["success"] is False
assert "Test exception" in logs[0]["error"]
# None return value
assert await none_func(user=test_user) is None
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["result"] is None

View file

@ -97,7 +97,7 @@ async def test_vector_engine_search_none_limit():
query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0]
result = await vector_engine.search(
collection_name=collection_name, query_vector=query_vector, limit=None
collection_name=collection_name, query_vector=query_vector, limit=None, include_payload=True
)
# Check that we did not accidentally use any default value for limit

View file

@ -70,7 +70,9 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -149,7 +149,9 @@ async def main():
await test_getting_of_documents(dataset_name_1)
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -48,7 +48,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -63,7 +63,9 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -52,7 +52,9 @@ async def main():
await cognee.cognify([dataset_name])
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -41,14 +41,14 @@ async def _reset_engines_and_prune() -> None:
except Exception:
pass
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
from cognee.infrastructure.databases.relational.create_relational_engine import (
create_relational_engine,
)
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine
create_graph_engine.cache_clear()
create_vector_engine.cache_clear()
_create_graph_engine.cache_clear()
_create_vector_engine.cache_clear()
create_relational_engine.cache_clear()
await cognee.prune.prune_data()

View file

@ -163,7 +163,9 @@ async def main():
await test_getting_of_documents(dataset_name_1)
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -58,7 +58,9 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -43,7 +43,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -48,14 +48,14 @@ async def _reset_engines_and_prune() -> None:
# Engine might not exist yet
pass
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine
from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine
from cognee.infrastructure.databases.relational.create_relational_engine import (
create_relational_engine,
)
create_graph_engine.cache_clear()
create_vector_engine.cache_clear()
_create_graph_engine.cache_clear()
_create_vector_engine.cache_clear()
create_relational_engine.cache_clear()
await cognee.prune.prune_data()

View file

@ -0,0 +1,268 @@
import os
import pytest
import pytest_asyncio
import asyncio
from fastapi.testclient import TestClient
import cognee
from cognee.api.client import app
from cognee.modules.users.methods import get_default_user, get_authenticated_user
async def _reset_engines_and_prune():
"""Reset db engine caches and prune data/system."""
try:
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
await vector_engine.engine.dispose(close=True)
except Exception:
pass
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
@pytest.fixture(scope="session")
def event_loop():
"""Use a single asyncio event loop for this test module."""
loop = asyncio.new_event_loop()
try:
yield loop
finally:
loop.close()
@pytest.fixture(scope="session")
def e2e_config():
"""Configure environment for E2E tests."""
original_env = os.environ.copy()
os.environ["USAGE_LOGGING"] = "true"
os.environ["CACHE_BACKEND"] = "redis"
os.environ["CACHE_HOST"] = "localhost"
os.environ["CACHE_PORT"] = "6379"
yield
os.environ.clear()
os.environ.update(original_env)
@pytest.fixture(scope="session")
def authenticated_client(test_client):
"""Override authentication to use default user."""
async def override_get_authenticated_user():
return await get_default_user()
app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user
yield test_client
app.dependency_overrides.pop(get_authenticated_user, None)
@pytest_asyncio.fixture(scope="session")
async def test_data_setup():
"""Set up test data: prune first, then add file and cognify."""
await _reset_engines_and_prune()
dataset_name = "test_e2e_dataset"
test_text = "Germany is located in Europe right next to the Netherlands."
await cognee.add(test_text, dataset_name)
await cognee.cognify([dataset_name])
yield dataset_name
await _reset_engines_and_prune()
@pytest_asyncio.fixture
async def mcp_data_setup():
"""Set up test data for MCP tests: prune first, then add file and cognify."""
await _reset_engines_and_prune()
dataset_name = "test_mcp_dataset"
test_text = "Germany is located in Europe right next to the Netherlands."
await cognee.add(test_text, dataset_name)
await cognee.cognify([dataset_name])
yield dataset_name
await _reset_engines_and_prune()
@pytest.fixture(scope="session")
def test_client():
"""TestClient instance for API calls."""
with TestClient(app) as client:
yield client
@pytest_asyncio.fixture
async def cache_engine(e2e_config):
"""Get cache engine for log verification in test's event loop."""
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
from cognee.infrastructure.databases.cache.config import get_cache_config
config = get_cache_config()
if not config.usage_logging or config.cache_backend != "redis":
pytest.skip("Redis usage logging not configured")
engine = RedisAdapter(
host=config.cache_host,
port=config.cache_port,
username=config.cache_username,
password=config.cache_password,
log_key="usage_logs",
)
return engine
@pytest.mark.asyncio
async def test_api_endpoint_logging(e2e_config, authenticated_client, cache_engine):
"""Test that API endpoints succeed and log to Redis."""
user = await get_default_user()
dataset_name = "test_e2e_api_dataset"
add_response = authenticated_client.post(
"/api/v1/add",
data={"datasetName": dataset_name},
files=[
(
"data",
(
"test.txt",
b"Germany is located in Europe right next to the Netherlands.",
"text/plain",
),
)
],
)
assert add_response.status_code in [200, 201], f"Add endpoint failed: {add_response.text}"
cognify_response = authenticated_client.post(
"/api/v1/cognify",
json={"datasets": [dataset_name], "run_in_background": False},
)
assert cognify_response.status_code in [200, 201], (
f"Cognify endpoint failed: {cognify_response.text}"
)
search_response = authenticated_client.post(
"/api/v1/search",
json={"query": "Germany", "search_type": "GRAPH_COMPLETION", "datasets": [dataset_name]},
)
assert search_response.status_code == 200, f"Search endpoint failed: {search_response.text}"
logs = await cache_engine.get_usage_logs(str(user.id), limit=20)
add_logs = [log for log in logs if log.get("function_name") == "POST /v1/add"]
assert len(add_logs) > 0
assert add_logs[0]["type"] == "api_endpoint"
assert add_logs[0]["user_id"] == str(user.id)
assert add_logs[0]["success"] is True
cognify_logs = [log for log in logs if log.get("function_name") == "POST /v1/cognify"]
assert len(cognify_logs) > 0
assert cognify_logs[0]["type"] == "api_endpoint"
assert cognify_logs[0]["user_id"] == str(user.id)
assert cognify_logs[0]["success"] is True
search_logs = [log for log in logs if log.get("function_name") == "POST /v1/search"]
assert len(search_logs) > 0
assert search_logs[0]["type"] == "api_endpoint"
assert search_logs[0]["user_id"] == str(user.id)
assert search_logs[0]["success"] is True
@pytest.mark.asyncio
async def test_mcp_tool_logging(e2e_config, cache_engine):
"""Test that MCP tools succeed and log to Redis."""
import sys
import importlib.util
from pathlib import Path
await _reset_engines_and_prune()
repo_root = Path(__file__).parent.parent.parent
mcp_src_path = repo_root / "cognee-mcp" / "src"
mcp_server_path = mcp_src_path / "server.py"
if not mcp_server_path.exists():
pytest.skip(f"MCP server not found at {mcp_server_path}")
if str(mcp_src_path) not in sys.path:
sys.path.insert(0, str(mcp_src_path))
spec = importlib.util.spec_from_file_location("mcp_server_module", mcp_server_path)
mcp_server_module = importlib.util.module_from_spec(spec)
import os
original_cwd = os.getcwd()
try:
os.chdir(str(mcp_src_path))
spec.loader.exec_module(mcp_server_module)
finally:
os.chdir(original_cwd)
if mcp_server_module.cognee_client is None:
cognee_client_path = mcp_src_path / "cognee_client.py"
if cognee_client_path.exists():
spec_client = importlib.util.spec_from_file_location(
"cognee_client", cognee_client_path
)
cognee_client_module = importlib.util.module_from_spec(spec_client)
spec_client.loader.exec_module(cognee_client_module)
CogneeClient = cognee_client_module.CogneeClient
mcp_server_module.cognee_client = CogneeClient()
else:
pytest.skip(f"CogneeClient not found at {cognee_client_path}")
test_text = "Germany is located in Europe right next to the Netherlands."
await mcp_server_module.cognify(data=test_text)
await asyncio.sleep(30.0)
list_result = await mcp_server_module.list_data()
assert list_result is not None, "List data should return results"
search_result = await mcp_server_module.search(
search_query="Germany", search_type="GRAPH_COMPLETION", top_k=5
)
assert search_result is not None, "Search should return results"
interaction_data = "User: What is Germany?\nAgent: Germany is a country in Europe."
await mcp_server_module.save_interaction(data=interaction_data)
await asyncio.sleep(30.0)
status_result = await mcp_server_module.cognify_status()
assert status_result is not None, "Cognify status should return results"
await mcp_server_module.prune()
await asyncio.sleep(0.5)
logs = await cache_engine.get_usage_logs("unknown", limit=50)
mcp_logs = [log for log in logs if log.get("type") == "mcp_tool"]
assert len(mcp_logs) > 0, (
f"Should have MCP tool logs with user_id='unknown'. Found logs: {[log.get('function_name') for log in logs[:5]]}"
)
assert len(mcp_logs) == 6
function_names = [log.get("function_name") for log in mcp_logs]
expected_tools = [
"MCP cognify",
"MCP list_data",
"MCP search",
"MCP save_interaction",
"MCP cognify_status",
"MCP prune",
]
for expected_tool in expected_tools:
assert expected_tool in function_names, (
f"Should have {expected_tool} log. Found: {function_names}"
)
for log in mcp_logs:
assert log["type"] == "mcp_tool"
assert log["user_id"] == "unknown"
assert log["success"] is True

View file

@ -62,6 +62,8 @@ def test_cache_config_to_dict():
"cache_password": None,
"agentic_lock_expire": 100,
"agentic_lock_timeout": 200,
"usage_logging": False,
"usage_logging_ttl": 604800,
}

View file

@ -1,6 +1,7 @@
import pytest
from unittest.mock import AsyncMock
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
@ -379,7 +380,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
graph.add_edge(edge)
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
MockScoredResult(generate_edge_id("CONNECTS_TO"), 0.92, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -404,8 +405,9 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_1_text = "CONNECTS_TO"
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -431,8 +433,9 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
)
graph.add_edge(edge)
edge_text = "KNOWS"
edge_distances = [
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
MockScoredResult(generate_edge_id(edge_text), 0.85, payload={"text": edge_text}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -457,8 +460,9 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
)
graph.add_edge(edge)
edge_text = "SOME_OTHER_EDGE"
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
MockScoredResult(generate_edge_id(edge_text), 0.92, payload={"text": edge_text}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -511,9 +515,15 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_1_text = "A"
edge_2_text = "B"
edge_distances = [
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0
[MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1
[
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
], # query 0
[
MockScoredResult(generate_edge_id(edge_2_text), 0.2, payload={"text": edge_2_text})
], # query 1
]
await graph.map_vector_distances_to_graph_edges(
@ -541,8 +551,11 @@ async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(se
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_1_text = "A"
edge_distances = [
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped
[
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
], # query 0: only edge1 mapped
[], # query 1: no edges mapped
]

View file

@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine):
assert len(context) == 2
assert context[0]["text"] == "Steve Rodger"
assert context[1]["text"] == "Mike Broski"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=5, include_payload=True
)
@pytest.mark.asyncio
@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
context = await retriever.get_context("test query")
assert len(context) == 3
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=3, include_payload=True
)
@pytest.mark.asyncio

View file

@ -33,7 +33,9 @@ async def test_get_context_success(mock_vector_engine):
context = await retriever.get_context("test query")
assert context == "Steve Rodger\nMike Broski"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=2, include_payload=True
)
@pytest.mark.asyncio
@ -85,7 +87,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
context = await retriever.get_context("test query")
assert context == "Chunk 0\nChunk 1"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=2, include_payload=True
)
@pytest.mark.asyncio

View file

@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine):
assert len(context) == 2
assert context[0]["text"] == "S.R."
assert context[1]["text"] == "M.B."
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5)
mock_vector_engine.search.assert_awaited_once_with(
"TextSummary_text", "test query", limit=5, include_payload=True
)
@pytest.mark.asyncio
@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
context = await retriever.get_context("test query")
assert len(context) == 3
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3)
mock_vector_engine.search.assert_awaited_once_with(
"TextSummary_text", "test query", limit=3, include_payload=True
)
@pytest.mark.asyncio

View file

@ -63,8 +63,8 @@ async def test_filter_top_k_events_sorts_and_limits():
]
scored_results = [
SimpleNamespace(payload={"id": "e2"}, score=0.10),
SimpleNamespace(payload={"id": "e1"}, score=0.20),
SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.10),
SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.20),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
@ -91,8 +91,8 @@ async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k
]
scored_results = [
SimpleNamespace(payload={"id": "known2"}, score=0.05),
SimpleNamespace(payload={"id": "known1"}, score=0.50),
SimpleNamespace(id="known2", payload={"id": "known2"}, score=0.05),
SimpleNamespace(id="known1", payload={"id": "known1"}, score=0.50),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
@ -119,8 +119,8 @@ async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
tr = TemporalRetriever(top_k=10)
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
scored_results = [
SimpleNamespace(payload={"id": "a"}, score=0.1),
SimpleNamespace(payload={"id": "b"}, score=0.2),
SimpleNamespace(id="a", payload={"id": "a"}, score=0.1),
SimpleNamespace(id="b", payload={"id": "b"}, score=0.2),
]
out = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in out] == ["a", "b"]
@ -179,8 +179,8 @@ async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine
}
]
mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
mock_result1 = SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.05)
mock_result2 = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.10)
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
with (
@ -279,7 +279,7 @@ async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine)
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -313,7 +313,7 @@ async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -347,7 +347,7 @@ async def test_get_completion_without_context(mock_graph_engine, mock_vector_eng
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -416,7 +416,7 @@ async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
mock_user = MagicMock()
@ -481,7 +481,7 @@ async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_ve
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -570,7 +570,7 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (

View file

@ -6,6 +6,7 @@ from cognee.modules.retrieval.utils.brute_force_triplet_search import (
get_memory_fragment,
format_triplets,
)
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
@ -1036,9 +1037,11 @@ async def test_cognee_graph_mapping_batch_shapes():
]
}
edge_1_text = "relates_to"
edge_2_text = "relates_to"
edge_distances_batch = [
[MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})],
[MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})],
[MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text})],
[MockScoredResult(generate_edge_id(edge_2_text), 0.88, payload={"text": edge_2_text})],
]
await graph.map_vector_distances_to_graph_nodes(

View file

@ -34,7 +34,9 @@ async def test_get_context_success(mock_vector_engine):
context = await retriever.get_context("test query")
assert context == "Alice knows Bob\nBob works at Tech Corp"
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
mock_vector_engine.search.assert_awaited_once_with(
"Triplet_text", "test query", limit=5, include_payload=True
)
@pytest.mark.asyncio

View file

@ -0,0 +1,241 @@
"""Unit tests for usage logger core functions."""
import pytest
from datetime import datetime, timezone
from uuid import UUID
from types import SimpleNamespace
from cognee.shared.usage_logger import (
_sanitize_value,
_sanitize_dict_key,
_get_param_names,
_get_param_defaults,
_extract_user_id,
_extract_parameters,
log_usage,
)
from cognee.shared.exceptions import UsageLoggerError
class TestSanitizeValue:
"""Test _sanitize_value function."""
@pytest.mark.parametrize(
"value,expected",
[
(None, None),
("string", "string"),
(42, 42),
(3.14, 3.14),
(True, True),
(False, False),
],
)
def test_basic_types(self, value, expected):
assert _sanitize_value(value) == expected
def test_uuid_and_datetime(self):
"""Test UUID and datetime serialization."""
uuid_val = UUID("123e4567-e89b-12d3-a456-426614174000")
dt = datetime(2024, 1, 15, 12, 30, 45, tzinfo=timezone.utc)
assert _sanitize_value(uuid_val) == "123e4567-e89b-12d3-a456-426614174000"
assert _sanitize_value(dt) == "2024-01-15T12:30:45+00:00"
def test_collections(self):
"""Test list, tuple, and dict serialization."""
assert _sanitize_value(
[1, "string", UUID("123e4567-e89b-12d3-a456-426614174000"), None]
) == [1, "string", "123e4567-e89b-12d3-a456-426614174000", None]
assert _sanitize_value((1, "string", True)) == [1, "string", True]
assert _sanitize_value({"key": UUID("123e4567-e89b-12d3-a456-426614174000")}) == {
"key": "123e4567-e89b-12d3-a456-426614174000"
}
assert _sanitize_value([]) == []
assert _sanitize_value({}) == {}
def test_nested_and_complex(self):
"""Test nested structures and non-serializable types."""
# Nested structure
nested = {"level1": {"level2": {"level3": [1, 2, {"nested": "value"}]}}}
assert _sanitize_value(nested)["level1"]["level2"]["level3"][2]["nested"] == "value"
# Non-serializable
class CustomObject:
def __str__(self):
return "<CustomObject instance>"
result = _sanitize_value(CustomObject())
assert isinstance(result, str)
assert "<cannot be serialized" in result or "<CustomObject" in result
class TestSanitizeDictKey:
"""Test _sanitize_dict_key function."""
@pytest.mark.parametrize(
"key,expected_contains",
[
("simple_key", "simple_key"),
(UUID("123e4567-e89b-12d3-a456-426614174000"), "123e4567-e89b-12d3-a456-426614174000"),
((1, 2, 3), ["1", "2"]),
],
)
def test_key_types(self, key, expected_contains):
result = _sanitize_dict_key(key)
assert isinstance(result, str)
if isinstance(expected_contains, list):
assert all(item in result for item in expected_contains)
else:
assert expected_contains in result
def test_non_serializable_key(self):
class BadKey:
def __str__(self):
return "<BadKey instance>"
result = _sanitize_dict_key(BadKey())
assert isinstance(result, str)
assert "<key:" in result or "<BadKey" in result
class TestGetParamNames:
"""Test _get_param_names function."""
@pytest.mark.parametrize(
"func_def,expected",
[
(lambda a, b, c: None, ["a", "b", "c"]),
(lambda a, b=42, c="default": None, ["a", "b", "c"]),
(lambda a, **kwargs: None, ["a", "kwargs"]),
(lambda *args: None, ["args"]),
],
)
def test_param_extraction(self, func_def, expected):
assert _get_param_names(func_def) == expected
def test_async_function(self):
async def func(a, b):
pass
assert _get_param_names(func) == ["a", "b"]
class TestGetParamDefaults:
"""Test _get_param_defaults function."""
@pytest.mark.parametrize(
"func_def,expected",
[
(lambda a, b=42, c="default", d=None: None, {"b": 42, "c": "default", "d": None}),
(lambda a, b, c: None, {}),
(lambda a, b=10, c="test", d=None: None, {"b": 10, "c": "test", "d": None}),
],
)
def test_default_extraction(self, func_def, expected):
assert _get_param_defaults(func_def) == expected
class TestExtractUserId:
"""Test _extract_user_id function."""
def test_user_extraction(self):
"""Test extracting user_id from kwargs and args."""
user1 = SimpleNamespace(id=UUID("123e4567-e89b-12d3-a456-426614174000"))
user2 = SimpleNamespace(id="user-123")
# From kwargs
assert (
_extract_user_id((), {"user": user1}, ["user", "other"])
== "123e4567-e89b-12d3-a456-426614174000"
)
# From args
assert _extract_user_id((user2, "other"), {}, ["user", "other"]) == "user-123"
# Not present
assert _extract_user_id(("arg1",), {}, ["param1"]) is None
# None value
assert _extract_user_id((None,), {}, ["user"]) is None
# No id attribute
assert _extract_user_id((SimpleNamespace(name="test"),), {}, ["user"]) is None
class TestExtractParameters:
"""Test _extract_parameters function."""
def test_parameter_extraction(self):
"""Test parameter extraction with various scenarios."""
def func1(param1, param2, user=None):
pass
def func2(param1, param2=42, param3="default", user=None):
pass
def func3():
pass
def func4(param1, user):
pass
# Kwargs only
result = _extract_parameters(
(), {"param1": "v1", "param2": 42}, _get_param_names(func1), func1
)
assert result == {"param1": "v1", "param2": 42}
assert "user" not in result
# Args only
result = _extract_parameters(("v1", 42), {}, _get_param_names(func1), func1)
assert result == {"param1": "v1", "param2": 42}
# Mixed args/kwargs
result = _extract_parameters(("v1",), {"param3": "v3"}, _get_param_names(func2), func2)
assert result["param1"] == "v1" and result["param3"] == "v3"
# Defaults included
result = _extract_parameters(("v1",), {}, _get_param_names(func2), func2)
assert result["param1"] == "v1" and result["param2"] == 42 and result["param3"] == "default"
# No parameters
assert _extract_parameters((), {}, _get_param_names(func3), func3) == {}
# User excluded
user = SimpleNamespace(id="user-123")
result = _extract_parameters(("v1", user), {}, _get_param_names(func4), func4)
assert result == {"param1": "v1"} and "user" not in result
# Fallback when inspection fails
class BadFunc:
pass
result = _extract_parameters(("arg1", "arg2"), {}, [], BadFunc())
assert "arg_0" in result or "arg_1" in result
class TestDecoratorValidation:
"""Test decorator validation and behavior."""
def test_decorator_validation(self):
"""Test decorator validation and metadata preservation."""
# Sync function raises error
with pytest.raises(UsageLoggerError, match="requires an async function"):
@log_usage()
def sync_func():
pass
# Async function accepted
@log_usage()
async def async_func():
pass
assert callable(async_func)
# Metadata preserved
@log_usage(function_name="test_func", log_type="test")
async def test_func(param1: str, param2: int = 42):
"""Test docstring."""
return param1
assert test_func.__name__ == "test_func"
assert "Test docstring" in test_func.__doc__

View file

@ -11,34 +11,24 @@ echo "Debug port: $DEBUG_PORT"
echo "HTTP port: $HTTP_PORT"
# Run Alembic migrations with proper error handling.
# Note on UserAlreadyExists error handling:
# During database migrations, we attempt to create a default user. If this user
# already exists (e.g., from a previous deployment or migration), it's not a
# critical error and shouldn't prevent the application from starting. This is
# different from other migration errors which could indicate database schema
# inconsistencies and should cause the startup to fail. This check allows for
# smooth redeployments and container restarts while maintaining data integrity.
echo "Running database migrations..."
# Move to the cognee directory to run alembic migrations from there
set +e # Disable exit on error to handle specific migration errors
MIGRATION_OUTPUT=$(alembic upgrade head)
MIGRATION_OUTPUT=$(cd cognee && alembic upgrade head)
MIGRATION_EXIT_CODE=$?
set -e
if [[ $MIGRATION_EXIT_CODE -ne 0 ]]; then
if [[ "$MIGRATION_OUTPUT" == *"UserAlreadyExists"* ]] || [[ "$MIGRATION_OUTPUT" == *"User default_user@example.com already exists"* ]]; then
echo "Warning: Default user already exists, continuing startup..."
else
echo "Migration failed with unexpected error. Trying to run Cognee without migrations."
echo "Migration failed with unexpected error. Trying to run Cognee without migrations."
echo "Initializing database tables..."
python /app/cognee/modules/engine/operations/setup.py
INIT_EXIT_CODE=$?
echo "Initializing database tables..."
python /app/cognee/modules/engine/operations/setup.py
INIT_EXIT_CODE=$?
if [[ $INIT_EXIT_CODE -ne 0 ]]; then
echo "Database initialization failed!"
exit 1
fi
if [[ $INIT_EXIT_CODE -ne 0 ]]; then
echo "Database initialization failed!"
exit 1
fi
else
echo "Database migrations done."

View file

@ -0,0 +1,39 @@
import asyncio
from typing import Any
from pydantic import SkipValidation
import cognee
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.Edge import Edge
from cognee.tasks.storage import add_data_points
class Person(DataPoint):
name: str
# Keep it simple for forward refs / mixed values
knows: SkipValidation[Any] = None # single Person or list[Person]
# Recommended: specify which fields to index for search
metadata: dict = {"index_fields": ["name"]}
async def main():
# Start clean (optional in your app)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
alice = Person(name="Alice")
bob = Person(name="Bob")
charlie = Person(name="Charlie")
# Create relationships - field name becomes edge label
alice.knows = bob
# You can also do lists: alice.knows = [bob, charlie]
# Optional: add weights and custom relationship types
bob.knows = (Edge(weight=0.9, relationship_type="friend_of"), charlie)
await add_data_points([alice, bob, charlie])
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,30 @@
import asyncio
import cognee
from cognee.api.v1.search import SearchType
custom_prompt = """
Extract only people and cities as entities.
Connect people to cities with the relationship "lives_in".
Ignore all other entities.
"""
async def main():
await cognee.add(
[
"Alice moved to Paris in 2010, while Bob has always lived in New York.",
"Andreas was born in Venice, but later settled in Lisbon.",
"Diana and Tom were born and raised in Helsinki. Diana currently resides in Berlin, while Tom never moved.",
]
)
await cognee.cognify(custom_prompt=custom_prompt)
res = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="Where does Alice live?",
)
print(res)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,53 @@
import asyncio
from typing import Any, Dict, List
from pydantic import BaseModel, SkipValidation
import cognee
from cognee.modules.engine.operations.setup import setup
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.engine import DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.pipelines import Task, run_pipeline
class Person(DataPoint):
name: str
# Optional relationships (we'll let the LLM populate this)
knows: List["Person"] = []
# Make names searchable in the vector store
metadata: Dict[str, Any] = {"index_fields": ["name"]}
class People(BaseModel):
persons: List[Person]
async def extract_people(text: str) -> List[Person]:
system_prompt = (
"Extract people mentioned in the text. "
"Return as `persons: Person[]` with each Person having `name` and optional `knows` relations. "
"If the text says someone knows someone set `knows` accordingly. "
"Only include facts explicitly stated."
)
people = await LLMGateway.acreate_structured_output(text, system_prompt, People)
return people.persons
async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
text = "Alice knows Bob."
tasks = [
Task(extract_people), # input: text -> output: list[Person]
Task(add_data_points), # input: list[Person] -> output: list[Person]
]
async for _ in run_pipeline(tasks=tasks, data=text, datasets=["people_demo"]):
pass
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,14 @@
import asyncio
import cognee
from cognee.api.v1.visualize.visualize import visualize_graph
async def main():
await cognee.add(["Alice knows Bob.", "NLP is a subfield of CS."])
await cognee.cognify()
await visualize_graph("./graph_after_cognify.html")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,31 @@
import asyncio
from pydantic import BaseModel
from typing import List
from cognee.infrastructure.llm.LLMGateway import LLMGateway
class MiniEntity(BaseModel):
name: str
type: str
class MiniGraph(BaseModel):
nodes: List[MiniEntity]
async def main():
system_prompt = (
"Extract entities as nodes with name and type. "
"Use concise, literal values present in the text."
)
text = "Apple develops iPhone; Audi produces the R8."
result = await LLMGateway.acreate_structured_output(text, system_prompt, MiniGraph)
print(result)
# MiniGraph(nodes=[MiniEntity(name='Apple', type='Organization'), ...])
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,30 @@
import asyncio
import cognee
from cognee.modules.search.types import SearchType
async def main():
# 1) Add two short chats and build a graph
await cognee.add(
[
"We follow PEP8. Add type hints and docstrings.",
"Releases should not be on Friday. Susan must review PRs.",
],
dataset_name="rules_demo",
)
await cognee.cognify(datasets=["rules_demo"]) # builds graph
# 2) Enrich the graph (uses default memify tasks)
await cognee.memify(dataset="rules_demo")
# 3) Query the new coding rules
rules = await cognee.search(
query_type=SearchType.CODING_RULES,
query_text="List coding rules",
node_name=["coding_agent_rules"],
)
print("Rules:", rules)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,290 @@
<?xml version="1.0" encoding="UTF-8"?>
<rdf:RDF
xmlns:ns1="http://example.org/ontology#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:rdfs="http://www.w3.org/2000/01/rdf-schema#"
>
<rdf:Description rdf:about="http://example.org/ontology#Volkswagen">
<rdfs:comment>Created for making cars accessible to everyone.</rdfs:comment>
<ns1:produces rdf:resource="http://example.org/ontology#VW_Golf"/>
<ns1:produces rdf:resource="http://example.org/ontology#VW_ID4"/>
<ns1:produces rdf:resource="http://example.org/ontology#VW_Touareg"/>
<rdf:type rdf:resource="http://example.org/ontology#CarManufacturer"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Azure">
<rdf:type rdf:resource="http://example.org/ontology#CloudServiceProvider"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Porsche">
<ns1:produces rdf:resource="http://example.org/ontology#Porsche_Cayenne"/>
<ns1:produces rdf:resource="http://example.org/ontology#Porsche_Taycan"/>
<ns1:produces rdf:resource="http://example.org/ontology#Porsche_911"/>
<rdf:type rdf:resource="http://example.org/ontology#CarManufacturer"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
<rdfs:comment>Famous for high-performance sports cars.</rdfs:comment>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Meta">
<rdf:type rdf:resource="http://example.org/ontology#TechnologyCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
<ns1:develops rdf:resource="http://example.org/ontology#Instagram"/>
<ns1:develops rdf:resource="http://example.org/ontology#Facebook"/>
<ns1:develops rdf:resource="http://example.org/ontology#Oculus"/>
<ns1:develops rdf:resource="http://example.org/ontology#WhatsApp"/>
<rdfs:comment>Pioneering social media and virtual reality technology.</rdfs:comment>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#TechnologyCompany">
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Apple">
<rdf:type rdf:resource="http://example.org/ontology#TechnologyCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
<rdfs:comment>Known for its innovative consumer electronics and software.</rdfs:comment>
<ns1:develops rdf:resource="http://example.org/ontology#iPad"/>
<ns1:develops rdf:resource="http://example.org/ontology#iPhone"/>
<ns1:develops rdf:resource="http://example.org/ontology#AppleWatch"/>
<ns1:develops rdf:resource="http://example.org/ontology#MacBook"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Audi">
<ns1:produces rdf:resource="http://example.org/ontology#Audi_eTron"/>
<ns1:produces rdf:resource="http://example.org/ontology#Audi_R8"/>
<ns1:produces rdf:resource="http://example.org/ontology#Audi_A8"/>
<rdf:type rdf:resource="http://example.org/ontology#CarManufacturer"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
<rdfs:comment>Known for its modern designs and technology.</rdfs:comment>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#AmazonEcho">
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Porsche_Taycan">
<rdf:type rdf:resource="http://example.org/ontology#ElectricCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#BMW">
<ns1:produces rdf:resource="http://example.org/ontology#BMW_7Series"/>
<ns1:produces rdf:resource="http://example.org/ontology#BMW_M4"/>
<ns1:produces rdf:resource="http://example.org/ontology#BMW_iX"/>
<rdf:type rdf:resource="http://example.org/ontology#CarManufacturer"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
<rdfs:comment>Focused on performance and driving pleasure.</rdfs:comment>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#VW_Touareg">
<rdf:type rdf:resource="http://example.org/ontology#SUV"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#SportsCar">
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Car"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#ElectricCar">
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Car"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Google">
<ns1:develops rdf:resource="http://example.org/ontology#GooglePixel"/>
<ns1:develops rdf:resource="http://example.org/ontology#GoogleCloud"/>
<ns1:develops rdf:resource="http://example.org/ontology#Android"/>
<ns1:develops rdf:resource="http://example.org/ontology#GoogleSearch"/>
<rdf:type rdf:resource="http://example.org/ontology#TechnologyCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
<rdfs:comment>Started as a search engine and expanded into cloud computing and AI.</rdfs:comment>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#AmazonPrime">
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Car">
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#WindowsOS">
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Android">
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Oculus">
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#GoogleCloud">
<rdf:type rdf:resource="http://example.org/ontology#CloudServiceProvider"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Microsoft">
<ns1:develops rdf:resource="http://example.org/ontology#Surface"/>
<ns1:develops rdf:resource="http://example.org/ontology#WindowsOS"/>
<ns1:develops rdf:resource="http://example.org/ontology#Azure"/>
<ns1:develops rdf:resource="http://example.org/ontology#Xbox"/>
<rdf:type rdf:resource="http://example.org/ontology#TechnologyCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
<rdfs:comment>Dominant in software, cloud computing, and gaming.</rdfs:comment>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#GoogleSearch">
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Mercedes_SClass">
<rdf:type rdf:resource="http://example.org/ontology#LuxuryCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Audi_A8">
<rdf:type rdf:resource="http://example.org/ontology#LuxuryCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Sedan">
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Car"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#VW_Golf">
<rdf:type rdf:resource="http://example.org/ontology#Sedan"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Facebook">
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#WhatsApp">
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#produces">
<rdfs:domain rdf:resource="http://example.org/ontology#CarManufacturer"/>
<rdfs:range rdf:resource="http://example.org/ontology#Car"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#ObjectProperty"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#BMW_7Series">
<rdf:type rdf:resource="http://example.org/ontology#LuxuryCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#BMW_M4">
<rdf:type rdf:resource="http://example.org/ontology#SportsCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Audi_eTron">
<rdf:type rdf:resource="http://example.org/ontology#ElectricCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Kindle">
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#BMW_iX">
<rdf:type rdf:resource="http://example.org/ontology#ElectricCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#SoftwareCompany">
<rdfs:subClassOf rdf:resource="http://example.org/ontology#TechnologyCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Audi_R8">
<rdf:type rdf:resource="http://example.org/ontology#SportsCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Xbox">
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Technology">
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Mercedes_EQS">
<rdf:type rdf:resource="http://example.org/ontology#ElectricCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Porsche_911">
<rdf:type rdf:resource="http://example.org/ontology#SportsCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#HardwareCompany">
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
<rdfs:subClassOf rdf:resource="http://example.org/ontology#TechnologyCompany"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#MercedesBenz">
<ns1:produces rdf:resource="http://example.org/ontology#Mercedes_SClass"/>
<ns1:produces rdf:resource="http://example.org/ontology#Mercedes_EQS"/>
<ns1:produces rdf:resource="http://example.org/ontology#Mercedes_AMG_GT"/>
<rdfs:comment>Synonymous with luxury and quality.</rdfs:comment>
<rdf:type rdf:resource="http://example.org/ontology#CarManufacturer"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Amazon">
<rdf:type rdf:resource="http://example.org/ontology#TechnologyCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
<ns1:develops rdf:resource="http://example.org/ontology#Kindle"/>
<ns1:develops rdf:resource="http://example.org/ontology#AmazonEcho"/>
<ns1:develops rdf:resource="http://example.org/ontology#AWS"/>
<ns1:develops rdf:resource="http://example.org/ontology#AmazonPrime"/>
<rdfs:comment>From e-commerce to cloud computing giant with AWS.</rdfs:comment>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Instagram">
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#AWS">
<rdf:type rdf:resource="http://example.org/ontology#CloudServiceProvider"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#SUV">
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Car"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#VW_ID4">
<rdf:type rdf:resource="http://example.org/ontology#ElectricCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#CloudServiceProvider">
<rdfs:subClassOf rdf:resource="http://example.org/ontology#TechnologyCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Surface">
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#iPad">
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#iPhone">
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Mercedes_AMG_GT">
<rdf:type rdf:resource="http://example.org/ontology#SportsCar"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#MacBook">
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#develops">
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#ObjectProperty"/>
<rdfs:range rdf:resource="http://example.org/ontology#Technology"/>
<rdfs:domain rdf:resource="http://example.org/ontology#TechnologyCompany"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#LuxuryCar">
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Car"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#AppleWatch">
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Porsche_Cayenne">
<rdf:type rdf:resource="http://example.org/ontology#SUV"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#GooglePixel">
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#Company">
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
</rdf:Description>
<rdf:Description rdf:about="http://example.org/ontology#CarManufacturer">
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Company"/>
</rdf:Description>
</rdf:RDF>

View file

@ -0,0 +1,29 @@
import asyncio
import cognee
import os
from cognee.modules.ontology.ontology_config import Config
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
async def main():
texts = ["Audi produces the R8 and e-tron.", "Apple develops iPhone and MacBook."]
await cognee.add(texts)
# or: await cognee.add("/path/to/folder/of/files")
ontology_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "ontology_input_example/basic_ontology.owl"
)
# Create full config structure manually
config: Config = {
"ontology_config": {
"ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_path)
}
}
await cognee.cognify(config=config)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,25 @@
import asyncio
import cognee
async def main():
# Single file
await cognee.add("s3://cognee-s3-small-test/Natural_language_processing.txt")
# Folder/prefix (recursively expands)
await cognee.add("s3://cognee-s3-small-test")
# Mixed list
await cognee.add(
[
"s3://cognee-s3-small-test/Natural_language_processing.txt",
"Some inline text to ingest",
]
)
# Process the data
await cognee.cognify()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,52 @@
import asyncio
import cognee
from cognee.api.v1.search import SearchType
async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
text = """
Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval.
First rule of coding: Do not talk about coding.
"""
text2 = """
Sandwiches are best served toasted with cheese, ham, mayo,
lettuce, mustard, and salt & pepper.
"""
await cognee.add(text, dataset_name="NLP_coding")
await cognee.add(text2, dataset_name="Sandwiches")
await cognee.add(text2)
await cognee.cognify()
# Make sure you've already run cognee.cognify(...) so the graph has content
answers = await cognee.search(query_text="What are the main themes in my data?")
assert len(answers) > 0
answers = await cognee.search(
query_text="List coding guidelines",
query_type=SearchType.CODING_RULES,
)
assert len(answers) > 0
answers = await cognee.search(
query_text="Give me a confident answer: What is NLP?",
system_prompt="Answer succinctly and state confidence at the end.",
)
assert len(answers) > 0
answers = await cognee.search(
query_text="Tell me about NLP",
only_context=True,
)
assert len(answers) > 0
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,27 @@
import asyncio
import cognee
async def main():
# Start clean (optional in your app)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Prepare knowledge base
await cognee.add(
[
"Alice moved to Paris in 2010. She works as a software engineer.",
"Bob lives in New York. He is a data scientist.",
"Alice and Bob met at a conference in 2015.",
]
)
await cognee.cognify()
# Make sure you've already run cognee.cognify(...) so the graph has content
answers = await cognee.search(query_text="What are the main themes in my data?")
for answer in answers:
print(answer)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,57 @@
import asyncio
import cognee
async def main():
text = """
In 1998 the project launched. In 2001 version 1.0 shipped. In 2004 the team merged
with another group. In 2010 support for v1 ended.
"""
await cognee.add(text, dataset_name="timeline_demo")
await cognee.cognify(datasets=["timeline_demo"], temporal_cognify=True)
from cognee.api.v1.search import SearchType
# Before / after queries
result = await cognee.search(
query_type=SearchType.TEMPORAL, query_text="What happened before 2000?", top_k=10
)
assert result != []
result = await cognee.search(
query_type=SearchType.TEMPORAL, query_text="What happened after 2010?", top_k=10
)
assert result != []
# Between queries
result = await cognee.search(
query_type=SearchType.TEMPORAL, query_text="Events between 2001 and 2004", top_k=10
)
assert result != []
# Scoped descriptions
result = await cognee.search(
query_type=SearchType.TEMPORAL,
query_text="Key project milestones between 1998 and 2010",
top_k=10,
)
assert result != []
result = await cognee.search(
query_type=SearchType.TEMPORAL,
query_text="What happened after 2004?",
datasets=["timeline_demo"],
top_k=10,
)
assert result != []
if __name__ == "__main__":
asyncio.run(main())