Merge branch 'dev' into feature/cog-3698-enable-batch-queries-in-all-graph-completion-retrievers
This commit is contained in:
commit
6070f9f71f
91 changed files with 2286 additions and 139 deletions
48
.github/workflows/e2e_tests.yml
vendored
48
.github/workflows/e2e_tests.yml
vendored
|
|
@ -659,3 +659,51 @@ jobs:
|
|||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_pipeline_cache.py
|
||||
|
||||
run_usage_logger_test:
|
||||
name: Usage logger test (API/MCP)
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
services:
|
||||
redis:
|
||||
image: redis:7
|
||||
ports:
|
||||
- 6379:6379
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 5s
|
||||
--health-timeout 3s
|
||||
--health-retries 5
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
extra-dependencies: "redis"
|
||||
|
||||
- name: Install cognee-mcp (local version)
|
||||
shell: bash
|
||||
run: |
|
||||
uv pip install -e ./cognee-mcp
|
||||
|
||||
- name: Run api/tool usage logger
|
||||
env:
|
||||
ENV: dev
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
USAGE_LOGGING: true
|
||||
CACHE_BACKEND: 'redis'
|
||||
run: uv run pytest cognee/tests/test_usage_logger_e2e.py -v --log-level=INFO
|
||||
|
|
|
|||
|
|
@ -34,10 +34,6 @@ COPY README.md pyproject.toml uv.lock entrypoint.sh ./
|
|||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra debug --extra api --extra postgres --extra neo4j --extra llama-index --extra ollama --extra mistral --extra groq --extra anthropic --extra chromadb --frozen --no-install-project --no-dev --no-editable
|
||||
|
||||
# Copy Alembic configuration
|
||||
COPY alembic.ini /app/alembic.ini
|
||||
COPY alembic/ /app/alembic
|
||||
|
||||
# Then, add the rest of the project source code and install it
|
||||
# Installing separately from its dependencies allows optimal layer caching
|
||||
COPY ./cognee /app/cognee
|
||||
|
|
|
|||
|
|
@ -34,8 +34,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||
uv sync --frozen --no-install-project --no-dev --no-editable
|
||||
|
||||
# Copy Alembic configuration
|
||||
COPY alembic.ini /app/alembic.ini
|
||||
COPY alembic/ /app/alembic
|
||||
COPY cognee/alembic.ini /app/cognee/alembic.ini
|
||||
COPY cognee/alembic/ /app/cognee/alembic
|
||||
|
||||
# Then, add the rest of the project source code and install it
|
||||
# Installing separately from its dependencies allows optimal layer caching
|
||||
|
|
|
|||
|
|
@ -56,24 +56,22 @@ if [ -n "$API_URL" ]; then
|
|||
echo "Skipping database migrations (API server handles its own database)"
|
||||
else
|
||||
echo "Direct mode: Using local cognee instance"
|
||||
# Run Alembic migrations with proper error handling.
|
||||
# Note on UserAlreadyExists error handling:
|
||||
# During database migrations, we attempt to create a default user. If this user
|
||||
# already exists (e.g., from a previous deployment or migration), it's not a
|
||||
# critical error and shouldn't prevent the application from starting. This is
|
||||
# different from other migration errors which could indicate database schema
|
||||
# inconsistencies and should cause the startup to fail. This check allows for
|
||||
# smooth redeployments and container restarts while maintaining data integrity.
|
||||
echo "Running database migrations..."
|
||||
|
||||
MIGRATION_OUTPUT=$(alembic upgrade head)
|
||||
set +e # Disable exit on error to handle specific migration errors
|
||||
MIGRATION_OUTPUT=$(cd cognee && alembic upgrade head)
|
||||
MIGRATION_EXIT_CODE=$?
|
||||
set -e
|
||||
|
||||
if [[ $MIGRATION_EXIT_CODE -ne 0 ]]; then
|
||||
if [[ "$MIGRATION_OUTPUT" == *"UserAlreadyExists"* ]] || [[ "$MIGRATION_OUTPUT" == *"User default_user@example.com already exists"* ]]; then
|
||||
echo "Warning: Default user already exists, continuing startup..."
|
||||
else
|
||||
echo "Migration failed with unexpected error."
|
||||
|
||||
echo "Migration failed with unexpected error. Trying to run Cognee without migrations."
|
||||
echo "Initializing database tables..."
|
||||
python /app/src/run_cognee_database_setup.py
|
||||
INIT_EXIT_CODE=$?
|
||||
|
||||
if [[ $INIT_EXIT_CODE -ne 0 ]]; then
|
||||
echo "Database initialization failed!"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ requires-python = ">=3.10"
|
|||
dependencies = [
|
||||
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
|
||||
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/igorilic/Desktop/cognee",
|
||||
# TODO: Remove gemini from optional dependecnies for new Cognee version after 0.3.4
|
||||
"cognee[postgres,docs,neo4j]==0.5.0",
|
||||
"fastmcp>=2.10.0,<3.0.0",
|
||||
"mcp>=1.12.0,<2.0.0",
|
||||
|
|
|
|||
5
cognee-mcp/src/run_cognee_database_setup.py
Normal file
5
cognee-mcp/src/run_cognee_database_setup.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from cognee.modules.engine.operations.setup import setup
|
||||
import asyncio
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(setup())
|
||||
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||
from typing import Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location
|
||||
from cognee.shared.usage_logger import log_usage
|
||||
import importlib.util
|
||||
from contextlib import redirect_stdout
|
||||
import mcp.types as types
|
||||
|
|
@ -91,6 +92,7 @@ async def health_check(request):
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
@log_usage(function_name="MCP cognify", log_type="mcp_tool")
|
||||
async def cognify(
|
||||
data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None
|
||||
) -> list:
|
||||
|
|
@ -257,6 +259,7 @@ async def cognify(
|
|||
@mcp.tool(
|
||||
name="save_interaction", description="Logs user-agent interactions and query-answer pairs"
|
||||
)
|
||||
@log_usage(function_name="MCP save_interaction", log_type="mcp_tool")
|
||||
async def save_interaction(data: str) -> list:
|
||||
"""
|
||||
Transform and save a user-agent interaction into structured knowledge.
|
||||
|
|
@ -316,6 +319,7 @@ async def save_interaction(data: str) -> list:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
@log_usage(function_name="MCP search", log_type="mcp_tool")
|
||||
async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
|
@ -496,6 +500,7 @@ async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
@log_usage(function_name="MCP list_data", log_type="mcp_tool")
|
||||
async def list_data(dataset_id: str = None) -> list:
|
||||
"""
|
||||
List all datasets and their data items with IDs for deletion operations.
|
||||
|
|
@ -624,6 +629,7 @@ async def list_data(dataset_id: str = None) -> list:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
@log_usage(function_name="MCP delete", log_type="mcp_tool")
|
||||
async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
|
||||
"""
|
||||
Delete specific data from a dataset in the Cognee knowledge graph.
|
||||
|
|
@ -703,6 +709,7 @@ async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
@log_usage(function_name="MCP prune", log_type="mcp_tool")
|
||||
async def prune():
|
||||
"""
|
||||
Reset the Cognee knowledge graph by removing all stored information.
|
||||
|
|
@ -739,6 +746,7 @@ async def prune():
|
|||
|
||||
|
||||
@mcp.tool()
|
||||
@log_usage(function_name="MCP cognify_status", log_type="mcp_tool")
|
||||
async def cognify_status():
|
||||
"""
|
||||
Get the current status of the cognify pipeline.
|
||||
|
|
@ -884,26 +892,11 @@ async def main():
|
|||
|
||||
await setup()
|
||||
|
||||
# Run Alembic migrations from the main cognee directory where alembic.ini is located
|
||||
# Run Cognee migrations
|
||||
logger.info("Running database migrations...")
|
||||
migration_result = subprocess.run(
|
||||
["python", "-m", "alembic", "upgrade", "head"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=Path(__file__).resolve().parent.parent.parent,
|
||||
)
|
||||
from cognee.run_migrations import run_migrations
|
||||
|
||||
if migration_result.returncode != 0:
|
||||
migration_output = migration_result.stderr + migration_result.stdout
|
||||
# Check for the expected UserAlreadyExists error (which is not critical)
|
||||
if (
|
||||
"UserAlreadyExists" in migration_output
|
||||
or "User default_user@example.com already exists" in migration_output
|
||||
):
|
||||
logger.warning("Warning: Default user already exists, continuing startup...")
|
||||
else:
|
||||
logger.error(f"Migration failed with unexpected error: {migration_output}")
|
||||
sys.exit(1)
|
||||
await run_migrations()
|
||||
|
||||
logger.info("Database migrations done.")
|
||||
elif args.api_url:
|
||||
|
|
|
|||
|
|
@ -33,3 +33,5 @@ from .api.v1.ui import start_ui
|
|||
|
||||
# Pipelines
|
||||
from .modules import pipelines
|
||||
|
||||
from cognee.run_migrations import run_migrations
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from cognee.modules.users.methods import get_authenticated_user
|
|||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.pipelines.models import PipelineRunErrored
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.usage_logger import log_usage
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -19,6 +20,7 @@ def get_add_router() -> APIRouter:
|
|||
router = APIRouter()
|
||||
|
||||
@router.post("", response_model=dict)
|
||||
@log_usage(function_name="POST /v1/add", log_type="api_endpoint")
|
||||
async def add(
|
||||
data: List[UploadFile] = File(default=None),
|
||||
datasetName: Optional[str] = Form(default=None),
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
|
|||
)
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.shared.usage_logger import log_usage
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
logger = get_logger("api.cognify")
|
||||
|
|
@ -57,6 +58,7 @@ def get_cognify_router() -> APIRouter:
|
|||
router = APIRouter()
|
||||
|
||||
@router.post("", response_model=dict)
|
||||
@log_usage(function_name="POST /v1/cognify", log_type="api_endpoint")
|
||||
async def cognify(payload: CognifyPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Transform datasets into structured knowledge graphs through cognitive processing.
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user
|
|||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.pipelines.models import PipelineRunErrored
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.usage_logger import log_usage
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -35,6 +36,7 @@ def get_memify_router() -> APIRouter:
|
|||
router = APIRouter()
|
||||
|
||||
@router.post("", response_model=dict)
|
||||
@log_usage(function_name="POST /v1/memify", log_type="api_endpoint")
|
||||
async def memify(payload: MemifyPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Enrichment pipeline in Cognee, can work with already built graphs. If no data is provided existing knowledge graph will be used as data,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from cognee.modules.users.models import User
|
|||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.shared.usage_logger import log_usage
|
||||
from cognee import __version__ as cognee_version
|
||||
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||
from cognee.exceptions import CogneeValidationError
|
||||
|
|
@ -75,6 +76,7 @@ def get_search_router() -> APIRouter:
|
|||
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||
|
||||
@router.post("", response_model=Union[List[SearchResult], List])
|
||||
@log_usage(function_name="POST /v1/search", log_type="api_endpoint")
|
||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Search for nodes in the graph database.
|
||||
|
|
|
|||
|
|
@ -8,10 +8,13 @@ class CacheDBInterface(ABC):
|
|||
Provides a common interface for lock acquisition, release, and context-managed locking.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, port: int, lock_key: str):
|
||||
def __init__(
|
||||
self, host: str, port: int, lock_key: str = "default_lock", log_key: str = "usage_logs"
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.lock_key = lock_key
|
||||
self.log_key = log_key
|
||||
self.lock = None
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -77,3 +80,37 @@ class CacheDBInterface(ABC):
|
|||
Gracefully close any async connections.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def log_usage(
|
||||
self,
|
||||
user_id: str,
|
||||
log_entry: dict,
|
||||
ttl: int | None = 604800,
|
||||
):
|
||||
"""
|
||||
Log usage information (API endpoint calls, MCP tool invocations) to cache.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
log_entry: Dictionary containing usage log information.
|
||||
ttl: Optional time-to-live (seconds). If provided, the log list expires after this time.
|
||||
|
||||
Raises:
|
||||
CacheConnectionError: If cache connection fails or times out.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_usage_logs(self, user_id: str, limit: int = 100):
|
||||
"""
|
||||
Retrieve usage logs for a given user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
limit: Maximum number of logs to retrieve (default: 100).
|
||||
|
||||
Returns:
|
||||
List of usage log entries, most recent first.
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ class CacheConfig(BaseSettings):
|
|||
- cache_port: Port number for the cache service.
|
||||
- agentic_lock_expire: Automatic lock expiration time (in seconds).
|
||||
- agentic_lock_timeout: Maximum time (in seconds) to wait for the lock release.
|
||||
- usage_logging: Enable/disable usage logging for API endpoints and MCP tools.
|
||||
- usage_logging_ttl: Time-to-live for usage logs in seconds (default: 7 days).
|
||||
"""
|
||||
|
||||
cache_backend: Literal["redis", "fs"] = "fs"
|
||||
|
|
@ -24,6 +26,8 @@ class CacheConfig(BaseSettings):
|
|||
cache_password: Optional[str] = None
|
||||
agentic_lock_expire: int = 240
|
||||
agentic_lock_timeout: int = 300
|
||||
usage_logging: bool = False
|
||||
usage_logging_ttl: int = 604800
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
|
@ -38,6 +42,8 @@ class CacheConfig(BaseSettings):
|
|||
"cache_password": self.cache_password,
|
||||
"agentic_lock_expire": self.agentic_lock_expire,
|
||||
"agentic_lock_timeout": self.agentic_lock_timeout,
|
||||
"usage_logging": self.usage_logging,
|
||||
"usage_logging_ttl": self.usage_logging_ttl,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -89,6 +89,27 @@ class FSCacheAdapter(CacheDBInterface):
|
|||
return None
|
||||
return json.loads(value)
|
||||
|
||||
async def log_usage(
|
||||
self,
|
||||
user_id: str,
|
||||
log_entry: dict,
|
||||
ttl: int | None = 604800,
|
||||
):
|
||||
"""
|
||||
Usage logging is not supported in filesystem cache backend.
|
||||
This method is a no-op to satisfy the interface.
|
||||
"""
|
||||
logger.warning("Usage logging not supported in FSCacheAdapter, skipping")
|
||||
pass
|
||||
|
||||
async def get_usage_logs(self, user_id: str, limit: int = 100):
|
||||
"""
|
||||
Usage logging is not supported in filesystem cache backend.
|
||||
This method returns an empty list to satisfy the interface.
|
||||
"""
|
||||
logger.warning("Usage logging not supported in FSCacheAdapter, returning empty list")
|
||||
return []
|
||||
|
||||
async def close(self):
|
||||
if self.cache is not None:
|
||||
self.cache.expire()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"""Factory to get the appropriate cache coordination engine (e.g., Redis)."""
|
||||
|
||||
from functools import lru_cache
|
||||
import os
|
||||
from typing import Optional
|
||||
from cognee.infrastructure.databases.cache.config import get_cache_config
|
||||
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
|
||||
|
|
@ -17,6 +16,7 @@ def create_cache_engine(
|
|||
cache_username: str,
|
||||
cache_password: str,
|
||||
lock_key: str,
|
||||
log_key: str,
|
||||
agentic_lock_expire: int = 240,
|
||||
agentic_lock_timeout: int = 300,
|
||||
):
|
||||
|
|
@ -30,6 +30,7 @@ def create_cache_engine(
|
|||
- cache_username: Username to authenticate with.
|
||||
- cache_password: Password to authenticate with.
|
||||
- lock_key: Identifier used for the locking resource.
|
||||
- log_key: Identifier used for usage logging.
|
||||
- agentic_lock_expire: Duration to hold the lock after acquisition.
|
||||
- agentic_lock_timeout: Max time to wait for the lock before failing.
|
||||
|
||||
|
|
@ -37,7 +38,7 @@ def create_cache_engine(
|
|||
--------
|
||||
- CacheDBInterface: An instance of the appropriate cache adapter.
|
||||
"""
|
||||
if config.caching:
|
||||
if config.caching or config.usage_logging:
|
||||
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
|
||||
|
||||
if config.cache_backend == "redis":
|
||||
|
|
@ -47,6 +48,7 @@ def create_cache_engine(
|
|||
username=cache_username,
|
||||
password=cache_password,
|
||||
lock_name=lock_key,
|
||||
log_key=log_key,
|
||||
timeout=agentic_lock_expire,
|
||||
blocking_timeout=agentic_lock_timeout,
|
||||
)
|
||||
|
|
@ -61,7 +63,10 @@ def create_cache_engine(
|
|||
return None
|
||||
|
||||
|
||||
def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface:
|
||||
def get_cache_engine(
|
||||
lock_key: Optional[str] = "default_lock",
|
||||
log_key: Optional[str] = "usage_logs",
|
||||
) -> Optional[CacheDBInterface]:
|
||||
"""
|
||||
Returns a cache adapter instance using current context configuration.
|
||||
"""
|
||||
|
|
@ -72,6 +77,7 @@ def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface:
|
|||
cache_username=config.cache_username,
|
||||
cache_password=config.cache_password,
|
||||
lock_key=lock_key,
|
||||
log_key=log_key,
|
||||
agentic_lock_expire=config.agentic_lock_expire,
|
||||
agentic_lock_timeout=config.agentic_lock_timeout,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,13 +17,14 @@ class RedisAdapter(CacheDBInterface):
|
|||
host,
|
||||
port,
|
||||
lock_name="default_lock",
|
||||
log_key="usage_logs",
|
||||
username=None,
|
||||
password=None,
|
||||
timeout=240,
|
||||
blocking_timeout=300,
|
||||
connection_timeout=30,
|
||||
):
|
||||
super().__init__(host, port, lock_name)
|
||||
super().__init__(host, port, lock_name, log_key)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
|
@ -177,6 +178,64 @@ class RedisAdapter(CacheDBInterface):
|
|||
entries = await self.async_redis.lrange(session_key, 0, -1)
|
||||
return [json.loads(e) for e in entries]
|
||||
|
||||
async def log_usage(
|
||||
self,
|
||||
user_id: str,
|
||||
log_entry: dict,
|
||||
ttl: int | None = 604800,
|
||||
):
|
||||
"""
|
||||
Log usage information (API endpoint calls, MCP tool invocations) to Redis.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
log_entry: Dictionary containing usage log information.
|
||||
ttl: Optional time-to-live (seconds). If provided, the log list expires after this time.
|
||||
|
||||
Raises:
|
||||
CacheConnectionError: If Redis connection fails or times out.
|
||||
"""
|
||||
try:
|
||||
usage_logs_key = f"{self.log_key}:{user_id}"
|
||||
|
||||
await self.async_redis.rpush(usage_logs_key, json.dumps(log_entry))
|
||||
|
||||
if ttl is not None:
|
||||
await self.async_redis.expire(usage_logs_key, ttl)
|
||||
|
||||
except (redis.ConnectionError, redis.TimeoutError) as e:
|
||||
error_msg = f"Redis connection error while logging usage: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise CacheConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error while logging usage to Redis: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise CacheConnectionError(error_msg) from e
|
||||
|
||||
async def get_usage_logs(self, user_id: str, limit: int = 100):
|
||||
"""
|
||||
Retrieve usage logs for a given user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
limit: Maximum number of logs to retrieve (default: 100).
|
||||
|
||||
Returns:
|
||||
List of usage log entries, most recent first.
|
||||
"""
|
||||
try:
|
||||
usage_logs_key = f"{self.log_key}:{user_id}"
|
||||
entries = await self.async_redis.lrange(usage_logs_key, -limit, -1)
|
||||
return [json.loads(e) for e in reversed(entries)] if entries else []
|
||||
except (redis.ConnectionError, redis.TimeoutError) as e:
|
||||
error_msg = f"Redis connection error while retrieving usage logs: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise CacheConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error while retrieving usage logs from Redis: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise CacheConnectionError(error_msg) from e
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Gracefully close the async Redis connection.
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ async def get_graph_engine() -> GraphDBInterface:
|
|||
return graph_client
|
||||
|
||||
|
||||
@lru_cache
|
||||
def create_graph_engine(
|
||||
graph_database_provider,
|
||||
graph_file_path,
|
||||
|
|
@ -35,6 +34,35 @@ def create_graph_engine(
|
|||
graph_database_port="",
|
||||
graph_database_key="",
|
||||
graph_dataset_database_handler="",
|
||||
):
|
||||
"""
|
||||
Wrapper function to call create graph engine with caching.
|
||||
For a detailed description, see _create_graph_engine.
|
||||
"""
|
||||
return _create_graph_engine(
|
||||
graph_database_provider,
|
||||
graph_file_path,
|
||||
graph_database_url,
|
||||
graph_database_name,
|
||||
graph_database_username,
|
||||
graph_database_password,
|
||||
graph_database_port,
|
||||
graph_database_key,
|
||||
graph_dataset_database_handler,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _create_graph_engine(
|
||||
graph_database_provider,
|
||||
graph_file_path,
|
||||
graph_database_url="",
|
||||
graph_database_name="",
|
||||
graph_database_username="",
|
||||
graph_database_password="",
|
||||
graph_database_port="",
|
||||
graph_database_key="",
|
||||
graph_dataset_database_handler="",
|
||||
):
|
||||
"""
|
||||
Create a graph engine based on the specified provider type.
|
||||
|
|
|
|||
|
|
@ -236,6 +236,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
query_vector: Optional[List[float]] = None,
|
||||
limit: Optional[int] = None,
|
||||
with_vector: bool = False,
|
||||
include_payload: bool = False, # TODO: Add support for this parameter
|
||||
):
|
||||
"""
|
||||
Perform a search in the specified collection using either a text query or a vector
|
||||
|
|
@ -319,7 +320,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
self._na_exception_handler(e, query_string)
|
||||
|
||||
async def batch_search(
|
||||
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
|
||||
self,
|
||||
collection_name: str,
|
||||
query_texts: List[str],
|
||||
limit: int,
|
||||
with_vectors: bool = False,
|
||||
include_payload: bool = False,
|
||||
):
|
||||
"""
|
||||
Perform a batch search using multiple text queries against a collection.
|
||||
|
|
@ -342,7 +348,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
|
|||
data_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
self.search(collection_name, None, vector, limit, with_vectors)
|
||||
self.search(
|
||||
collection_name,
|
||||
None,
|
||||
vector,
|
||||
limit,
|
||||
with_vectors,
|
||||
include_payload=include_payload,
|
||||
)
|
||||
for vector in data_vectors
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -355,6 +355,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
include_payload: bool = False, # TODO: Add support for this parameter when set to False
|
||||
):
|
||||
"""
|
||||
Search for items in a collection using either a text or a vector query.
|
||||
|
|
@ -441,6 +442,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
query_texts: List[str],
|
||||
limit: int = 5,
|
||||
with_vectors: bool = False,
|
||||
include_payload: bool = False,
|
||||
):
|
||||
"""
|
||||
Perform multiple searches in a single request for efficiency, returning results for each
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from cognee.infrastructure.databases.graph.config import get_graph_context_confi
|
|||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache
|
||||
def create_vector_engine(
|
||||
vector_db_provider: str,
|
||||
vector_db_url: str,
|
||||
|
|
@ -15,6 +14,29 @@ def create_vector_engine(
|
|||
vector_db_port: str = "",
|
||||
vector_db_key: str = "",
|
||||
vector_dataset_database_handler: str = "",
|
||||
):
|
||||
"""
|
||||
Wrapper function to call create vector engine with caching.
|
||||
For a detailed description, see _create_vector_engine.
|
||||
"""
|
||||
return _create_vector_engine(
|
||||
vector_db_provider,
|
||||
vector_db_url,
|
||||
vector_db_name,
|
||||
vector_db_port,
|
||||
vector_db_key,
|
||||
vector_dataset_database_handler,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _create_vector_engine(
|
||||
vector_db_provider: str,
|
||||
vector_db_url: str,
|
||||
vector_db_name: str,
|
||||
vector_db_port: str = "",
|
||||
vector_db_key: str = "",
|
||||
vector_dataset_database_handler: str = "",
|
||||
):
|
||||
"""
|
||||
Create a vector database engine based on the specified provider.
|
||||
|
|
|
|||
|
|
@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
include_payload: bool = False,
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise MissingQueryParameterError()
|
||||
|
|
@ -247,17 +248,27 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
if limit <= 0:
|
||||
return []
|
||||
|
||||
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
|
||||
# Note: Exclude payload if not needed to optimize performance
|
||||
select_columns = (
|
||||
["id", "vector", "payload", "_distance"]
|
||||
if include_payload
|
||||
else ["id", "vector", "_distance"]
|
||||
)
|
||||
result_values = (
|
||||
await collection.vector_search(query_vector)
|
||||
.select(select_columns)
|
||||
.limit(limit)
|
||||
.to_list()
|
||||
)
|
||||
|
||||
if not result_values:
|
||||
return []
|
||||
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
payload=result["payload"],
|
||||
payload=result["payload"] if include_payload else None,
|
||||
score=normalized_values[value_index],
|
||||
)
|
||||
for value_index, result in enumerate(result_values)
|
||||
|
|
@ -269,6 +280,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
query_texts: List[str],
|
||||
limit: Optional[int] = None,
|
||||
with_vectors: bool = False,
|
||||
include_payload: bool = False,
|
||||
):
|
||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
||||
|
|
@ -279,6 +291,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
with_vector=with_vectors,
|
||||
include_payload=include_payload,
|
||||
)
|
||||
for query_vector in query_vectors
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -12,10 +12,10 @@ class ScoredResult(BaseModel):
|
|||
- id (UUID): Unique identifier for the scored result.
|
||||
- score (float): The score associated with the result, where a lower score indicates a
|
||||
better outcome.
|
||||
- payload (Dict[str, Any]): Additional information related to the score, stored as
|
||||
- payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as
|
||||
key-value pairs in a dictionary.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
score: float # Lower score is better
|
||||
payload: Dict[str, Any]
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
|
|
|
|||
|
|
@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
query_vector: Optional[List[float]] = None,
|
||||
limit: Optional[int] = 15,
|
||||
with_vector: bool = False,
|
||||
include_payload: bool = False,
|
||||
) -> List[ScoredResult]:
|
||||
if query_text is None and query_vector is None:
|
||||
raise MissingQueryParameterError()
|
||||
|
|
@ -324,10 +325,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||
closest_items = []
|
||||
|
||||
# Note: Exclude payload from returned columns if not needed to optimize performance
|
||||
select_columns = (
|
||||
[PGVectorDataPoint]
|
||||
if include_payload
|
||||
else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector]
|
||||
)
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
query = select(
|
||||
PGVectorDataPoint,
|
||||
*select_columns,
|
||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||
).order_by("similarity")
|
||||
|
||||
|
|
@ -344,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
vector_list.append(
|
||||
{
|
||||
"id": parse_id(str(vector.id)),
|
||||
"payload": vector.payload,
|
||||
"payload": vector.payload if include_payload else None,
|
||||
"_distance": vector.similarity,
|
||||
}
|
||||
)
|
||||
|
|
@ -359,7 +366,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
||||
ScoredResult(
|
||||
id=row.get("id"),
|
||||
payload=row.get("payload") if include_payload else None,
|
||||
score=row.get("score"),
|
||||
)
|
||||
for row in vector_list
|
||||
]
|
||||
|
||||
|
|
@ -369,6 +380,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
query_texts: List[str],
|
||||
limit: int = None,
|
||||
with_vectors: bool = False,
|
||||
include_payload: bool = False,
|
||||
):
|
||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
||||
|
|
@ -379,6 +391,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
with_vector=with_vectors,
|
||||
include_payload=include_payload,
|
||||
)
|
||||
for query_vector in query_vectors
|
||||
]
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ class VectorDBInterface(Protocol):
|
|||
query_vector: Optional[List[float]],
|
||||
limit: Optional[int],
|
||||
with_vector: bool = False,
|
||||
include_payload: bool = False,
|
||||
):
|
||||
"""
|
||||
Perform a search in the specified collection using either a text query or a vector
|
||||
|
|
@ -103,6 +104,9 @@ class VectorDBInterface(Protocol):
|
|||
- limit (Optional[int]): The maximum number of results to return from the search.
|
||||
- with_vector (bool): Whether to return the vector representations with search
|
||||
results. (default False)
|
||||
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
|
||||
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
|
||||
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -113,6 +117,7 @@ class VectorDBInterface(Protocol):
|
|||
query_texts: List[str],
|
||||
limit: Optional[int],
|
||||
with_vectors: bool = False,
|
||||
include_payload: bool = False,
|
||||
):
|
||||
"""
|
||||
Perform a batch search using multiple text queries against a collection.
|
||||
|
|
@ -125,6 +130,9 @@ class VectorDBInterface(Protocol):
|
|||
- limit (Optional[int]): The maximum number of results to return for each query.
|
||||
- with_vectors (bool): Whether to include vector representations with search
|
||||
results. (default False)
|
||||
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
|
||||
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
|
||||
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import time
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||
from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any
|
||||
|
||||
from cognee.modules.graph.exceptions import (
|
||||
|
|
@ -44,6 +45,12 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
self.edges.append(edge)
|
||||
|
||||
edge_text = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
|
||||
edge.attributes["edge_type_id"] = (
|
||||
generate_edge_id(edge_id=edge_text) if edge_text else None
|
||||
) # Update edge with generated edge_type_id
|
||||
|
||||
edge.node1.add_skeleton_edge(edge)
|
||||
edge.node2.add_skeleton_edge(edge)
|
||||
key = edge.get_distance_key()
|
||||
|
|
@ -284,13 +291,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
for query_index, scored_results in enumerate(per_query_scored_results):
|
||||
for result in scored_results:
|
||||
payload = getattr(result, "payload", None)
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
text = payload.get("text")
|
||||
if not text:
|
||||
continue
|
||||
matching_edges = self.edges_by_distance_key.get(str(text))
|
||||
matching_edges = self.edges_by_distance_key.get(str(result.id))
|
||||
if not matching_edges:
|
||||
continue
|
||||
for edge in matching_edges:
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ class Edge:
|
|||
self.status = np.ones(dimension, dtype=int)
|
||||
|
||||
def get_distance_key(self) -> Optional[str]:
|
||||
key = self.attributes.get("edge_text") or self.attributes.get("relationship_type")
|
||||
key = self.attributes.get("edge_type_id")
|
||||
if key is None:
|
||||
return None
|
||||
return str(key)
|
||||
|
|
|
|||
|
|
@ -47,7 +47,9 @@ class ChunksRetriever(BaseRetriever):
|
|||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||
found_chunks = await vector_engine.search(
|
||||
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
|
||||
)
|
||||
logger.info(f"Found {len(found_chunks)} chunks from vector search")
|
||||
await update_node_access_timestamps(found_chunks)
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,9 @@ class CompletionRetriever(BaseRetriever):
|
|||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||
found_chunks = await vector_engine.search(
|
||||
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
|
||||
)
|
||||
|
||||
if len(found_chunks) == 0:
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class SummariesRetriever(BaseRetriever):
|
|||
|
||||
try:
|
||||
summaries_results = await vector_engine.search(
|
||||
"TextSummary_text", query, limit=self.top_k
|
||||
"TextSummary_text", query, limit=self.top_k, include_payload=True
|
||||
)
|
||||
logger.info(f"Found {len(summaries_results)} summaries from vector search")
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
|
||||
async def filter_top_k_events(self, relevant_events, scored_results):
|
||||
# Build a score lookup from vector search results
|
||||
score_lookup = {res.payload["id"]: res.score for res in scored_results}
|
||||
score_lookup = {res.id: res.score for res in scored_results}
|
||||
|
||||
events_with_scores = []
|
||||
for event in relevant_events[0]["events"]:
|
||||
|
|
|
|||
|
|
@ -67,7 +67,9 @@ class TripletRetriever(BaseRetriever):
|
|||
"In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
|
||||
)
|
||||
|
||||
found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
|
||||
found_triplets = await vector_engine.search(
|
||||
"Triplet_text", query, limit=self.top_k, include_payload=True
|
||||
)
|
||||
|
||||
if len(found_triplets) == 0:
|
||||
return ""
|
||||
|
|
|
|||
48
cognee/run_migrations.py
Normal file
48
cognee/run_migrations.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
# Assuming your package is named 'cognee' and the migrations are under 'cognee/alembic'
|
||||
# This is a placeholder for the path logic.
|
||||
MIGRATIONS_PACKAGE = "cognee"
|
||||
MIGRATIONS_DIR_NAME = "alembic"
|
||||
|
||||
|
||||
async def run_migrations():
|
||||
"""
|
||||
Finds the Alembic configuration within the installed package and
|
||||
programmatically executes 'alembic upgrade head'.
|
||||
"""
|
||||
# 1. Locate the base path of the installed package.
|
||||
# This reliably finds the root directory of the installed 'cognee' package.
|
||||
# We look for the parent of the 'migrations' directory.
|
||||
package_root = str(pkg_resources.files(MIGRATIONS_PACKAGE))
|
||||
|
||||
# 2. Define the paths for config and scripts
|
||||
alembic_ini_path = os.path.join(package_root, "alembic.ini")
|
||||
script_location_path = os.path.join(package_root, MIGRATIONS_DIR_NAME)
|
||||
|
||||
if not os.path.exists(alembic_ini_path):
|
||||
raise FileNotFoundError(
|
||||
f"Error: alembic.ini not found at expected locations for package '{MIGRATIONS_PACKAGE}'."
|
||||
)
|
||||
if not os.path.exists(script_location_path):
|
||||
raise FileNotFoundError(
|
||||
f"Error: Migrations directory not found at expected locations for package '{MIGRATIONS_PACKAGE}'."
|
||||
)
|
||||
|
||||
migration_result = subprocess.run(
|
||||
["python", "-m", "alembic", "upgrade", "head"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=Path(package_root),
|
||||
)
|
||||
|
||||
if migration_result.returncode != 0:
|
||||
migration_output = migration_result.stderr + migration_result.stdout
|
||||
print(f"Migration failed with unexpected error: {migration_output}")
|
||||
sys.exit(1)
|
||||
|
||||
print("Migration completed successfully.")
|
||||
|
|
@ -4,6 +4,4 @@ Custom exceptions for the Cognee API.
|
|||
This module defines a set of exceptions for handling various shared utility errors
|
||||
"""
|
||||
|
||||
from .exceptions import (
|
||||
IngestionError,
|
||||
)
|
||||
from .exceptions import IngestionError, UsageLoggerError
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from cognee.exceptions import CogneeValidationError
|
||||
from cognee.exceptions import CogneeConfigurationError, CogneeValidationError
|
||||
from fastapi import status
|
||||
|
||||
|
||||
|
|
@ -10,3 +10,13 @@ class IngestionError(CogneeValidationError):
|
|||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class UsageLoggerError(CogneeConfigurationError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Usage logging configuration is invalid.",
|
||||
name: str = "UsageLoggerError",
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
332
cognee/shared/usage_logger.py
Normal file
332
cognee/shared/usage_logger.py
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from functools import singledispatch, wraps
|
||||
from typing import Any, Callable, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.databases.cache.config import get_cache_config
|
||||
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
|
||||
from cognee.shared.exceptions import UsageLoggerError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
logger = get_logger("usage_logger")
|
||||
|
||||
|
||||
@singledispatch
|
||||
def _sanitize_value(value: Any) -> Any:
|
||||
"""Default handler for JSON serialization - converts to string."""
|
||||
try:
|
||||
str_repr = str(value)
|
||||
if str_repr.startswith("<") and str_repr.endswith(">"):
|
||||
return f"<cannot be serialized: {type(value).__name__}>"
|
||||
return str_repr
|
||||
except Exception:
|
||||
return f"<cannot be serialized: {type(value).__name__}>"
|
||||
|
||||
|
||||
@_sanitize_value.register(type(None))
|
||||
def _(value: None) -> None:
|
||||
"""Handle None values - returns None as-is."""
|
||||
return None
|
||||
|
||||
|
||||
@_sanitize_value.register(str)
|
||||
@_sanitize_value.register(int)
|
||||
@_sanitize_value.register(float)
|
||||
@_sanitize_value.register(bool)
|
||||
def _(value: str | int | float | bool) -> str | int | float | bool:
|
||||
"""Handle primitive types - returns value as-is since they're JSON-serializable."""
|
||||
return value
|
||||
|
||||
|
||||
@_sanitize_value.register(UUID)
|
||||
def _(value: UUID) -> str:
|
||||
"""Convert UUID to string representation."""
|
||||
return str(value)
|
||||
|
||||
|
||||
@_sanitize_value.register(datetime)
|
||||
def _(value: datetime) -> str:
|
||||
"""Convert datetime to ISO format string."""
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
@_sanitize_value.register(list)
|
||||
@_sanitize_value.register(tuple)
|
||||
def _(value: list | tuple) -> list:
|
||||
"""Recursively sanitize list or tuple elements."""
|
||||
return [_sanitize_value(v) for v in value]
|
||||
|
||||
|
||||
@_sanitize_value.register(dict)
|
||||
def _(value: dict) -> dict:
|
||||
"""Recursively sanitize dictionary keys and values."""
|
||||
sanitized = {}
|
||||
for k, v in value.items():
|
||||
key_str = k if isinstance(k, str) else _sanitize_dict_key(k)
|
||||
sanitized[key_str] = _sanitize_value(v)
|
||||
return sanitized
|
||||
|
||||
|
||||
def _sanitize_dict_key(key: Any) -> str:
|
||||
"""Convert a non-string dict key to a string."""
|
||||
sanitized_key = _sanitize_value(key)
|
||||
if isinstance(sanitized_key, str):
|
||||
if sanitized_key.startswith("<cannot be serialized"):
|
||||
return f"<key:{type(key).__name__}>"
|
||||
return sanitized_key
|
||||
return str(sanitized_key)
|
||||
|
||||
|
||||
def _get_param_names(func: Callable) -> list[str]:
|
||||
"""Get parameter names from function signature."""
|
||||
try:
|
||||
return list(inspect.signature(func).parameters.keys())
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _get_param_defaults(func: Callable) -> dict[str, Any]:
|
||||
"""Get parameter defaults from function signature."""
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
defaults = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.default != inspect.Parameter.empty:
|
||||
defaults[param_name] = param.default
|
||||
return defaults
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _extract_user_id(args: tuple, kwargs: dict, param_names: list[str]) -> Optional[str]:
|
||||
"""Extract user_id from function arguments if available."""
|
||||
try:
|
||||
if "user" in kwargs and kwargs["user"] is not None:
|
||||
user = kwargs["user"]
|
||||
if hasattr(user, "id"):
|
||||
return str(user.id)
|
||||
|
||||
for i, param_name in enumerate(param_names):
|
||||
if i < len(args) and param_name == "user":
|
||||
user = args[i]
|
||||
if user is not None and hasattr(user, "id"):
|
||||
return str(user.id)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_parameters(args: tuple, kwargs: dict, param_names: list[str], func: Callable) -> dict:
|
||||
"""Extract function parameters - captures all parameters including defaults, sanitizes for JSON."""
|
||||
params = {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key != "user":
|
||||
params[key] = _sanitize_value(value)
|
||||
|
||||
if param_names:
|
||||
for i, param_name in enumerate(param_names):
|
||||
if i < len(args) and param_name != "user" and param_name not in kwargs:
|
||||
params[param_name] = _sanitize_value(args[i])
|
||||
else:
|
||||
for i, arg_value in enumerate(args):
|
||||
params[f"arg_{i}"] = _sanitize_value(arg_value)
|
||||
|
||||
if param_names:
|
||||
defaults = _get_param_defaults(func)
|
||||
for param_name in param_names:
|
||||
if param_name != "user" and param_name not in params and param_name in defaults:
|
||||
params[param_name] = _sanitize_value(defaults[param_name])
|
||||
|
||||
return params
|
||||
|
||||
|
||||
async def _log_usage_async(
|
||||
function_name: str,
|
||||
log_type: str,
|
||||
user_id: Optional[str],
|
||||
parameters: dict,
|
||||
result: Any,
|
||||
success: bool,
|
||||
error: Optional[str],
|
||||
duration_ms: float,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
):
|
||||
"""Asynchronously log function usage to Redis.
|
||||
|
||||
Args:
|
||||
function_name: Name of the function being logged.
|
||||
log_type: Type of log entry (e.g., "api_endpoint", "mcp_tool", "function").
|
||||
user_id: User identifier, or None to use "unknown".
|
||||
parameters: Dictionary of function parameters (sanitized).
|
||||
result: Function return value (will be sanitized).
|
||||
success: Whether the function executed successfully.
|
||||
error: Error message if function failed, None otherwise.
|
||||
duration_ms: Execution duration in milliseconds.
|
||||
start_time: Function start timestamp.
|
||||
end_time: Function end timestamp.
|
||||
|
||||
Note:
|
||||
This function silently handles errors to avoid disrupting the original
|
||||
function execution. Logs are written to Redis with TTL from config.
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Starting to log usage for {function_name} at {start_time.isoformat()}")
|
||||
config = get_cache_config()
|
||||
if not config.usage_logging:
|
||||
logger.debug("Usage logging disabled, skipping log")
|
||||
return
|
||||
|
||||
logger.debug(f"Getting cache engine for {function_name}")
|
||||
cache_engine = get_cache_engine()
|
||||
if cache_engine is None:
|
||||
logger.warning(
|
||||
f"Cache engine not available for usage logging (function: {function_name})"
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug(f"Cache engine obtained for {function_name}")
|
||||
|
||||
if user_id is None:
|
||||
user_id = "unknown"
|
||||
logger.debug(f"No user_id provided, using 'unknown' for {function_name}")
|
||||
|
||||
log_entry = {
|
||||
"timestamp": start_time.isoformat(),
|
||||
"type": log_type,
|
||||
"function_name": function_name,
|
||||
"user_id": user_id,
|
||||
"parameters": parameters,
|
||||
"result": _sanitize_value(result),
|
||||
"success": success,
|
||||
"error": error,
|
||||
"duration_ms": round(duration_ms, 2),
|
||||
"start_time": start_time.isoformat(),
|
||||
"end_time": end_time.isoformat(),
|
||||
"metadata": {
|
||||
"cognee_version": cognee_version,
|
||||
"environment": os.getenv("ENV", "prod"),
|
||||
},
|
||||
}
|
||||
|
||||
logger.debug(f"Calling log_usage for {function_name}, user_id={user_id}")
|
||||
await cache_engine.log_usage(
|
||||
user_id=user_id,
|
||||
log_entry=log_entry,
|
||||
ttl=config.usage_logging_ttl,
|
||||
)
|
||||
logger.info(f"Successfully logged usage for {function_name} (user_id={user_id})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log usage for {function_name}: {str(e)}", exc_info=True)
|
||||
|
||||
|
||||
def log_usage(function_name: Optional[str] = None, log_type: str = "function"):
|
||||
"""
|
||||
Decorator to log function usage to Redis.
|
||||
|
||||
This decorator is completely transparent - it doesn't change function behavior.
|
||||
It logs function name, parameters, result, timing, and user (if available).
|
||||
|
||||
Args:
|
||||
function_name: Optional name for the function (defaults to func.__name__)
|
||||
log_type: Type of log entry (e.g., "api_endpoint", "mcp_tool")
|
||||
|
||||
Usage:
|
||||
@log_usage(function_name="MCP my_mcp_tool", log_type="mcp_tool")
|
||||
async def my_mcp_tool(...):
|
||||
# mcp code
|
||||
|
||||
@log_usage(function_name="POST API /v1/add", log_type="api_endpoint")
|
||||
async def add(...):
|
||||
# endpoint code
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
"""Inner decorator that wraps the function with usage logging.
|
||||
|
||||
Args:
|
||||
func: The async function to wrap with usage logging.
|
||||
|
||||
Returns:
|
||||
Callable: The wrapped function with usage logging enabled.
|
||||
|
||||
Raises:
|
||||
UsageLoggerError: If the function is not async.
|
||||
"""
|
||||
if not inspect.iscoroutinefunction(func):
|
||||
raise UsageLoggerError(
|
||||
f"@log_usage requires an async function. Got {func.__name__} which is not async."
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
"""Wrapper function that executes the original function and logs usage.
|
||||
|
||||
This wrapper:
|
||||
- Extracts user ID and parameters from function arguments
|
||||
- Executes the original function
|
||||
- Captures result, success status, and any errors
|
||||
- Logs usage information asynchronously without blocking
|
||||
|
||||
Args:
|
||||
*args: Positional arguments passed to the original function.
|
||||
**kwargs: Keyword arguments passed to the original function.
|
||||
|
||||
Returns:
|
||||
Any: The return value of the original function.
|
||||
|
||||
Raises:
|
||||
Any exception raised by the original function (re-raised after logging).
|
||||
"""
|
||||
config = get_cache_config()
|
||||
if not config.usage_logging:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
param_names = _get_param_names(func)
|
||||
user_id = _extract_user_id(args, kwargs, param_names)
|
||||
parameters = _extract_parameters(args, kwargs, param_names, func)
|
||||
|
||||
result = None
|
||||
success = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
success = False
|
||||
error = str(e)
|
||||
raise
|
||||
finally:
|
||||
end_time = datetime.now(timezone.utc)
|
||||
duration_ms = (end_time - start_time).total_seconds() * 1000
|
||||
|
||||
try:
|
||||
await _log_usage_async(
|
||||
function_name=function_name or func.__name__,
|
||||
log_type=log_type,
|
||||
user_id=user_id,
|
||||
parameters=parameters,
|
||||
result=result,
|
||||
success=success,
|
||||
error=error,
|
||||
duration_ms=duration_ms,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to log usage for {function_name or func.__name__}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return async_wrapper
|
||||
|
||||
return decorator
|
||||
255
cognee/tests/integration/shared/test_usage_logger_integration.py
Normal file
255
cognee/tests/integration/shared/test_usage_logger_integration.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
"""Integration tests for usage logger with real Redis components."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from uuid import UUID
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.shared.usage_logger import log_usage
|
||||
from cognee.infrastructure.databases.cache.config import get_cache_config
|
||||
from cognee.infrastructure.databases.cache.get_cache_engine import (
|
||||
get_cache_engine,
|
||||
create_cache_engine,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def usage_logging_config():
|
||||
"""Fixture to enable usage logging via environment variables."""
|
||||
original_env = os.environ.copy()
|
||||
os.environ["USAGE_LOGGING"] = "true"
|
||||
os.environ["CACHE_BACKEND"] = "redis"
|
||||
os.environ["CACHE_HOST"] = "localhost"
|
||||
os.environ["CACHE_PORT"] = "6379"
|
||||
get_cache_config.cache_clear()
|
||||
create_cache_engine.cache_clear()
|
||||
yield
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
get_cache_config.cache_clear()
|
||||
create_cache_engine.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def usage_logging_disabled():
|
||||
"""Fixture to disable usage logging via environment variables."""
|
||||
original_env = os.environ.copy()
|
||||
os.environ["USAGE_LOGGING"] = "false"
|
||||
os.environ["CACHE_BACKEND"] = "redis"
|
||||
get_cache_config.cache_clear()
|
||||
create_cache_engine.cache_clear()
|
||||
yield
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
get_cache_config.cache_clear()
|
||||
create_cache_engine.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_adapter():
|
||||
"""Real RedisAdapter instance for testing."""
|
||||
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
|
||||
|
||||
try:
|
||||
yield RedisAdapter(host="localhost", port=6379, log_key="test_usage_logs")
|
||||
except Exception as e:
|
||||
pytest.skip(f"Redis not available: {e}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user():
|
||||
"""Test user object."""
|
||||
return SimpleNamespace(id="test-user-123")
|
||||
|
||||
|
||||
class TestDecoratorBehavior:
|
||||
"""Test decorator behavior with real components."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_configuration(
|
||||
self, usage_logging_disabled, usage_logging_config, redis_adapter
|
||||
):
|
||||
"""Test decorator skips when disabled and logs when enabled."""
|
||||
# Test disabled
|
||||
call_count = 0
|
||||
|
||||
@log_usage(function_name="test_func", log_type="test")
|
||||
async def test_func():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return "result"
|
||||
|
||||
assert await test_func() == "result"
|
||||
assert call_count == 1
|
||||
|
||||
# Test enabled with cache engine None
|
||||
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
|
||||
mock_get.return_value = None
|
||||
assert await test_func() == "result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_logging(self, usage_logging_config, redis_adapter, test_user):
|
||||
"""Test decorator logs to Redis with correct structure."""
|
||||
|
||||
@log_usage(function_name="test_func", log_type="test")
|
||||
async def test_func(param1: str, param2: int = 42, user=None):
|
||||
await asyncio.sleep(0.01)
|
||||
return {"result": f"{param1}_{param2}"}
|
||||
|
||||
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
|
||||
mock_get.return_value = redis_adapter
|
||||
|
||||
result = await test_func("value1", user=test_user)
|
||||
assert result == {"result": "value1_42"}
|
||||
|
||||
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
|
||||
log = logs[0]
|
||||
assert log["function_name"] == "test_func"
|
||||
assert log["type"] == "test"
|
||||
assert log["user_id"] == "test-user-123"
|
||||
assert log["parameters"]["param1"] == "value1"
|
||||
assert log["parameters"]["param2"] == 42
|
||||
assert log["success"] is True
|
||||
assert all(
|
||||
field in log
|
||||
for field in [
|
||||
"timestamp",
|
||||
"result",
|
||||
"error",
|
||||
"duration_ms",
|
||||
"start_time",
|
||||
"end_time",
|
||||
"metadata",
|
||||
]
|
||||
)
|
||||
assert "cognee_version" in log["metadata"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_calls(self, usage_logging_config, redis_adapter, test_user):
|
||||
"""Test multiple consecutive calls are all logged."""
|
||||
|
||||
@log_usage(function_name="multi_test", log_type="test")
|
||||
async def multi_func(call_num: int, user=None):
|
||||
return {"call": call_num}
|
||||
|
||||
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
|
||||
mock_get.return_value = redis_adapter
|
||||
|
||||
for i in range(3):
|
||||
await multi_func(i, user=test_user)
|
||||
|
||||
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
|
||||
assert len(logs) >= 3
|
||||
call_nums = {log["parameters"]["call_num"] for log in logs[:3]}
|
||||
assert call_nums == {0, 1, 2}
|
||||
|
||||
|
||||
class TestRealRedisIntegration:
|
||||
"""Test real Redis integration."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_storage_retrieval_and_ttl(
|
||||
self, usage_logging_config, redis_adapter, test_user
|
||||
):
|
||||
"""Test logs are stored, retrieved with correct order/limits, and TTL is set."""
|
||||
|
||||
@log_usage(function_name="redis_test", log_type="test")
|
||||
async def redis_func(data: str, user=None):
|
||||
return {"processed": data}
|
||||
|
||||
@log_usage(function_name="order_test", log_type="test")
|
||||
async def order_func(num: int, user=None):
|
||||
return {"num": num}
|
||||
|
||||
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
|
||||
mock_get.return_value = redis_adapter
|
||||
|
||||
# Storage
|
||||
await redis_func("test_data", user=test_user)
|
||||
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
|
||||
assert logs[0]["function_name"] == "redis_test"
|
||||
assert logs[0]["parameters"]["data"] == "test_data"
|
||||
|
||||
# Order (most recent first)
|
||||
for i in range(3):
|
||||
await order_func(i, user=test_user)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
|
||||
assert [log["parameters"]["num"] for log in logs[:3]] == [2, 1, 0]
|
||||
|
||||
# Limit
|
||||
assert len(await redis_adapter.get_usage_logs("test-user-123", limit=2)) == 2
|
||||
|
||||
# TTL
|
||||
ttl = await redis_adapter.async_redis.ttl("test_usage_logs:test-user-123")
|
||||
assert 0 < ttl <= 604800
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases in integration tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edge_cases(self, usage_logging_config, redis_adapter, test_user):
|
||||
"""Test no params, defaults, complex structures, exceptions, None, circular refs."""
|
||||
|
||||
@log_usage(function_name="no_params", log_type="test")
|
||||
async def no_params_func(user=None):
|
||||
return "result"
|
||||
|
||||
@log_usage(function_name="defaults_only", log_type="test")
|
||||
async def defaults_only_func(param1: str = "default1", param2: int = 42, user=None):
|
||||
return {"param1": param1, "param2": param2}
|
||||
|
||||
@log_usage(function_name="complex_test", log_type="test")
|
||||
async def complex_func(user=None):
|
||||
return {
|
||||
"nested": {
|
||||
"list": [1, 2, 3],
|
||||
"uuid": UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
"datetime": datetime(2024, 1, 15, tzinfo=timezone.utc),
|
||||
}
|
||||
}
|
||||
|
||||
@log_usage(function_name="exception_test", log_type="test")
|
||||
async def exception_func(user=None):
|
||||
raise RuntimeError("Test exception")
|
||||
|
||||
@log_usage(function_name="none_test", log_type="test")
|
||||
async def none_func(user=None):
|
||||
return None
|
||||
|
||||
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
|
||||
mock_get.return_value = redis_adapter
|
||||
|
||||
# No parameters
|
||||
await no_params_func(user=test_user)
|
||||
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
|
||||
assert logs[0]["parameters"] == {}
|
||||
|
||||
# Default parameters
|
||||
await defaults_only_func(user=test_user)
|
||||
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
|
||||
assert logs[0]["parameters"]["param1"] == "default1"
|
||||
assert logs[0]["parameters"]["param2"] == 42
|
||||
|
||||
# Complex nested structures
|
||||
await complex_func(user=test_user)
|
||||
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
|
||||
assert isinstance(logs[0]["result"]["nested"]["uuid"], str)
|
||||
assert isinstance(logs[0]["result"]["nested"]["datetime"], str)
|
||||
|
||||
# Exception handling
|
||||
with pytest.raises(RuntimeError):
|
||||
await exception_func(user=test_user)
|
||||
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
|
||||
assert logs[0]["success"] is False
|
||||
assert "Test exception" in logs[0]["error"]
|
||||
|
||||
# None return value
|
||||
assert await none_func(user=test_user) is None
|
||||
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
|
||||
assert logs[0]["result"] is None
|
||||
|
|
@ -97,7 +97,7 @@ async def test_vector_engine_search_none_limit():
|
|||
query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
result = await vector_engine.search(
|
||||
collection_name=collection_name, query_vector=query_vector, limit=None
|
||||
collection_name=collection_name, query_vector=query_vector, limit=None, include_payload=True
|
||||
)
|
||||
|
||||
# Check that we did not accidentally use any default value for limit
|
||||
|
|
|
|||
|
|
@ -70,7 +70,9 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node = (
|
||||
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||
)[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -149,7 +149,9 @@ async def main():
|
|||
await test_getting_of_documents(dataset_name_1)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node = (
|
||||
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||
)[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
||||
random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -63,7 +63,9 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node = (
|
||||
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||
)[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -52,7 +52,9 @@ async def main():
|
|||
await cognee.cognify([dataset_name])
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node = (
|
||||
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||
)[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -41,14 +41,14 @@ async def _reset_engines_and_prune() -> None:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
||||
create_relational_engine,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||
from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine
|
||||
|
||||
create_graph_engine.cache_clear()
|
||||
create_vector_engine.cache_clear()
|
||||
_create_graph_engine.cache_clear()
|
||||
_create_vector_engine.cache_clear()
|
||||
create_relational_engine.cache_clear()
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
|
|
|
|||
|
|
@ -163,7 +163,9 @@ async def main():
|
|||
await test_getting_of_documents(dataset_name_1)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node = (
|
||||
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||
)[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -58,7 +58,9 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node = (
|
||||
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
|
||||
)[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ async def main():
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
|
||||
random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
|
|
|
|||
|
|
@ -48,14 +48,14 @@ async def _reset_engines_and_prune() -> None:
|
|||
# Engine might not exist yet
|
||||
pass
|
||||
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
|
||||
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine
|
||||
from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine
|
||||
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
||||
create_relational_engine,
|
||||
)
|
||||
|
||||
create_graph_engine.cache_clear()
|
||||
create_vector_engine.cache_clear()
|
||||
_create_graph_engine.cache_clear()
|
||||
_create_vector_engine.cache_clear()
|
||||
create_relational_engine.cache_clear()
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
|
|
|
|||
268
cognee/tests/test_usage_logger_e2e.py
Normal file
268
cognee/tests/test_usage_logger_e2e.py
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
import os
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import cognee
|
||||
from cognee.api.client import app
|
||||
from cognee.modules.users.methods import get_default_user, get_authenticated_user
|
||||
|
||||
|
||||
async def _reset_engines_and_prune():
|
||||
"""Reset db engine caches and prune data/system."""
|
||||
try:
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
|
||||
await vector_engine.engine.dispose(close=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Use a single asyncio event loop for this test module."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
yield loop
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def e2e_config():
|
||||
"""Configure environment for E2E tests."""
|
||||
original_env = os.environ.copy()
|
||||
os.environ["USAGE_LOGGING"] = "true"
|
||||
os.environ["CACHE_BACKEND"] = "redis"
|
||||
os.environ["CACHE_HOST"] = "localhost"
|
||||
os.environ["CACHE_PORT"] = "6379"
|
||||
yield
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def authenticated_client(test_client):
|
||||
"""Override authentication to use default user."""
|
||||
|
||||
async def override_get_authenticated_user():
|
||||
return await get_default_user()
|
||||
|
||||
app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user
|
||||
yield test_client
|
||||
app.dependency_overrides.pop(get_authenticated_user, None)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def test_data_setup():
|
||||
"""Set up test data: prune first, then add file and cognify."""
|
||||
await _reset_engines_and_prune()
|
||||
|
||||
dataset_name = "test_e2e_dataset"
|
||||
test_text = "Germany is located in Europe right next to the Netherlands."
|
||||
|
||||
await cognee.add(test_text, dataset_name)
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
yield dataset_name
|
||||
|
||||
await _reset_engines_and_prune()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mcp_data_setup():
|
||||
"""Set up test data for MCP tests: prune first, then add file and cognify."""
|
||||
await _reset_engines_and_prune()
|
||||
|
||||
dataset_name = "test_mcp_dataset"
|
||||
test_text = "Germany is located in Europe right next to the Netherlands."
|
||||
|
||||
await cognee.add(test_text, dataset_name)
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
yield dataset_name
|
||||
|
||||
await _reset_engines_and_prune()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_client():
|
||||
"""TestClient instance for API calls."""
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def cache_engine(e2e_config):
|
||||
"""Get cache engine for log verification in test's event loop."""
|
||||
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
|
||||
from cognee.infrastructure.databases.cache.config import get_cache_config
|
||||
|
||||
config = get_cache_config()
|
||||
if not config.usage_logging or config.cache_backend != "redis":
|
||||
pytest.skip("Redis usage logging not configured")
|
||||
|
||||
engine = RedisAdapter(
|
||||
host=config.cache_host,
|
||||
port=config.cache_port,
|
||||
username=config.cache_username,
|
||||
password=config.cache_password,
|
||||
log_key="usage_logs",
|
||||
)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_endpoint_logging(e2e_config, authenticated_client, cache_engine):
|
||||
"""Test that API endpoints succeed and log to Redis."""
|
||||
user = await get_default_user()
|
||||
dataset_name = "test_e2e_api_dataset"
|
||||
|
||||
add_response = authenticated_client.post(
|
||||
"/api/v1/add",
|
||||
data={"datasetName": dataset_name},
|
||||
files=[
|
||||
(
|
||||
"data",
|
||||
(
|
||||
"test.txt",
|
||||
b"Germany is located in Europe right next to the Netherlands.",
|
||||
"text/plain",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
assert add_response.status_code in [200, 201], f"Add endpoint failed: {add_response.text}"
|
||||
|
||||
cognify_response = authenticated_client.post(
|
||||
"/api/v1/cognify",
|
||||
json={"datasets": [dataset_name], "run_in_background": False},
|
||||
)
|
||||
assert cognify_response.status_code in [200, 201], (
|
||||
f"Cognify endpoint failed: {cognify_response.text}"
|
||||
)
|
||||
|
||||
search_response = authenticated_client.post(
|
||||
"/api/v1/search",
|
||||
json={"query": "Germany", "search_type": "GRAPH_COMPLETION", "datasets": [dataset_name]},
|
||||
)
|
||||
assert search_response.status_code == 200, f"Search endpoint failed: {search_response.text}"
|
||||
|
||||
logs = await cache_engine.get_usage_logs(str(user.id), limit=20)
|
||||
|
||||
add_logs = [log for log in logs if log.get("function_name") == "POST /v1/add"]
|
||||
assert len(add_logs) > 0
|
||||
assert add_logs[0]["type"] == "api_endpoint"
|
||||
assert add_logs[0]["user_id"] == str(user.id)
|
||||
assert add_logs[0]["success"] is True
|
||||
|
||||
cognify_logs = [log for log in logs if log.get("function_name") == "POST /v1/cognify"]
|
||||
assert len(cognify_logs) > 0
|
||||
assert cognify_logs[0]["type"] == "api_endpoint"
|
||||
assert cognify_logs[0]["user_id"] == str(user.id)
|
||||
assert cognify_logs[0]["success"] is True
|
||||
|
||||
search_logs = [log for log in logs if log.get("function_name") == "POST /v1/search"]
|
||||
assert len(search_logs) > 0
|
||||
assert search_logs[0]["type"] == "api_endpoint"
|
||||
assert search_logs[0]["user_id"] == str(user.id)
|
||||
assert search_logs[0]["success"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_logging(e2e_config, cache_engine):
|
||||
"""Test that MCP tools succeed and log to Redis."""
|
||||
import sys
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
await _reset_engines_and_prune()
|
||||
|
||||
repo_root = Path(__file__).parent.parent.parent
|
||||
mcp_src_path = repo_root / "cognee-mcp" / "src"
|
||||
mcp_server_path = mcp_src_path / "server.py"
|
||||
|
||||
if not mcp_server_path.exists():
|
||||
pytest.skip(f"MCP server not found at {mcp_server_path}")
|
||||
|
||||
if str(mcp_src_path) not in sys.path:
|
||||
sys.path.insert(0, str(mcp_src_path))
|
||||
|
||||
spec = importlib.util.spec_from_file_location("mcp_server_module", mcp_server_path)
|
||||
mcp_server_module = importlib.util.module_from_spec(spec)
|
||||
|
||||
import os
|
||||
|
||||
original_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(str(mcp_src_path))
|
||||
spec.loader.exec_module(mcp_server_module)
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
if mcp_server_module.cognee_client is None:
|
||||
cognee_client_path = mcp_src_path / "cognee_client.py"
|
||||
if cognee_client_path.exists():
|
||||
spec_client = importlib.util.spec_from_file_location(
|
||||
"cognee_client", cognee_client_path
|
||||
)
|
||||
cognee_client_module = importlib.util.module_from_spec(spec_client)
|
||||
spec_client.loader.exec_module(cognee_client_module)
|
||||
CogneeClient = cognee_client_module.CogneeClient
|
||||
mcp_server_module.cognee_client = CogneeClient()
|
||||
else:
|
||||
pytest.skip(f"CogneeClient not found at {cognee_client_path}")
|
||||
|
||||
test_text = "Germany is located in Europe right next to the Netherlands."
|
||||
await mcp_server_module.cognify(data=test_text)
|
||||
await asyncio.sleep(30.0)
|
||||
|
||||
list_result = await mcp_server_module.list_data()
|
||||
assert list_result is not None, "List data should return results"
|
||||
|
||||
search_result = await mcp_server_module.search(
|
||||
search_query="Germany", search_type="GRAPH_COMPLETION", top_k=5
|
||||
)
|
||||
assert search_result is not None, "Search should return results"
|
||||
|
||||
interaction_data = "User: What is Germany?\nAgent: Germany is a country in Europe."
|
||||
await mcp_server_module.save_interaction(data=interaction_data)
|
||||
await asyncio.sleep(30.0)
|
||||
|
||||
status_result = await mcp_server_module.cognify_status()
|
||||
assert status_result is not None, "Cognify status should return results"
|
||||
|
||||
await mcp_server_module.prune()
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
logs = await cache_engine.get_usage_logs("unknown", limit=50)
|
||||
mcp_logs = [log for log in logs if log.get("type") == "mcp_tool"]
|
||||
assert len(mcp_logs) > 0, (
|
||||
f"Should have MCP tool logs with user_id='unknown'. Found logs: {[log.get('function_name') for log in logs[:5]]}"
|
||||
)
|
||||
assert len(mcp_logs) == 6
|
||||
function_names = [log.get("function_name") for log in mcp_logs]
|
||||
expected_tools = [
|
||||
"MCP cognify",
|
||||
"MCP list_data",
|
||||
"MCP search",
|
||||
"MCP save_interaction",
|
||||
"MCP cognify_status",
|
||||
"MCP prune",
|
||||
]
|
||||
|
||||
for expected_tool in expected_tools:
|
||||
assert expected_tool in function_names, (
|
||||
f"Should have {expected_tool} log. Found: {function_names}"
|
||||
)
|
||||
|
||||
for log in mcp_logs:
|
||||
assert log["type"] == "mcp_tool"
|
||||
assert log["user_id"] == "unknown"
|
||||
assert log["success"] is True
|
||||
|
|
@ -62,6 +62,8 @@ def test_cache_config_to_dict():
|
|||
"cache_password": None,
|
||||
"agentic_lock_expire": 100,
|
||||
"agentic_lock_timeout": 200,
|
||||
"usage_logging": False,
|
||||
"usage_logging_ttl": 604800,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||
|
|
@ -379,7 +380,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
|
|||
graph.add_edge(edge)
|
||||
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
MockScoredResult(generate_edge_id("CONNECTS_TO"), 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
|
@ -404,8 +405,9 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
|
|||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_1_text = "CONNECTS_TO"
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
|
||||
MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
|
@ -431,8 +433,9 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
|
|||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_text = "KNOWS"
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
|
||||
MockScoredResult(generate_edge_id(edge_text), 0.85, payload={"text": edge_text}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
|
@ -457,8 +460,9 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
|
|||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
edge_text = "SOME_OTHER_EDGE"
|
||||
edge_distances = [
|
||||
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
|
||||
MockScoredResult(generate_edge_id(edge_text), 0.92, payload={"text": edge_text}),
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
|
||||
|
|
@ -511,9 +515,15 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
|
|||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_1_text = "A"
|
||||
edge_2_text = "B"
|
||||
edge_distances = [
|
||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0
|
||||
[MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1
|
||||
[
|
||||
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
|
||||
], # query 0
|
||||
[
|
||||
MockScoredResult(generate_edge_id(edge_2_text), 0.2, payload={"text": edge_2_text})
|
||||
], # query 1
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_edges(
|
||||
|
|
@ -541,8 +551,11 @@ async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(se
|
|||
graph.add_edge(edge1)
|
||||
graph.add_edge(edge2)
|
||||
|
||||
edge_1_text = "A"
|
||||
edge_distances = [
|
||||
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped
|
||||
[
|
||||
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
|
||||
], # query 0: only edge1 mapped
|
||||
[], # query 1: no edges mapped
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine):
|
|||
assert len(context) == 2
|
||||
assert context[0]["text"] == "Steve Rodger"
|
||||
assert context[1]["text"] == "Mike Broski"
|
||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5)
|
||||
mock_vector_engine.search.assert_awaited_once_with(
|
||||
"DocumentChunk_text", "test query", limit=5, include_payload=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
|
|||
context = await retriever.get_context("test query")
|
||||
|
||||
assert len(context) == 3
|
||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3)
|
||||
mock_vector_engine.search.assert_awaited_once_with(
|
||||
"DocumentChunk_text", "test query", limit=3, include_payload=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -33,7 +33,9 @@ async def test_get_context_success(mock_vector_engine):
|
|||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "Steve Rodger\nMike Broski"
|
||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
|
||||
mock_vector_engine.search.assert_awaited_once_with(
|
||||
"DocumentChunk_text", "test query", limit=2, include_payload=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -85,7 +87,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
|
|||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "Chunk 0\nChunk 1"
|
||||
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
|
||||
mock_vector_engine.search.assert_awaited_once_with(
|
||||
"DocumentChunk_text", "test query", limit=2, include_payload=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine):
|
|||
assert len(context) == 2
|
||||
assert context[0]["text"] == "S.R."
|
||||
assert context[1]["text"] == "M.B."
|
||||
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5)
|
||||
mock_vector_engine.search.assert_awaited_once_with(
|
||||
"TextSummary_text", "test query", limit=5, include_payload=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
|
|||
context = await retriever.get_context("test query")
|
||||
|
||||
assert len(context) == 3
|
||||
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3)
|
||||
mock_vector_engine.search.assert_awaited_once_with(
|
||||
"TextSummary_text", "test query", limit=3, include_payload=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -63,8 +63,8 @@ async def test_filter_top_k_events_sorts_and_limits():
|
|||
]
|
||||
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "e2"}, score=0.10),
|
||||
SimpleNamespace(payload={"id": "e1"}, score=0.20),
|
||||
SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.10),
|
||||
SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.20),
|
||||
]
|
||||
|
||||
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
|
|
@ -91,8 +91,8 @@ async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k
|
|||
]
|
||||
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "known2"}, score=0.05),
|
||||
SimpleNamespace(payload={"id": "known1"}, score=0.50),
|
||||
SimpleNamespace(id="known2", payload={"id": "known2"}, score=0.05),
|
||||
SimpleNamespace(id="known1", payload={"id": "known1"}, score=0.50),
|
||||
]
|
||||
|
||||
top = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
|
|
@ -119,8 +119,8 @@ async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
|
|||
tr = TemporalRetriever(top_k=10)
|
||||
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
|
||||
scored_results = [
|
||||
SimpleNamespace(payload={"id": "a"}, score=0.1),
|
||||
SimpleNamespace(payload={"id": "b"}, score=0.2),
|
||||
SimpleNamespace(id="a", payload={"id": "a"}, score=0.1),
|
||||
SimpleNamespace(id="b", payload={"id": "b"}, score=0.2),
|
||||
]
|
||||
out = await tr.filter_top_k_events(relevant_events, scored_results)
|
||||
assert [e["id"] for e in out] == ["a", "b"]
|
||||
|
|
@ -179,8 +179,8 @@ async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine
|
|||
}
|
||||
]
|
||||
|
||||
mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
|
||||
mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
|
||||
mock_result1 = SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.05)
|
||||
mock_result2 = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.10)
|
||||
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
|
||||
|
||||
with (
|
||||
|
|
@ -279,7 +279,7 @@ async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine)
|
|||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
|
|
@ -313,7 +313,7 @@ async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
|
|||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
|
|
@ -347,7 +347,7 @@ async def test_get_completion_without_context(mock_graph_engine, mock_vector_eng
|
|||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
|
|
@ -416,7 +416,7 @@ async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine
|
|||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
mock_user = MagicMock()
|
||||
|
|
@ -481,7 +481,7 @@ async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_ve
|
|||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
|
|
@ -570,7 +570,7 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector
|
|||
}
|
||||
]
|
||||
|
||||
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
|
||||
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
|
||||
mock_vector_engine.search.return_value = [mock_result]
|
||||
|
||||
with (
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
|||
get_memory_fragment,
|
||||
format_triplets,
|
||||
)
|
||||
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
|
@ -1036,9 +1037,11 @@ async def test_cognee_graph_mapping_batch_shapes():
|
|||
]
|
||||
}
|
||||
|
||||
edge_1_text = "relates_to"
|
||||
edge_2_text = "relates_to"
|
||||
edge_distances_batch = [
|
||||
[MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})],
|
||||
[MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})],
|
||||
[MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text})],
|
||||
[MockScoredResult(generate_edge_id(edge_2_text), 0.88, payload={"text": edge_2_text})],
|
||||
]
|
||||
|
||||
await graph.map_vector_distances_to_graph_nodes(
|
||||
|
|
|
|||
|
|
@ -34,7 +34,9 @@ async def test_get_context_success(mock_vector_engine):
|
|||
context = await retriever.get_context("test query")
|
||||
|
||||
assert context == "Alice knows Bob\nBob works at Tech Corp"
|
||||
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
|
||||
mock_vector_engine.search.assert_awaited_once_with(
|
||||
"Triplet_text", "test query", limit=5, include_payload=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
241
cognee/tests/unit/shared/test_usage_logger.py
Normal file
241
cognee/tests/unit/shared/test_usage_logger.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
"""Unit tests for usage logger core functions."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
from types import SimpleNamespace
|
||||
|
||||
from cognee.shared.usage_logger import (
|
||||
_sanitize_value,
|
||||
_sanitize_dict_key,
|
||||
_get_param_names,
|
||||
_get_param_defaults,
|
||||
_extract_user_id,
|
||||
_extract_parameters,
|
||||
log_usage,
|
||||
)
|
||||
from cognee.shared.exceptions import UsageLoggerError
|
||||
|
||||
|
||||
class TestSanitizeValue:
|
||||
"""Test _sanitize_value function."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected",
|
||||
[
|
||||
(None, None),
|
||||
("string", "string"),
|
||||
(42, 42),
|
||||
(3.14, 3.14),
|
||||
(True, True),
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
def test_basic_types(self, value, expected):
|
||||
assert _sanitize_value(value) == expected
|
||||
|
||||
def test_uuid_and_datetime(self):
|
||||
"""Test UUID and datetime serialization."""
|
||||
uuid_val = UUID("123e4567-e89b-12d3-a456-426614174000")
|
||||
dt = datetime(2024, 1, 15, 12, 30, 45, tzinfo=timezone.utc)
|
||||
|
||||
assert _sanitize_value(uuid_val) == "123e4567-e89b-12d3-a456-426614174000"
|
||||
assert _sanitize_value(dt) == "2024-01-15T12:30:45+00:00"
|
||||
|
||||
def test_collections(self):
|
||||
"""Test list, tuple, and dict serialization."""
|
||||
assert _sanitize_value(
|
||||
[1, "string", UUID("123e4567-e89b-12d3-a456-426614174000"), None]
|
||||
) == [1, "string", "123e4567-e89b-12d3-a456-426614174000", None]
|
||||
assert _sanitize_value((1, "string", True)) == [1, "string", True]
|
||||
assert _sanitize_value({"key": UUID("123e4567-e89b-12d3-a456-426614174000")}) == {
|
||||
"key": "123e4567-e89b-12d3-a456-426614174000"
|
||||
}
|
||||
assert _sanitize_value([]) == []
|
||||
assert _sanitize_value({}) == {}
|
||||
|
||||
def test_nested_and_complex(self):
|
||||
"""Test nested structures and non-serializable types."""
|
||||
# Nested structure
|
||||
nested = {"level1": {"level2": {"level3": [1, 2, {"nested": "value"}]}}}
|
||||
assert _sanitize_value(nested)["level1"]["level2"]["level3"][2]["nested"] == "value"
|
||||
|
||||
# Non-serializable
|
||||
class CustomObject:
|
||||
def __str__(self):
|
||||
return "<CustomObject instance>"
|
||||
|
||||
result = _sanitize_value(CustomObject())
|
||||
assert isinstance(result, str)
|
||||
assert "<cannot be serialized" in result or "<CustomObject" in result
|
||||
|
||||
|
||||
class TestSanitizeDictKey:
|
||||
"""Test _sanitize_dict_key function."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key,expected_contains",
|
||||
[
|
||||
("simple_key", "simple_key"),
|
||||
(UUID("123e4567-e89b-12d3-a456-426614174000"), "123e4567-e89b-12d3-a456-426614174000"),
|
||||
((1, 2, 3), ["1", "2"]),
|
||||
],
|
||||
)
|
||||
def test_key_types(self, key, expected_contains):
|
||||
result = _sanitize_dict_key(key)
|
||||
assert isinstance(result, str)
|
||||
if isinstance(expected_contains, list):
|
||||
assert all(item in result for item in expected_contains)
|
||||
else:
|
||||
assert expected_contains in result
|
||||
|
||||
def test_non_serializable_key(self):
|
||||
class BadKey:
|
||||
def __str__(self):
|
||||
return "<BadKey instance>"
|
||||
|
||||
result = _sanitize_dict_key(BadKey())
|
||||
assert isinstance(result, str)
|
||||
assert "<key:" in result or "<BadKey" in result
|
||||
|
||||
|
||||
class TestGetParamNames:
|
||||
"""Test _get_param_names function."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func_def,expected",
|
||||
[
|
||||
(lambda a, b, c: None, ["a", "b", "c"]),
|
||||
(lambda a, b=42, c="default": None, ["a", "b", "c"]),
|
||||
(lambda a, **kwargs: None, ["a", "kwargs"]),
|
||||
(lambda *args: None, ["args"]),
|
||||
],
|
||||
)
|
||||
def test_param_extraction(self, func_def, expected):
|
||||
assert _get_param_names(func_def) == expected
|
||||
|
||||
def test_async_function(self):
|
||||
async def func(a, b):
|
||||
pass
|
||||
|
||||
assert _get_param_names(func) == ["a", "b"]
|
||||
|
||||
|
||||
class TestGetParamDefaults:
|
||||
"""Test _get_param_defaults function."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func_def,expected",
|
||||
[
|
||||
(lambda a, b=42, c="default", d=None: None, {"b": 42, "c": "default", "d": None}),
|
||||
(lambda a, b, c: None, {}),
|
||||
(lambda a, b=10, c="test", d=None: None, {"b": 10, "c": "test", "d": None}),
|
||||
],
|
||||
)
|
||||
def test_default_extraction(self, func_def, expected):
|
||||
assert _get_param_defaults(func_def) == expected
|
||||
|
||||
|
||||
class TestExtractUserId:
|
||||
"""Test _extract_user_id function."""
|
||||
|
||||
def test_user_extraction(self):
|
||||
"""Test extracting user_id from kwargs and args."""
|
||||
user1 = SimpleNamespace(id=UUID("123e4567-e89b-12d3-a456-426614174000"))
|
||||
user2 = SimpleNamespace(id="user-123")
|
||||
|
||||
# From kwargs
|
||||
assert (
|
||||
_extract_user_id((), {"user": user1}, ["user", "other"])
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
# From args
|
||||
assert _extract_user_id((user2, "other"), {}, ["user", "other"]) == "user-123"
|
||||
# Not present
|
||||
assert _extract_user_id(("arg1",), {}, ["param1"]) is None
|
||||
# None value
|
||||
assert _extract_user_id((None,), {}, ["user"]) is None
|
||||
# No id attribute
|
||||
assert _extract_user_id((SimpleNamespace(name="test"),), {}, ["user"]) is None
|
||||
|
||||
|
||||
class TestExtractParameters:
|
||||
"""Test _extract_parameters function."""
|
||||
|
||||
def test_parameter_extraction(self):
|
||||
"""Test parameter extraction with various scenarios."""
|
||||
|
||||
def func1(param1, param2, user=None):
|
||||
pass
|
||||
|
||||
def func2(param1, param2=42, param3="default", user=None):
|
||||
pass
|
||||
|
||||
def func3():
|
||||
pass
|
||||
|
||||
def func4(param1, user):
|
||||
pass
|
||||
|
||||
# Kwargs only
|
||||
result = _extract_parameters(
|
||||
(), {"param1": "v1", "param2": 42}, _get_param_names(func1), func1
|
||||
)
|
||||
assert result == {"param1": "v1", "param2": 42}
|
||||
assert "user" not in result
|
||||
|
||||
# Args only
|
||||
result = _extract_parameters(("v1", 42), {}, _get_param_names(func1), func1)
|
||||
assert result == {"param1": "v1", "param2": 42}
|
||||
|
||||
# Mixed args/kwargs
|
||||
result = _extract_parameters(("v1",), {"param3": "v3"}, _get_param_names(func2), func2)
|
||||
assert result["param1"] == "v1" and result["param3"] == "v3"
|
||||
|
||||
# Defaults included
|
||||
result = _extract_parameters(("v1",), {}, _get_param_names(func2), func2)
|
||||
assert result["param1"] == "v1" and result["param2"] == 42 and result["param3"] == "default"
|
||||
|
||||
# No parameters
|
||||
assert _extract_parameters((), {}, _get_param_names(func3), func3) == {}
|
||||
|
||||
# User excluded
|
||||
user = SimpleNamespace(id="user-123")
|
||||
result = _extract_parameters(("v1", user), {}, _get_param_names(func4), func4)
|
||||
assert result == {"param1": "v1"} and "user" not in result
|
||||
|
||||
# Fallback when inspection fails
|
||||
class BadFunc:
|
||||
pass
|
||||
|
||||
result = _extract_parameters(("arg1", "arg2"), {}, [], BadFunc())
|
||||
assert "arg_0" in result or "arg_1" in result
|
||||
|
||||
|
||||
class TestDecoratorValidation:
|
||||
"""Test decorator validation and behavior."""
|
||||
|
||||
def test_decorator_validation(self):
|
||||
"""Test decorator validation and metadata preservation."""
|
||||
# Sync function raises error
|
||||
with pytest.raises(UsageLoggerError, match="requires an async function"):
|
||||
|
||||
@log_usage()
|
||||
def sync_func():
|
||||
pass
|
||||
|
||||
# Async function accepted
|
||||
@log_usage()
|
||||
async def async_func():
|
||||
pass
|
||||
|
||||
assert callable(async_func)
|
||||
|
||||
# Metadata preserved
|
||||
@log_usage(function_name="test_func", log_type="test")
|
||||
async def test_func(param1: str, param2: int = 42):
|
||||
"""Test docstring."""
|
||||
return param1
|
||||
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert "Test docstring" in test_func.__doc__
|
||||
|
|
@ -11,34 +11,24 @@ echo "Debug port: $DEBUG_PORT"
|
|||
echo "HTTP port: $HTTP_PORT"
|
||||
|
||||
# Run Alembic migrations with proper error handling.
|
||||
# Note on UserAlreadyExists error handling:
|
||||
# During database migrations, we attempt to create a default user. If this user
|
||||
# already exists (e.g., from a previous deployment or migration), it's not a
|
||||
# critical error and shouldn't prevent the application from starting. This is
|
||||
# different from other migration errors which could indicate database schema
|
||||
# inconsistencies and should cause the startup to fail. This check allows for
|
||||
# smooth redeployments and container restarts while maintaining data integrity.
|
||||
echo "Running database migrations..."
|
||||
|
||||
# Move to the cognee directory to run alembic migrations from there
|
||||
set +e # Disable exit on error to handle specific migration errors
|
||||
MIGRATION_OUTPUT=$(alembic upgrade head)
|
||||
MIGRATION_OUTPUT=$(cd cognee && alembic upgrade head)
|
||||
MIGRATION_EXIT_CODE=$?
|
||||
set -e
|
||||
|
||||
if [[ $MIGRATION_EXIT_CODE -ne 0 ]]; then
|
||||
if [[ "$MIGRATION_OUTPUT" == *"UserAlreadyExists"* ]] || [[ "$MIGRATION_OUTPUT" == *"User default_user@example.com already exists"* ]]; then
|
||||
echo "Warning: Default user already exists, continuing startup..."
|
||||
else
|
||||
echo "Migration failed with unexpected error. Trying to run Cognee without migrations."
|
||||
echo "Migration failed with unexpected error. Trying to run Cognee without migrations."
|
||||
|
||||
echo "Initializing database tables..."
|
||||
python /app/cognee/modules/engine/operations/setup.py
|
||||
INIT_EXIT_CODE=$?
|
||||
echo "Initializing database tables..."
|
||||
python /app/cognee/modules/engine/operations/setup.py
|
||||
INIT_EXIT_CODE=$?
|
||||
|
||||
if [[ $INIT_EXIT_CODE -ne 0 ]]; then
|
||||
echo "Database initialization failed!"
|
||||
exit 1
|
||||
fi
|
||||
if [[ $INIT_EXIT_CODE -ne 0 ]]; then
|
||||
echo "Database initialization failed!"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "Database migrations done."
|
||||
|
|
|
|||
39
examples/guides/custom_data_models.py
Normal file
39
examples/guides/custom_data_models.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
from pydantic import SkipValidation
|
||||
|
||||
import cognee
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
# Keep it simple for forward refs / mixed values
|
||||
knows: SkipValidation[Any] = None # single Person or list[Person]
|
||||
# Recommended: specify which fields to index for search
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
async def main():
|
||||
# Start clean (optional in your app)
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
alice = Person(name="Alice")
|
||||
bob = Person(name="Bob")
|
||||
charlie = Person(name="Charlie")
|
||||
|
||||
# Create relationships - field name becomes edge label
|
||||
alice.knows = bob
|
||||
# You can also do lists: alice.knows = [bob, charlie]
|
||||
|
||||
# Optional: add weights and custom relationship types
|
||||
bob.knows = (Edge(weight=0.9, relationship_type="friend_of"), charlie)
|
||||
|
||||
await add_data_points([alice, bob, charlie])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
30
examples/guides/custom_prompts.py
Normal file
30
examples/guides/custom_prompts.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
||||
custom_prompt = """
|
||||
Extract only people and cities as entities.
|
||||
Connect people to cities with the relationship "lives_in".
|
||||
Ignore all other entities.
|
||||
"""
|
||||
|
||||
|
||||
async def main():
|
||||
await cognee.add(
|
||||
[
|
||||
"Alice moved to Paris in 2010, while Bob has always lived in New York.",
|
||||
"Andreas was born in Venice, but later settled in Lisbon.",
|
||||
"Diana and Tom were born and raised in Helsinki. Diana currently resides in Berlin, while Tom never moved.",
|
||||
]
|
||||
)
|
||||
await cognee.cognify(custom_prompt=custom_prompt)
|
||||
|
||||
res = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="Where does Alice live?",
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
53
examples/guides/custom_tasks_and_pipelines.py
Normal file
53
examples/guides/custom_tasks_and_pipelines.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from pydantic import BaseModel, SkipValidation
|
||||
|
||||
import cognee
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.modules.pipelines import Task, run_pipeline
|
||||
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
# Optional relationships (we'll let the LLM populate this)
|
||||
knows: List["Person"] = []
|
||||
# Make names searchable in the vector store
|
||||
metadata: Dict[str, Any] = {"index_fields": ["name"]}
|
||||
|
||||
|
||||
class People(BaseModel):
|
||||
persons: List[Person]
|
||||
|
||||
|
||||
async def extract_people(text: str) -> List[Person]:
|
||||
system_prompt = (
|
||||
"Extract people mentioned in the text. "
|
||||
"Return as `persons: Person[]` with each Person having `name` and optional `knows` relations. "
|
||||
"If the text says someone knows someone set `knows` accordingly. "
|
||||
"Only include facts explicitly stated."
|
||||
)
|
||||
people = await LLMGateway.acreate_structured_output(text, system_prompt, People)
|
||||
return people.persons
|
||||
|
||||
|
||||
async def main():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
text = "Alice knows Bob."
|
||||
|
||||
tasks = [
|
||||
Task(extract_people), # input: text -> output: list[Person]
|
||||
Task(add_data_points), # input: list[Person] -> output: list[Person]
|
||||
]
|
||||
|
||||
async for _ in run_pipeline(tasks=tasks, data=text, datasets=["people_demo"]):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
14
examples/guides/graph_visualization.py
Normal file
14
examples/guides/graph_visualization.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
|
||||
|
||||
async def main():
|
||||
await cognee.add(["Alice knows Bob.", "NLP is a subfield of CS."])
|
||||
await cognee.cognify()
|
||||
|
||||
await visualize_graph("./graph_after_cognify.html")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
31
examples/guides/low_level_llm.py
Normal file
31
examples/guides/low_level_llm.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
||||
|
||||
class MiniEntity(BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
|
||||
|
||||
class MiniGraph(BaseModel):
|
||||
nodes: List[MiniEntity]
|
||||
|
||||
|
||||
async def main():
|
||||
system_prompt = (
|
||||
"Extract entities as nodes with name and type. "
|
||||
"Use concise, literal values present in the text."
|
||||
)
|
||||
|
||||
text = "Apple develops iPhone; Audi produces the R8."
|
||||
|
||||
result = await LLMGateway.acreate_structured_output(text, system_prompt, MiniGraph)
|
||||
print(result)
|
||||
# MiniGraph(nodes=[MiniEntity(name='Apple', type='Organization'), ...])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
30
examples/guides/memify_quickstart.py
Normal file
30
examples/guides/memify_quickstart.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
|
||||
async def main():
|
||||
# 1) Add two short chats and build a graph
|
||||
await cognee.add(
|
||||
[
|
||||
"We follow PEP8. Add type hints and docstrings.",
|
||||
"Releases should not be on Friday. Susan must review PRs.",
|
||||
],
|
||||
dataset_name="rules_demo",
|
||||
)
|
||||
await cognee.cognify(datasets=["rules_demo"]) # builds graph
|
||||
|
||||
# 2) Enrich the graph (uses default memify tasks)
|
||||
await cognee.memify(dataset="rules_demo")
|
||||
|
||||
# 3) Query the new coding rules
|
||||
rules = await cognee.search(
|
||||
query_type=SearchType.CODING_RULES,
|
||||
query_text="List coding rules",
|
||||
node_name=["coding_agent_rules"],
|
||||
)
|
||||
print("Rules:", rules)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
290
examples/guides/ontology_input_example/basic_ontology.owl
Normal file
290
examples/guides/ontology_input_example/basic_ontology.owl
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<rdf:RDF
|
||||
xmlns:ns1="http://example.org/ontology#"
|
||||
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
xmlns:rdfs="http://www.w3.org/2000/01/rdf-schema#"
|
||||
>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Volkswagen">
|
||||
<rdfs:comment>Created for making cars accessible to everyone.</rdfs:comment>
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#VW_Golf"/>
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#VW_ID4"/>
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#VW_Touareg"/>
|
||||
<rdf:type rdf:resource="http://example.org/ontology#CarManufacturer"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Azure">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#CloudServiceProvider"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Porsche">
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#Porsche_Cayenne"/>
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#Porsche_Taycan"/>
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#Porsche_911"/>
|
||||
<rdf:type rdf:resource="http://example.org/ontology#CarManufacturer"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
<rdfs:comment>Famous for high-performance sports cars.</rdfs:comment>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Meta">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#TechnologyCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#Instagram"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#Facebook"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#Oculus"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#WhatsApp"/>
|
||||
<rdfs:comment>Pioneering social media and virtual reality technology.</rdfs:comment>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#TechnologyCompany">
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Apple">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#TechnologyCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
<rdfs:comment>Known for its innovative consumer electronics and software.</rdfs:comment>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#iPad"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#iPhone"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#AppleWatch"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#MacBook"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Audi">
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#Audi_eTron"/>
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#Audi_R8"/>
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#Audi_A8"/>
|
||||
<rdf:type rdf:resource="http://example.org/ontology#CarManufacturer"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
<rdfs:comment>Known for its modern designs and technology.</rdfs:comment>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#AmazonEcho">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Porsche_Taycan">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#ElectricCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#BMW">
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#BMW_7Series"/>
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#BMW_M4"/>
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#BMW_iX"/>
|
||||
<rdf:type rdf:resource="http://example.org/ontology#CarManufacturer"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
<rdfs:comment>Focused on performance and driving pleasure.</rdfs:comment>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#VW_Touareg">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SUV"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#SportsCar">
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Car"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#ElectricCar">
|
||||
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Car"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Google">
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#GooglePixel"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#GoogleCloud"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#Android"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#GoogleSearch"/>
|
||||
<rdf:type rdf:resource="http://example.org/ontology#TechnologyCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
<rdfs:comment>Started as a search engine and expanded into cloud computing and AI.</rdfs:comment>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#AmazonPrime">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Car">
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#WindowsOS">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Android">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Oculus">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#GoogleCloud">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#CloudServiceProvider"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Microsoft">
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#Surface"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#WindowsOS"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#Azure"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#Xbox"/>
|
||||
<rdf:type rdf:resource="http://example.org/ontology#TechnologyCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
<rdfs:comment>Dominant in software, cloud computing, and gaming.</rdfs:comment>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#GoogleSearch">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Mercedes_SClass">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#LuxuryCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Audi_A8">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#LuxuryCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Sedan">
|
||||
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Car"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#VW_Golf">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#Sedan"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Facebook">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#WhatsApp">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#produces">
|
||||
<rdfs:domain rdf:resource="http://example.org/ontology#CarManufacturer"/>
|
||||
<rdfs:range rdf:resource="http://example.org/ontology#Car"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#ObjectProperty"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#BMW_7Series">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#LuxuryCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#BMW_M4">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SportsCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Audi_eTron">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#ElectricCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Kindle">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#BMW_iX">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#ElectricCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#SoftwareCompany">
|
||||
<rdfs:subClassOf rdf:resource="http://example.org/ontology#TechnologyCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Audi_R8">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SportsCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Xbox">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Technology">
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Mercedes_EQS">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#ElectricCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Porsche_911">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SportsCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#HardwareCompany">
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
<rdfs:subClassOf rdf:resource="http://example.org/ontology#TechnologyCompany"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#MercedesBenz">
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#Mercedes_SClass"/>
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#Mercedes_EQS"/>
|
||||
<ns1:produces rdf:resource="http://example.org/ontology#Mercedes_AMG_GT"/>
|
||||
<rdfs:comment>Synonymous with luxury and quality.</rdfs:comment>
|
||||
<rdf:type rdf:resource="http://example.org/ontology#CarManufacturer"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Amazon">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#TechnologyCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#Kindle"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#AmazonEcho"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#AWS"/>
|
||||
<ns1:develops rdf:resource="http://example.org/ontology#AmazonPrime"/>
|
||||
<rdfs:comment>From e-commerce to cloud computing giant with AWS.</rdfs:comment>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Instagram">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SoftwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#AWS">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#CloudServiceProvider"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#SUV">
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Car"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#VW_ID4">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#ElectricCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#CloudServiceProvider">
|
||||
<rdfs:subClassOf rdf:resource="http://example.org/ontology#TechnologyCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Surface">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#iPad">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#iPhone">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Mercedes_AMG_GT">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SportsCar"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#MacBook">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#develops">
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#ObjectProperty"/>
|
||||
<rdfs:range rdf:resource="http://example.org/ontology#Technology"/>
|
||||
<rdfs:domain rdf:resource="http://example.org/ontology#TechnologyCompany"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#LuxuryCar">
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Car"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#AppleWatch">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Porsche_Cayenne">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#SUV"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#GooglePixel">
|
||||
<rdf:type rdf:resource="http://example.org/ontology#HardwareCompany"/>
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#NamedIndividual"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#Company">
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
</rdf:Description>
|
||||
<rdf:Description rdf:about="http://example.org/ontology#CarManufacturer">
|
||||
<rdf:type rdf:resource="http://www.w3.org/2002/07/owl#Class"/>
|
||||
<rdfs:subClassOf rdf:resource="http://example.org/ontology#Company"/>
|
||||
</rdf:Description>
|
||||
</rdf:RDF>
|
||||
29
examples/guides/ontology_quickstart.py
Normal file
29
examples/guides/ontology_quickstart.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
import os
|
||||
from cognee.modules.ontology.ontology_config import Config
|
||||
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver
|
||||
|
||||
|
||||
async def main():
|
||||
texts = ["Audi produces the R8 and e-tron.", "Apple develops iPhone and MacBook."]
|
||||
|
||||
await cognee.add(texts)
|
||||
# or: await cognee.add("/path/to/folder/of/files")
|
||||
|
||||
ontology_path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "ontology_input_example/basic_ontology.owl"
|
||||
)
|
||||
|
||||
# Create full config structure manually
|
||||
config: Config = {
|
||||
"ontology_config": {
|
||||
"ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_path)
|
||||
}
|
||||
}
|
||||
|
||||
await cognee.cognify(config=config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
25
examples/guides/s3_storage.py
Normal file
25
examples/guides/s3_storage.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
|
||||
|
||||
async def main():
|
||||
# Single file
|
||||
await cognee.add("s3://cognee-s3-small-test/Natural_language_processing.txt")
|
||||
|
||||
# Folder/prefix (recursively expands)
|
||||
await cognee.add("s3://cognee-s3-small-test")
|
||||
|
||||
# Mixed list
|
||||
await cognee.add(
|
||||
[
|
||||
"s3://cognee-s3-small-test/Natural_language_processing.txt",
|
||||
"Some inline text to ingest",
|
||||
]
|
||||
)
|
||||
|
||||
# Process the data
|
||||
await cognee.cognify()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
52
examples/guides/search_basics_additional.py
Normal file
52
examples/guides/search_basics_additional.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
||||
|
||||
async def main():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
text = """
|
||||
Natural language processing (NLP) is an interdisciplinary
|
||||
subfield of computer science and information retrieval.
|
||||
First rule of coding: Do not talk about coding.
|
||||
"""
|
||||
|
||||
text2 = """
|
||||
Sandwiches are best served toasted with cheese, ham, mayo,
|
||||
lettuce, mustard, and salt & pepper.
|
||||
"""
|
||||
|
||||
await cognee.add(text, dataset_name="NLP_coding")
|
||||
await cognee.add(text2, dataset_name="Sandwiches")
|
||||
await cognee.add(text2)
|
||||
|
||||
await cognee.cognify()
|
||||
|
||||
# Make sure you've already run cognee.cognify(...) so the graph has content
|
||||
answers = await cognee.search(query_text="What are the main themes in my data?")
|
||||
assert len(answers) > 0
|
||||
|
||||
answers = await cognee.search(
|
||||
query_text="List coding guidelines",
|
||||
query_type=SearchType.CODING_RULES,
|
||||
)
|
||||
assert len(answers) > 0
|
||||
|
||||
answers = await cognee.search(
|
||||
query_text="Give me a confident answer: What is NLP?",
|
||||
system_prompt="Answer succinctly and state confidence at the end.",
|
||||
)
|
||||
assert len(answers) > 0
|
||||
|
||||
answers = await cognee.search(
|
||||
query_text="Tell me about NLP",
|
||||
only_context=True,
|
||||
)
|
||||
assert len(answers) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
27
examples/guides/search_basics_core.py
Normal file
27
examples/guides/search_basics_core.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
|
||||
|
||||
async def main():
|
||||
# Start clean (optional in your app)
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
# Prepare knowledge base
|
||||
await cognee.add(
|
||||
[
|
||||
"Alice moved to Paris in 2010. She works as a software engineer.",
|
||||
"Bob lives in New York. He is a data scientist.",
|
||||
"Alice and Bob met at a conference in 2015.",
|
||||
]
|
||||
)
|
||||
|
||||
await cognee.cognify()
|
||||
|
||||
# Make sure you've already run cognee.cognify(...) so the graph has content
|
||||
answers = await cognee.search(query_text="What are the main themes in my data?")
|
||||
for answer in answers:
|
||||
print(answer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
57
examples/guides/temporal_cognify.py
Normal file
57
examples/guides/temporal_cognify.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
|
||||
|
||||
async def main():
|
||||
text = """
|
||||
In 1998 the project launched. In 2001 version 1.0 shipped. In 2004 the team merged
|
||||
with another group. In 2010 support for v1 ended.
|
||||
"""
|
||||
|
||||
await cognee.add(text, dataset_name="timeline_demo")
|
||||
|
||||
await cognee.cognify(datasets=["timeline_demo"], temporal_cognify=True)
|
||||
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
||||
# Before / after queries
|
||||
result = await cognee.search(
|
||||
query_type=SearchType.TEMPORAL, query_text="What happened before 2000?", top_k=10
|
||||
)
|
||||
|
||||
assert result != []
|
||||
|
||||
result = await cognee.search(
|
||||
query_type=SearchType.TEMPORAL, query_text="What happened after 2010?", top_k=10
|
||||
)
|
||||
|
||||
assert result != []
|
||||
|
||||
# Between queries
|
||||
result = await cognee.search(
|
||||
query_type=SearchType.TEMPORAL, query_text="Events between 2001 and 2004", top_k=10
|
||||
)
|
||||
|
||||
assert result != []
|
||||
|
||||
# Scoped descriptions
|
||||
result = await cognee.search(
|
||||
query_type=SearchType.TEMPORAL,
|
||||
query_text="Key project milestones between 1998 and 2010",
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert result != []
|
||||
|
||||
result = await cognee.search(
|
||||
query_type=SearchType.TEMPORAL,
|
||||
query_text="What happened after 2004?",
|
||||
datasets=["timeline_demo"],
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
assert result != []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Add table
Reference in a new issue