From 5fa5bfa68211f645bef20de54e9e73e5408a9c19 Mon Sep 17 00:00:00 2001 From: Mohammad Date: Wed, 10 Sep 2025 14:11:00 +0200 Subject: [PATCH 01/36] feat: add support for AWS session token in S3 configuration --- cognee/infrastructure/files/storage/S3FileStorage.py | 8 +++++++- cognee/infrastructure/files/storage/s3_config.py | 2 +- cognee/tasks/ingestion/resolve_data_directories.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/cognee/infrastructure/files/storage/S3FileStorage.py b/cognee/infrastructure/files/storage/S3FileStorage.py index 7c5a1033c..a0d611241 100644 --- a/cognee/infrastructure/files/storage/S3FileStorage.py +++ b/cognee/infrastructure/files/storage/S3FileStorage.py @@ -21,10 +21,11 @@ class S3FileStorage(Storage): def __init__(self, storage_path: str): self.storage_path = storage_path s3_config = get_s3_config() - if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None: + if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None and s3_config.aws_session_token is not None: self.s3 = s3fs.S3FileSystem( key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, + token=s3_config.aws_session_token, anon=False, endpoint_url=s3_config.aws_endpoint_url, client_kwargs={"region_name": s3_config.aws_region}, @@ -146,6 +147,11 @@ class S3FileStorage(Storage): self.s3.isfile, os.path.join(self.storage_path.replace("s3://", ""), file_path) ) + async def get_size(self, file_path: str) -> int: + return await run_async( + self.s3.size, os.path.join(self.storage_path.replace("s3://", ""), file_path) + ) + async def ensure_directory_exists(self, directory_path: str = ""): """ Ensure that the specified directory exists, creating it if necessary. diff --git a/cognee/infrastructure/files/storage/s3_config.py b/cognee/infrastructure/files/storage/s3_config.py index 0b9372b7e..3b59bcd57 100644 --- a/cognee/infrastructure/files/storage/s3_config.py +++ b/cognee/infrastructure/files/storage/s3_config.py @@ -8,9 +8,9 @@ class S3Config(BaseSettings): aws_endpoint_url: Optional[str] = None aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None + aws_session_token: Optional[str] = None model_config = SettingsConfigDict(env_file=".env", extra="allow") - @lru_cache def get_s3_config(): return S3Config() diff --git a/cognee/tasks/ingestion/resolve_data_directories.py b/cognee/tasks/ingestion/resolve_data_directories.py index 1d3124a0c..cbd979e16 100644 --- a/cognee/tasks/ingestion/resolve_data_directories.py +++ b/cognee/tasks/ingestion/resolve_data_directories.py @@ -32,7 +32,7 @@ async def resolve_data_directories( import s3fs fs = s3fs.S3FileSystem( - key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, anon=False + key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key,token=s3_config.aws_session_token, anon=False ) for item in data: From e2ed2793140a6993d500f823bf656f5ba93af6cd Mon Sep 17 00:00:00 2001 From: Mohammad Date: Wed, 10 Sep 2025 14:14:22 +0200 Subject: [PATCH 02/36] feat: add support for AWS session token in S3 configuration --- cognee/infrastructure/files/storage/s3_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cognee/infrastructure/files/storage/s3_config.py b/cognee/infrastructure/files/storage/s3_config.py index 3b59bcd57..cefe5cd2f 100644 --- a/cognee/infrastructure/files/storage/s3_config.py +++ b/cognee/infrastructure/files/storage/s3_config.py @@ -11,6 +11,7 @@ class S3Config(BaseSettings): aws_session_token: Optional[str] = None model_config = SettingsConfigDict(env_file=".env", extra="allow") + @lru_cache def get_s3_config(): return S3Config() From a0c951336e22a6ee14036c3e51527e2b1723f9ed Mon Sep 17 00:00:00 2001 From: Mohammad Date: Wed, 10 Sep 2025 14:20:42 +0200 Subject: [PATCH 03/36] feat: add support for AWS session token in S3 configuration --- cognee/infrastructure/files/storage/S3FileStorage.py | 6 +++++- cognee/tasks/ingestion/resolve_data_directories.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/cognee/infrastructure/files/storage/S3FileStorage.py b/cognee/infrastructure/files/storage/S3FileStorage.py index a0d611241..078d5fe2a 100644 --- a/cognee/infrastructure/files/storage/S3FileStorage.py +++ b/cognee/infrastructure/files/storage/S3FileStorage.py @@ -21,7 +21,11 @@ class S3FileStorage(Storage): def __init__(self, storage_path: str): self.storage_path = storage_path s3_config = get_s3_config() - if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None and s3_config.aws_session_token is not None: + if ( + s3_config.aws_access_key_id is not None + and s3_config.aws_secret_access_key is not None + and s3_config.aws_session_token is not None + ): self.s3 = s3fs.S3FileSystem( key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, diff --git a/cognee/tasks/ingestion/resolve_data_directories.py b/cognee/tasks/ingestion/resolve_data_directories.py index cbd979e16..aa2f95303 100644 --- a/cognee/tasks/ingestion/resolve_data_directories.py +++ b/cognee/tasks/ingestion/resolve_data_directories.py @@ -32,7 +32,10 @@ async def resolve_data_directories( import s3fs fs = s3fs.S3FileSystem( - key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key,token=s3_config.aws_session_token, anon=False + key=s3_config.aws_access_key_id, + secret=s3_config.aws_secret_access_key, + token=s3_config.aws_session_token, + anon=False, ) for item in data: From 8ad3ab23285a2865a0e1455081a5ece5e8104c7a Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Wed, 10 Sep 2025 15:06:37 +0200 Subject: [PATCH 04/36] fix: Allow S3 usage without token --- .env.template | 10 ++++++++++ cognee/infrastructure/files/storage/S3FileStorage.py | 6 +----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.env.template b/.env.template index 28980de74..b90ad7525 100644 --- a/.env.template +++ b/.env.template @@ -155,6 +155,16 @@ LITELLM_LOG="ERROR" # DEFAULT_USER_EMAIL="" # DEFAULT_USER_PASSWORD="" +################################################################################ +# 📂 AWS Settings +################################################################################ + +#AWS_REGION="" +#AWS_ENDPOINT_URL="" +#AWS_ACCESS_KEY_ID="" +#AWS_SECRET_ACCESS_KEY="" +#AWS_SESSION_TOKEN="" + ------------------------------- END OF POSSIBLE SETTINGS ------------------------------- diff --git a/cognee/infrastructure/files/storage/S3FileStorage.py b/cognee/infrastructure/files/storage/S3FileStorage.py index 078d5fe2a..4284dcac2 100644 --- a/cognee/infrastructure/files/storage/S3FileStorage.py +++ b/cognee/infrastructure/files/storage/S3FileStorage.py @@ -21,11 +21,7 @@ class S3FileStorage(Storage): def __init__(self, storage_path: str): self.storage_path = storage_path s3_config = get_s3_config() - if ( - s3_config.aws_access_key_id is not None - and s3_config.aws_secret_access_key is not None - and s3_config.aws_session_token is not None - ): + if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None: self.s3 = s3fs.S3FileSystem( key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, From e3494ca15f43601e3105411ef6be713556ab3ce5 Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Thu, 25 Sep 2025 17:43:34 +0100 Subject: [PATCH 05/36] feat: add mcp status display to frontend --- cognee-frontend/src/ui/Layout/Header.tsx | 11 ++++++++++- cognee-frontend/src/ui/elements/StatusDot.tsx | 13 +++++++++++++ cognee-frontend/src/ui/elements/index.ts | 1 + 3 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 cognee-frontend/src/ui/elements/StatusDot.tsx diff --git a/cognee-frontend/src/ui/Layout/Header.tsx b/cognee-frontend/src/ui/Layout/Header.tsx index 7a1d2e906..69968076a 100644 --- a/cognee-frontend/src/ui/Layout/Header.tsx +++ b/cognee-frontend/src/ui/Layout/Header.tsx @@ -5,7 +5,7 @@ import Image from "next/image"; import { useBoolean } from "@/utils"; import { CloseIcon, CloudIcon, CogneeIcon } from "../Icons"; -import { CTAButton, GhostButton, IconButton, Modal } from "../elements"; +import { CTAButton, GhostButton, IconButton, Modal, StatusDot } from "../elements"; import syncData from "@/modules/cloud/syncData"; interface HeaderProps { @@ -23,6 +23,11 @@ export default function Header({ user }: HeaderProps) { setFalse: closeSyncModal, } = useBoolean(false); + const { + value: isMCPStatusOpen, + setTrue: setMCPStatusOpen, + } = useBoolean(false); + const handleDataSyncConfirm = () => { syncData() .finally(() => { @@ -39,6 +44,10 @@ export default function Header({ user }: HeaderProps) {
+ + + MCP status +
Sync
diff --git a/cognee-frontend/src/ui/elements/StatusDot.tsx b/cognee-frontend/src/ui/elements/StatusDot.tsx new file mode 100644 index 000000000..4eb71a6e0 --- /dev/null +++ b/cognee-frontend/src/ui/elements/StatusDot.tsx @@ -0,0 +1,13 @@ +import React from "react"; + +const StatusDot = ({ isActive, className }: { isActive: boolean, className?: string }) => { + return ( + + ); +}; + +export default StatusDot; diff --git a/cognee-frontend/src/ui/elements/index.ts b/cognee-frontend/src/ui/elements/index.ts index 551b06596..0133f56f6 100644 --- a/cognee-frontend/src/ui/elements/index.ts +++ b/cognee-frontend/src/ui/elements/index.ts @@ -8,5 +8,6 @@ export { default as IconButton } from "./IconButton"; export { default as GhostButton } from "./GhostButton"; export { default as NeutralButton } from "./NeutralButton"; export { default as StatusIndicator } from "./StatusIndicator"; +export { default as StatusDot } from "./StatusDot"; export { default as Accordion } from "./Accordion"; export { default as Notebook } from "./Notebook"; From 38e3f11533bceefe9776cd130269c742b850c2a4 Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Thu, 25 Sep 2025 20:42:40 +0100 Subject: [PATCH 06/36] fix: update entrypoint script to use cognee-mcp module --- cognee-mcp/entrypoint.sh | 18 +++++++++--------- cognee-mcp/pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/cognee-mcp/entrypoint.sh b/cognee-mcp/entrypoint.sh index 53da83c11..e3ff849e0 100644 --- a/cognee-mcp/entrypoint.sh +++ b/cognee-mcp/entrypoint.sh @@ -48,27 +48,27 @@ if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then if [ "$DEBUG" = "true" ]; then echo "Waiting for the debugger to attach..." if [ "$TRANSPORT_MODE" = "sse" ]; then - exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration + exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration elif [ "$TRANSPORT_MODE" = "http" ]; then - exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration + exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration else - exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport stdio --no-migration + exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport stdio --no-migration fi else if [ "$TRANSPORT_MODE" = "sse" ]; then - exec cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration + exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration elif [ "$TRANSPORT_MODE" = "http" ]; then - exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration + exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration else - exec cognee --transport stdio --no-migration + exec cognee-mcp --transport stdio --no-migration fi fi else if [ "$TRANSPORT_MODE" = "sse" ]; then - exec cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration + exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration elif [ "$TRANSPORT_MODE" = "http" ]; then - exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration + exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration else - exec cognee --transport stdio --no-migration + exec cognee-mcp --transport stdio --no-migration fi fi diff --git a/cognee-mcp/pyproject.toml b/cognee-mcp/pyproject.toml index a1ee22985..f22396bd4 100644 --- a/cognee-mcp/pyproject.toml +++ b/cognee-mcp/pyproject.toml @@ -36,4 +36,4 @@ dev = [ allow-direct-references = true [project.scripts] -cognee = "src:main" +cognee-mcp = "src:main" From 921c4481f034a5e22968e8067781b1ba63504334 Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Thu, 25 Sep 2025 22:04:06 +0100 Subject: [PATCH 07/36] feat: start cognee-mcp as part of cognee -ui --- cognee/api/v1/ui/__init__.py | 2 +- cognee/api/v1/ui/ui.py | 140 +++++++++-------------------------- cognee/cli/_cognee.py | 1 + 3 files changed, 39 insertions(+), 104 deletions(-) diff --git a/cognee/api/v1/ui/__init__.py b/cognee/api/v1/ui/__init__.py index f268a2e54..03876e999 100644 --- a/cognee/api/v1/ui/__init__.py +++ b/cognee/api/v1/ui/__init__.py @@ -1 +1 @@ -from .ui import start_ui, stop_ui, ui +from .ui import start_ui, ui diff --git a/cognee/api/v1/ui/ui.py b/cognee/api/v1/ui/ui.py index 6faca19e8..fb20b0420 100644 --- a/cognee/api/v1/ui/ui.py +++ b/cognee/api/v1/ui/ui.py @@ -334,17 +334,19 @@ def start_ui( start_backend: bool = False, backend_host: str = "localhost", backend_port: int = 8000, + start_mcp: bool = False, ) -> Optional[subprocess.Popen]: """ - Start the cognee frontend UI server, optionally with the backend API server. + Start the cognee frontend UI server, optionally with the backend API server and MCP server. This function will: 1. Optionally start the cognee backend API server - 2. Find the cognee-frontend directory (development) or download it (pip install) - 3. Check if Node.js and npm are available (for development mode) - 4. Install dependencies if needed (development mode) - 5. Start the frontend server - 6. Optionally open the browser + 2. Optionally start the cognee MCP server + 3. Find the cognee-frontend directory (development) or download it (pip install) + 4. Check if Node.js and npm are available (for development mode) + 5. Install dependencies if needed (development mode) + 6. Start the frontend server + 7. Optionally open the browser Args: pid_callback: Callback to notify with PID of each spawned process @@ -355,11 +357,12 @@ def start_ui( start_backend: If True, also start the cognee API backend server (default: False) backend_host: Host to bind the backend server to (default: localhost) backend_port: Port to run the backend server on (default: 8000) + start_mcp: If True, also start the cognee MCP server on port 8001 (default: False) Returns: subprocess.Popen object representing the running frontend server, or None if failed - Note: If backend is started, it runs in a separate process that will be cleaned up - when the frontend process is terminated. + Note: If backend and/or MCP server are started, they run in separate processes + that will be cleaned up when the frontend process is terminated. Example: >>> import cognee @@ -370,12 +373,37 @@ def start_ui( >>> server = cognee.start_ui(start_backend=True) >>> # UI will be available at http://localhost:3000 >>> # API will be available at http://localhost:8000 - >>> # To stop both servers later: + >>> + >>> # Start frontend with MCP server + >>> server = cognee.start_ui(start_mcp=True) + >>> # UI will be available at http://localhost:3000 + >>> # MCP server will be available at http://127.0.0.1:8001/sse + >>> # To stop all servers later: >>> server.terminate() """ logger.info("Starting cognee UI...") backend_process = None + if start_mcp: + logger.info("Starting Cognee MCP server with Docker...") + cwd = os.getcwd() + env_file = os.path.join(cwd, ".env") + try: + mcp_process = subprocess.Popen( + [ + "docker", "run", + "-p", "8001:8000", + "--rm", + "--env-file", env_file, + "-e", "TRANSPORT_MODE=sse", + "cognee/cognee-mcp:daulet-dev" + ], + preexec_fn=os.setsid if hasattr(os, "setsid") else None, + ) + pid_callback(mcp_process.pid) + logger.info("✓ Cognee MCP server starting on http://127.0.0.1:8001/sse") + except Exception as e: + logger.error(f"Failed to start MCP server with Docker: {str(e)}") # Start backend server if requested if start_backend: logger.info("Starting cognee backend API server...") @@ -502,10 +530,6 @@ def start_ui( logger.info(f"✓ Open your browser to: http://{host}:{port}") logger.info("✓ The UI will be available once Next.js finishes compiling") - # Store backend process reference in the frontend process for cleanup - if backend_process: - process._cognee_backend_process = backend_process - return process except Exception as e: @@ -525,86 +549,6 @@ def start_ui( return None -def stop_ui(process: subprocess.Popen) -> bool: - """ - Stop a running UI server process and backend process (if started), along with all their children. - - Args: - process: The subprocess.Popen object returned by start_ui() - - Returns: - bool: True if stopped successfully, False otherwise - """ - if not process: - return False - - success = True - - try: - # First, stop the backend process if it exists - backend_process = getattr(process, "_cognee_backend_process", None) - if backend_process: - logger.info("Stopping backend server...") - try: - backend_process.terminate() - try: - backend_process.wait(timeout=5) - logger.info("Backend server stopped gracefully") - except subprocess.TimeoutExpired: - logger.warning("Backend didn't terminate gracefully, forcing kill") - backend_process.kill() - backend_process.wait() - logger.info("Backend server stopped") - except Exception as e: - logger.error(f"Error stopping backend server: {str(e)}") - success = False - - # Now stop the frontend process - logger.info("Stopping frontend server...") - # Try to terminate the process group (includes child processes like Next.js) - if hasattr(os, "killpg"): - try: - # Kill the entire process group - os.killpg(os.getpgid(process.pid), signal.SIGTERM) - logger.debug("Sent SIGTERM to process group") - except (OSError, ProcessLookupError): - # Fall back to terminating just the main process - process.terminate() - logger.debug("Terminated main process only") - else: - process.terminate() - logger.debug("Terminated main process (Windows)") - - try: - process.wait(timeout=10) - logger.info("Frontend server stopped gracefully") - except subprocess.TimeoutExpired: - logger.warning("Frontend didn't terminate gracefully, forcing kill") - - # Force kill the process group - if hasattr(os, "killpg"): - try: - os.killpg(os.getpgid(process.pid), signal.SIGKILL) - logger.debug("Sent SIGKILL to process group") - except (OSError, ProcessLookupError): - process.kill() - logger.debug("Force killed main process only") - else: - process.kill() - logger.debug("Force killed main process (Windows)") - - process.wait() - - if success: - logger.info("UI servers stopped successfully") - - return success - - except Exception as e: - logger.error(f"Error stopping UI servers: {str(e)}") - return False - - # Convenience function similar to DuckDB's approach def ui() -> Optional[subprocess.Popen]: """ @@ -612,13 +556,3 @@ def ui() -> Optional[subprocess.Popen]: Similar to how DuckDB provides simple ui() function. """ return start_ui() - - -if __name__ == "__main__": - # Test the UI startup - server = start_ui() - if server: - try: - input("Press Enter to stop the server...") - finally: - stop_ui(server) diff --git a/cognee/cli/_cognee.py b/cognee/cli/_cognee.py index 52915594b..6010ea679 100644 --- a/cognee/cli/_cognee.py +++ b/cognee/cli/_cognee.py @@ -209,6 +209,7 @@ def main() -> int: port=3000, open_browser=True, start_backend=True, + start_mcp=True, auto_download=True, pid_callback=pid_callback, ) From 80da5531853059de2960140e091eafe2ae8612e3 Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Thu, 25 Sep 2025 22:04:41 +0100 Subject: [PATCH 08/36] format: ruff format --- cognee/api/v1/ui/ui.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/cognee/api/v1/ui/ui.py b/cognee/api/v1/ui/ui.py index fb20b0420..3b583ac33 100644 --- a/cognee/api/v1/ui/ui.py +++ b/cognee/api/v1/ui/ui.py @@ -361,7 +361,7 @@ def start_ui( Returns: subprocess.Popen object representing the running frontend server, or None if failed - Note: If backend and/or MCP server are started, they run in separate processes + Note: If backend and/or MCP server are started, they run in separate processes that will be cleaned up when the frontend process is terminated. Example: @@ -391,12 +391,16 @@ def start_ui( try: mcp_process = subprocess.Popen( [ - "docker", "run", - "-p", "8001:8000", + "docker", + "run", + "-p", + "8001:8000", "--rm", - "--env-file", env_file, - "-e", "TRANSPORT_MODE=sse", - "cognee/cognee-mcp:daulet-dev" + "--env-file", + env_file, + "-e", + "TRANSPORT_MODE=sse", + "cognee/cognee-mcp:daulet-dev", ], preexec_fn=os.setsid if hasattr(os, "setsid") else None, ) From a68401ee70ee31b8d01c6b3c7fefd71116845d8d Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Fri, 26 Sep 2025 12:54:09 +0100 Subject: [PATCH 09/36] chore: update MCP status text to connected/disconnected --- cognee-frontend/src/ui/Layout/Header.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee-frontend/src/ui/Layout/Header.tsx b/cognee-frontend/src/ui/Layout/Header.tsx index 69968076a..30bf7ddb0 100644 --- a/cognee-frontend/src/ui/Layout/Header.tsx +++ b/cognee-frontend/src/ui/Layout/Header.tsx @@ -46,7 +46,7 @@ export default function Header({ user }: HeaderProps) {
- MCP status + { isMCPStatusOpen ? "MCP connected" : "MCP disconnected" } From c518f149f252c96ba8b2195bf14a8ca6ff4bdf5b Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Fri, 26 Sep 2025 14:26:43 +0100 Subject: [PATCH 10/36] refactor: streamline UI server startup and port availability checks --- cognee/api/v1/ui/__init__.py | 2 +- cognee/api/v1/ui/ui.py | 96 ++++++++++++++++++++++++++---------- cognee/cli/_cognee.py | 21 +++++--- examples/start_ui_example.py | 5 +- 4 files changed, 89 insertions(+), 35 deletions(-) diff --git a/cognee/api/v1/ui/__init__.py b/cognee/api/v1/ui/__init__.py index 03876e999..d5708da5a 100644 --- a/cognee/api/v1/ui/__init__.py +++ b/cognee/api/v1/ui/__init__.py @@ -1 +1 @@ -from .ui import start_ui, ui +from .ui import start_ui diff --git a/cognee/api/v1/ui/ui.py b/cognee/api/v1/ui/ui.py index 3b583ac33..af499421b 100644 --- a/cognee/api/v1/ui/ui.py +++ b/cognee/api/v1/ui/ui.py @@ -1,5 +1,6 @@ import os import signal +import socket import subprocess import threading import time @@ -7,7 +8,7 @@ import webbrowser import zipfile import requests from pathlib import Path -from typing import Callable, Optional, Tuple +from typing import Callable, Optional, Tuple, List import tempfile import shutil @@ -17,6 +18,40 @@ from cognee.version import get_cognee_version logger = get_logger() +def _is_port_available(port: int) -> bool: + """ + Check if a port is available on localhost. + Returns True if the port is available, False otherwise. + """ + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(1) # 1 second timeout + result = sock.connect_ex(("localhost", port)) + return result != 0 # Port is available if connection fails + except Exception: + return False + + +def _check_required_ports(ports_to_check: List[Tuple[int, str]]) -> Tuple[bool, List[str]]: + """ + Check if all required ports are available on localhost. + + Args: + ports_to_check: List of (port, service_name) tuples + + Returns: + Tuple of (all_available: bool, unavailable_services: List[str]) + """ + unavailable = [] + + for port, service_name in ports_to_check: + if not _is_port_available(port): + unavailable.append(f"{service_name} (port {port})") + logger.error(f"Port {port} is already in use for {service_name}") + + return len(unavailable) == 0, unavailable + + def normalize_version_for_comparison(version: str) -> str: """ Normalize version string for comparison. @@ -327,14 +362,13 @@ def prompt_user_for_download() -> bool: def start_ui( pid_callback: Callable[[int], None], - host: str = "localhost", port: int = 3000, open_browser: bool = True, auto_download: bool = False, start_backend: bool = False, - backend_host: str = "localhost", backend_port: int = 8000, start_mcp: bool = False, + mcp_port: int = 8001, ) -> Optional[subprocess.Popen]: """ Start the cognee frontend UI server, optionally with the backend API server and MCP server. @@ -350,14 +384,13 @@ def start_ui( Args: pid_callback: Callback to notify with PID of each spawned process - host: Host to bind the frontend server to (default: localhost) port: Port to run the frontend server on (default: 3000) open_browser: Whether to open the browser automatically (default: True) auto_download: If True, download frontend without prompting (default: False) start_backend: If True, also start the cognee API backend server (default: False) - backend_host: Host to bind the backend server to (default: localhost) backend_port: Port to run the backend server on (default: 8000) - start_mcp: If True, also start the cognee MCP server on port 8001 (default: False) + start_mcp: If True, also start the cognee MCP server (default: False) + mcp_port: Port to run the MCP server on (default: 8001) Returns: subprocess.Popen object representing the running frontend server, or None if failed @@ -366,22 +399,42 @@ def start_ui( Example: >>> import cognee + >>> def dummy_callback(pid): pass >>> # Start just the frontend - >>> server = cognee.start_ui() + >>> server = cognee.start_ui(dummy_callback) >>> >>> # Start both frontend and backend - >>> server = cognee.start_ui(start_backend=True) + >>> server = cognee.start_ui(dummy_callback, start_backend=True) >>> # UI will be available at http://localhost:3000 >>> # API will be available at http://localhost:8000 >>> >>> # Start frontend with MCP server - >>> server = cognee.start_ui(start_mcp=True) + >>> server = cognee.start_ui(dummy_callback, start_mcp=True) >>> # UI will be available at http://localhost:3000 >>> # MCP server will be available at http://127.0.0.1:8001/sse >>> # To stop all servers later: >>> server.terminate() """ logger.info("Starting cognee UI...") + + ports_to_check = [(port, "Frontend UI")] + + if start_backend: + ports_to_check.append((backend_port, "Backend API")) + + if start_mcp: + ports_to_check.append((mcp_port, "MCP Server")) + + logger.info("Checking port availability...") + all_ports_available, unavailable_services = _check_required_ports(ports_to_check) + + if not all_ports_available: + error_msg = f"Cannot start cognee UI: The following services have ports already in use: {', '.join(unavailable_services)}" + logger.error(error_msg) + logger.error("Please stop the conflicting services or change the port configuration.") + return None + + logger.info("✓ All required ports are available") backend_process = None if start_mcp: @@ -394,7 +447,7 @@ def start_ui( "docker", "run", "-p", - "8001:8000", + f"{mcp_port}:8000", "--rm", "--env-file", env_file, @@ -405,7 +458,7 @@ def start_ui( preexec_fn=os.setsid if hasattr(os, "setsid") else None, ) pid_callback(mcp_process.pid) - logger.info("✓ Cognee MCP server starting on http://127.0.0.1:8001/sse") + logger.info(f"✓ Cognee MCP server starting on http://127.0.0.1:{mcp_port}/sse") except Exception as e: logger.error(f"Failed to start MCP server with Docker: {str(e)}") # Start backend server if requested @@ -421,7 +474,7 @@ def start_ui( "uvicorn", "cognee.api.client:app", "--host", - backend_host, + "localhost", "--port", str(backend_port), ], @@ -440,7 +493,7 @@ def start_ui( logger.error("Backend server failed to start - process exited early") return None - logger.info(f"✓ Backend API started at http://{backend_host}:{backend_port}") + logger.info(f"✓ Backend API started at http://localhost:{backend_port}") except Exception as e: logger.error(f"Failed to start backend server: {str(e)}") @@ -485,11 +538,11 @@ def start_ui( # Prepare environment variables env = os.environ.copy() - env["HOST"] = host + env["HOST"] = "localhost" env["PORT"] = str(port) # Start the development server - logger.info(f"Starting frontend server at http://{host}:{port}") + logger.info(f"Starting frontend server at http://localhost:{port}") logger.info("This may take a moment to compile and start...") try: @@ -523,7 +576,7 @@ def start_ui( def open_browser_delayed(): time.sleep(5) # Give Next.js time to fully start try: - webbrowser.open(f"http://{host}:{port}") # TODO: use dashboard url? + webbrowser.open(f"http://localhost:{port}") except Exception as e: logger.warning(f"Could not open browser automatically: {e}") @@ -531,7 +584,7 @@ def start_ui( browser_thread.start() logger.info("✓ Cognee UI is starting up...") - logger.info(f"✓ Open your browser to: http://{host}:{port}") + logger.info(f"✓ Open your browser to: http://localhost:{port}") logger.info("✓ The UI will be available once Next.js finishes compiling") return process @@ -551,12 +604,3 @@ def start_ui( except (OSError, ProcessLookupError): pass return None - - -# Convenience function similar to DuckDB's approach -def ui() -> Optional[subprocess.Popen]: - """ - Convenient alias for start_ui() with default parameters. - Similar to how DuckDB provides simple ui() function. - """ - return start_ui() diff --git a/cognee/cli/_cognee.py b/cognee/cli/_cognee.py index 6010ea679..7f2b06c89 100644 --- a/cognee/cli/_cognee.py +++ b/cognee/cli/_cognee.py @@ -204,20 +204,27 @@ def main() -> int: nonlocal spawned_pids spawned_pids.append(pid) + frontend_port = 3000 + start_backend, backend_port = True, 8000 + start_mcp, mcp_port = True, 8001 server_process = start_ui( - host="localhost", - port=3000, - open_browser=True, - start_backend=True, - start_mcp=True, - auto_download=True, pid_callback=pid_callback, + port=frontend_port, + open_browser=True, + auto_download=True, + start_backend=start_backend, + backend_port=backend_port, + start_mcp=start_mcp, + mcp_port=mcp_port, ) if server_process: fmt.success("UI server started successfully!") fmt.echo("The interface is available at: http://localhost:3000") - fmt.echo("The API backend is available at: http://localhost:8000") + if start_backend: + fmt.echo(f"The API backend is available at: http://localhost:{backend_port}") + if start_mcp: + fmt.echo(f"The MCP server is available at: http://localhost:{mcp_port}") fmt.note("Press Ctrl+C to stop the server...") try: diff --git a/examples/start_ui_example.py b/examples/start_ui_example.py index 55796727b..1fb29d239 100644 --- a/examples/start_ui_example.py +++ b/examples/start_ui_example.py @@ -29,8 +29,11 @@ async def main(): print("=" * 60) # Start the UI server + def dummy_callback(pid): + pass + server = cognee.start_ui( - host="localhost", + pid_callback=dummy_callback, port=3000, open_browser=True, # This will automatically open your browser ) From 056da9699558eca2ae5c2a541ea39be2dbad4905 Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Fri, 26 Sep 2025 14:32:15 +0100 Subject: [PATCH 11/36] feat: add logging distinction for mcp/backend/frontend processes for clearer output --- cognee/api/v1/ui/ui.py | 65 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 8 deletions(-) diff --git a/cognee/api/v1/ui/ui.py b/cognee/api/v1/ui/ui.py index af499421b..4d5674832 100644 --- a/cognee/api/v1/ui/ui.py +++ b/cognee/api/v1/ui/ui.py @@ -1,5 +1,4 @@ import os -import signal import socket import subprocess import threading @@ -18,6 +17,46 @@ from cognee.version import get_cognee_version logger = get_logger() +def _stream_process_output( + process: subprocess.Popen, stream_name: str, prefix: str, color_code: str = "" +) -> threading.Thread: + """ + Stream output from a process with a prefix to identify the source. + + Args: + process: The subprocess to monitor + stream_name: 'stdout' or 'stderr' + prefix: Text prefix for each line (e.g., '[BACKEND]', '[FRONTEND]') + color_code: ANSI color code for the prefix (optional) + + Returns: + Thread that handles the streaming + """ + + def stream_reader(): + stream = getattr(process, stream_name) + if stream is None: + return + + reset_code = "\033[0m" if color_code else "" + + try: + for line in iter(stream.readline, b""): + if line: + line_text = line.decode("utf-8").rstrip() + if line_text: + print(f"{color_code}{prefix}{reset_code} {line_text}", flush=True) + except Exception: + pass + finally: + if stream: + stream.close() + + thread = threading.Thread(target=stream_reader, daemon=True) + thread.start() + return thread + + def _is_port_available(port: int) -> bool: """ Check if a port is available on localhost. @@ -455,8 +494,14 @@ def start_ui( "TRANSPORT_MODE=sse", "cognee/cognee-mcp:daulet-dev", ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, preexec_fn=os.setsid if hasattr(os, "setsid") else None, ) + + _stream_process_output(mcp_process, "stdout", "[MCP]", "\033[34m") # Blue + _stream_process_output(mcp_process, "stderr", "[MCP]", "\033[34m") # Blue + pid_callback(mcp_process.pid) logger.info(f"✓ Cognee MCP server starting on http://127.0.0.1:{mcp_port}/sse") except Exception as e: @@ -478,12 +523,15 @@ def start_ui( "--port", str(backend_port), ], - # Inherit stdout/stderr from parent process to show logs - stdout=None, - stderr=None, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, preexec_fn=os.setsid if hasattr(os, "setsid") else None, ) + # Start threads to stream backend output with prefix + _stream_process_output(backend_process, "stdout", "[BACKEND]", "\033[32m") # Green + _stream_process_output(backend_process, "stderr", "[BACKEND]", "\033[32m") # Green + pid_callback(backend_process.pid) # Give the backend a moment to start @@ -557,6 +605,10 @@ def start_ui( preexec_fn=os.setsid if hasattr(os, "setsid") else None, ) + # Start threads to stream frontend output with prefix + _stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow + _stream_process_output(process, "stderr", "[FRONTEND]", "\033[33m") # Yellow + pid_callback(process.pid) # Give it a moment to start up @@ -564,10 +616,7 @@ def start_ui( # Check if process is still running if process.poll() is not None: - stdout, stderr = process.communicate() - logger.error("Frontend server failed to start:") - logger.error(f"stdout: {stdout}") - logger.error(f"stderr: {stderr}") + logger.error("Frontend server failed to start - check the logs above for details") return None # Open browser if requested From b7441f81cdf6775110cc050491fd4c1beb52e545 Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Fri, 26 Sep 2025 16:29:14 +0100 Subject: [PATCH 12/36] feat: add health check endpoint to MCP server --- cognee-mcp/src/server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index 9393fe71b..7670db9f4 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -19,6 +19,7 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline from cognee.modules.search.types import SearchType from cognee.shared.data_models import KnowledgeGraph from cognee.modules.storage.utils import JSONEncoder +from starlette.responses import JSONResponse try: @@ -37,6 +38,9 @@ mcp = FastMCP("Cognee") logger = get_logger() +@mcp.custom_route("/health", methods=["GET"]) +async def health_check(request) -> dict: + return JSONResponse({"status": "ok"}) @mcp.tool() async def cognee_add_developer_rules( From 143d9433b1620ce0925df34e7608c7c0980e1db0 Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Fri, 26 Sep 2025 17:53:17 +0100 Subject: [PATCH 13/36] refactor: remove text parameter from subprocess call in UI startup --- cognee-mcp/src/server.py | 2 ++ cognee/api/v1/ui/ui.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index 7670db9f4..f249f1d08 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -38,10 +38,12 @@ mcp = FastMCP("Cognee") logger = get_logger() + @mcp.custom_route("/health", methods=["GET"]) async def health_check(request) -> dict: return JSONResponse({"status": "ok"}) + @mcp.tool() async def cognee_add_developer_rules( base_path: str = ".", graph_model_file: str = None, graph_model_name: str = None diff --git a/cognee/api/v1/ui/ui.py b/cognee/api/v1/ui/ui.py index 4d5674832..7df0b519a 100644 --- a/cognee/api/v1/ui/ui.py +++ b/cognee/api/v1/ui/ui.py @@ -601,7 +601,6 @@ def start_ui( env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - text=True, preexec_fn=os.setsid if hasattr(os, "setsid") else None, ) From 39fa0180f32edfe6815f4db7e5c40400383fbebc Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Fri, 26 Sep 2025 22:42:39 +0200 Subject: [PATCH 14/36] refactor: Make relational database search more effective --- .../ingestion/migrate_relational_database.py | 6 +-- .../relational_database_migration_example.py | 46 +++++++++++++------ 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index 936ea59e0..82319e9f5 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -38,7 +38,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True table_node = TableType( id=uuid5(NAMESPACE_OID, name=table_name), name=table_name, - description=f"Table: {table_name}", + description=f'Relational database table with the following name: "{table_name}".', ) # Add TableType node to mapping ( node will be added to the graph later based on this mapping ) @@ -75,7 +75,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True name=node_id, is_a=table_node, properties=str(row_properties), - description=f"Row in {table_name} with {primary_key_col}={primary_key_value}", + description=f'Row in relational database table from the table with the name: "{table_name}" with the following row data {str(row_properties)} where the dictionary key value is the column name and the value is the column value. This row has the id of: {node_id}', ) # Store the node object in our mapping @@ -113,7 +113,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True id=uuid5(NAMESPACE_OID, name=column_node_id), name=column_node_id, properties=f"{key} {value} {table_name}", - description=f"Column name={key} and value={value} from column from table={table_name}", + description=f"column from relational database table={table_name}. Column name={key} and value={value}. The value of the column is related to the following row with this id: {row_node.id}. This column has the following ID: {column_node_id}", ) node_mapping[column_node_id] = column_node diff --git a/examples/python/relational_database_migration_example.py b/examples/python/relational_database_migration_example.py index fae8cfb3d..6a5c3b78b 100644 --- a/examples/python/relational_database_migration_example.py +++ b/examples/python/relational_database_migration_example.py @@ -1,16 +1,15 @@ +from pathlib import Path import asyncio - -import cognee import os +import cognee +from cognee.infrastructure.databases.relational.config import get_migration_config from cognee.infrastructure.databases.graph import get_graph_engine from cognee.api.v1.visualize.visualize import visualize_graph from cognee.infrastructure.databases.relational import ( get_migration_relational_engine, ) - from cognee.modules.search.types import SearchType - from cognee.infrastructure.databases.relational import ( create_db_and_tables as create_relational_db_and_tables, ) @@ -32,16 +31,29 @@ from cognee.infrastructure.databases.vector.pgvector import ( async def main(): - engine = get_migration_relational_engine() - # Clean all data stored in Cognee await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - # Needed to create appropriate tables only on the Cognee side + # Needed to create appropriate database tables only on the Cognee side await create_relational_db_and_tables() await create_vector_db_and_tables() + # In case environment variables are not set use the example database from the Cognee repo + migration_db_provider = os.environ.get("MIGRATION_DB_PROVIDER", "sqlite") + migration_db_path = os.environ.get( + "MIGRATION_DB_PATH", + os.path.join(Path(__file__).resolve().parent.parent.parent, "cognee/tests/test_data"), + ) + migration_db_name = os.environ.get("MIGRATION_DB_NAME", "migration_database.sqlite") + + migration_config = get_migration_config() + migration_config.migration_db_provider = migration_db_provider + migration_config.migration_db_path = migration_db_path + migration_config.migration_db_name = migration_db_name + + engine = get_migration_relational_engine() + print("\nExtracting schema of database to migrate.") schema = await engine.extract_schema() print(f"Migrated database schema:\n{schema}") @@ -53,10 +65,6 @@ async def main(): await migrate_relational_database(graph, schema=schema) print("Relational database migration complete.") - # Define location where to store html visualization of graph of the migrated database - home_dir = os.path.expanduser("~") - destination_file_path = os.path.join(home_dir, "graph_visualization.html") - # Make sure to set top_k at a high value for a broader search, the default value is only 10! # top_k represent the number of graph tripplets to supply to the LLM to answer your question search_results = await cognee.search( @@ -69,13 +77,25 @@ async def main(): # Having a top_k value set to too high might overwhelm the LLM context when specific questions need to be answered. # For this kind of question we've set the top_k to 30 search_results = await cognee.search( - query_type=SearchType.GRAPH_COMPLETION_COT, + query_type=SearchType.GRAPH_COMPLETION, query_text="What invoices are related to Leonie Köhler?", top_k=30, ) print(f"Search results: {search_results}") - # test.html is a file with visualized data migration + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What invoices are related to Luís Gonçalves?", + top_k=30, + ) + print(f"Search results: {search_results}") + + # If you check the relational database for this example you can see that the search results successfully found all + # the invoices related to the two customers, without any hallucinations or additional information + + # Define location where to store html visualization of graph of the migrated database + home_dir = os.path.expanduser("~") + destination_file_path = os.path.join(home_dir, "graph_visualization.html") print("Adding html visualization of graph database after migration.") await visualize_graph(destination_file_path) print(f"Visualization can be found at: {destination_file_path}") From 9d801f5fe0a37267f2ece139152ffa79a280d439 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Fri, 12 Sep 2025 10:42:47 +0530 Subject: [PATCH 15/36] Done creating models.py and ingest_database_schema.py --- cognee/tasks/schema/ingest_database_schema.py | 22 +++++++++++++ cognee/tasks/schema/models.py | 32 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 cognee/tasks/schema/ingest_database_schema.py create mode 100644 cognee/tasks/schema/models.py diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py new file mode 100644 index 000000000..6f9c538cd --- /dev/null +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -0,0 +1,22 @@ +from typing import List, Dict +from cognee.infrastructure.engine.models.DataPoint import DataPoint + +async def ingest_database_schema( + database_config: Dict, + schema_name: str = "default", + max_sample_rows: int = 5, + node_set: List[str] = ["database_schema"] +) -> List[DataPoint]: + """ + Ingest database schema with sample data into dedicated nodeset + + Args: + database_config: Database connection configuration + schema_name: Name identifier for this schema + max_sample_rows: Maximum sample rows per table + node_set: Target nodeset (default: ["database_schema"]) + + Returns: + List of created DataPoint objects + """ + pass \ No newline at end of file diff --git a/cognee/tasks/schema/models.py b/cognee/tasks/schema/models.py new file mode 100644 index 000000000..b38ec5ff5 --- /dev/null +++ b/cognee/tasks/schema/models.py @@ -0,0 +1,32 @@ +from cognee.infrastructure.engine.models.DataPoint import DataPoint +from typing import List, Dict, Optional +from datetime import datetime + +class DatabaseSchema(DataPoint): + """Represents a complete database schema with sample data""" + schema_name: str + database_type: str # sqlite, postgres, etc. + tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter + sample_data: Dict[str, List[Dict]] # Limited examples per table + extraction_timestamp: datetime + metadata: dict = {"index_fields": ["schema_name", "database_type"]} + +class SchemaTable(DataPoint): + """Represents an individual table schema with relationships""" + table_name: str + schema_name: str + columns: List[Dict] # Column definitions with types + primary_key: Optional[str] + foreign_keys: List[Dict] # Foreign key relationships + sample_rows: List[Dict] # Max 3-5 example rows + row_count_estimate: Optional[int] # Actual table size + metadata: dict = {"index_fields": ["table_name", "schema_name"]} + +class SchemaRelationship(DataPoint): + """Represents relationships between tables""" + source_table: str + target_table: str + relationship_type: str # "foreign_key", "one_to_many", etc. + source_column: str + target_column: str + metadata: dict = {"index_fields": ["source_table", "target_table"]} \ No newline at end of file From a43f19cc5914f46d9a83e1b3e7a25189f89cf252 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Sat, 13 Sep 2025 14:33:06 +0530 Subject: [PATCH 16/36] ingest_database_schema with a slight alteration with return value as Dict[str,List[DataPoint] | DataPoint]] --- cognee/tasks/schema/ingest_database_schema.py | 63 ++++++++++++++++++- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index 6f9c538cd..2b9cd38c5 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -1,12 +1,17 @@ from typing import List, Dict from cognee.infrastructure.engine.models.DataPoint import DataPoint +from cognee.infrastructure.databases.relational.get_migration_relational_engine import get_migration_relational_engine +from sqlalchemy import text +from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship +from cognee.infrastructure.databases.relational.create_relational_engine import create_relational_engine +from datetime import datetime async def ingest_database_schema( database_config: Dict, schema_name: str = "default", max_sample_rows: int = 5, node_set: List[str] = ["database_schema"] -) -> List[DataPoint]: +) -> Dict[str, List[DataPoint]|DataPoint]: """ Ingest database schema with sample data into dedicated nodeset @@ -19,4 +24,58 @@ async def ingest_database_schema( Returns: List of created DataPoint objects """ - pass \ No newline at end of file + engine = create_relational_engine( + db_path=database_config.get("db_path", ""), + db_name=database_config.get("db_name", "cognee_db"), + db_host=database_config.get("db_host"), + db_port=database_config.get("db_port"), + db_username=database_config.get("db_username"), + db_password=database_config.get("db_password"), + db_provider=database_config.get("db_provider", "sqlite"), + ) + schema = await engine.extract_schema() + tables={} + sample_data={} + schema_tables = [] + schema_relationships = [] + async with engine.engine.begin() as cursor: + for table_name, details in schema.items(): + rows_result = await cursor.execute(text(f"SELECT * FROM {table_name} LIMIT {max_sample_rows}")) + rows = [dict(zip([col["name"] for col in details["columns"]], row)) for row in rows_result.fetchall()] + count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {table_name};")) + row_count_estimate = count_result.scalar() + schema_table = SchemaTable( + table_name=table_name, + schema_name=schema_name, + columns=details["columns"], + primary_key=details.get("primary_key"), + foreign_keys=details.get("foreign_keys", []), + sample_rows=rows, + row_count_estimate=row_count_estimate + ) + schema_tables.append(schema_table) + tables[table_name] = details + sample_data[table_name] = rows + + for fk in details.get("foreign_keys",[]): + relationship = SchemaRelationship( + source_table=table_name, + target_table=fk["ref_table"], + relationship_type=fk["type"], + source_column=fk["source_column"], + target_column=fk["target_column"] + ) + schema_relationships.append(relationship) + database_schema = DatabaseSchema( + schema_name=schema_name, + database_type=database_config.get("db_provider", "sqlite"), + tables=tables, + sample_data=sample_data, + extraction_timestamp=datetime.utcnow() + ) + + return{ + "database_schema": database_schema, + "schema_tables": schema_tables, + "relationships": schema_relationships + } \ No newline at end of file From 17df14363c96e8c01d6927a99932ec751869539e Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Sun, 14 Sep 2025 03:26:06 +0530 Subject: [PATCH 17/36] integrated schema only ingestion --- .../ingestion/migrate_relational_database.py | 350 +++++++++++------- 1 file changed, 211 insertions(+), 139 deletions(-) diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index 936ea59e0..e535a0ed8 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -4,16 +4,19 @@ from sqlalchemy import text from cognee.infrastructure.databases.relational.get_migration_relational_engine import ( get_migration_relational_engine, ) +from cognee.infrastructure.databases.relational.config import get_migration_config from cognee.tasks.storage.index_data_points import index_data_points from cognee.tasks.storage.index_graph_edges import index_graph_edges +from cognee.tasks.schema.ingest_database_schema import ingest_database_schema +from cognee.tasks.schema.models import SchemaTable from cognee.modules.engine.models import TableRow, TableType, ColumnValue logger = logging.getLogger(__name__) -async def migrate_relational_database(graph_db, schema, migrate_column_data=True): +async def migrate_relational_database(graph_db, schema, migrate_column_data=True,schema_only=False): """ Migrates data from a relational database into a graph database. @@ -30,157 +33,226 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True # Create a mapping of node_id to node objects for referencing in edge creation node_mapping = {} edge_mapping = [] - - async with engine.engine.begin() as cursor: - # First, create table type nodes for all tables - for table_name, details in schema.items(): - # Create a TableType node for each table - table_node = TableType( - id=uuid5(NAMESPACE_OID, name=table_name), - name=table_name, - description=f"Table: {table_name}", - ) - - # Add TableType node to mapping ( node will be added to the graph later based on this mapping ) - node_mapping[table_name] = table_node - - # Fetch all rows for the current table - rows_result = await cursor.execute(text(f"SELECT * FROM {table_name};")) - rows = rows_result.fetchall() - - for row in rows: - # Build a dictionary of properties from the row - row_properties = { - col["name"]: row[idx] for idx, col in enumerate(details["columns"]) - } - - # Determine the primary key value - if not details["primary_key"]: - # Use the first column as primary key if not specified - primary_key_col = details["columns"][0]["name"] - primary_key_value = row_properties[primary_key_col] - else: - # Use value of the specified primary key column - primary_key_col = details["primary_key"] - primary_key_value = row_properties[primary_key_col] - - # Create a node ID in the format "table_name:primary_key_value" - node_id = f"{table_name}:{primary_key_value}" - - # Create a TableRow node - # Node id must uniquely map to the id used in the relational database - # To catch the foreign key relationships properly - row_node = TableRow( - id=uuid5(NAMESPACE_OID, name=node_id), - name=node_id, - is_a=table_node, - properties=str(row_properties), - description=f"Row in {table_name} with {primary_key_col}={primary_key_value}", + + if schema_only: + database_config = get_migration_config().to_dict() + # Calling the ingest_database_schema function to return DataPoint subclasses + result = await ingest_database_schema( + database_config=database_config, + schema_name="migrated_schema", + max_sample_rows=5, + node_set=["database_schema", "schema_tables", "relationships"] + ) + database_schema = result["database_schema"] + schema_tables = result["schema_tables"] + schema_relationships = result["relationships"] + database_node_id = database_schema.id + node_mapping[database_node_id] = database_schema + for table in schema_tables: + table_node_id = table.id + # Add TableSchema Datapoint as a node. + node_mapping[table_node_id]=table + edge_mapping.append(( + table_node_id, + database_node_id, + "is_part_of", + dict( + source_node_id=table_node_id, + target_node_id=database_node_id, + relationship_name="is_part_of", + ), + )) + for rel in schema_relationships: + source_table_id = uuid5(NAMESPACE_OID,name=rel.source_table) + target_table_id = uuid5(NAMESPACE_OID,name=rel.target_table) + relationship_id = rel.id + + # Add RelationshipTable DataPoint as a node. + node_mapping[relationship_id]=rel + edge_mapping.append(( + source_table_id, + relationship_id, + "has_relationship", + dict( + source_node_id=source_table_id, + target_node_id=relationship_id, + relationship_name=rel.relationship_type, + ), + )) + edge_mapping.append(( + relationship_id, + target_table_id, + "has_relationship", + dict( + source_node_id=relationship_id, + target_node_id=target_table_id, + relationship_name=rel.relationship_type, + ) + )) + edge_mapping.append(( + source_table_id, + target_table_id, + rel.relationship_type, + dict( + source_node_id=source_table_id, + target_node_id=target_table_id, + relationship_name=rel.relationship_type, + ), + )) + + + + else: + async with engine.engine.begin() as cursor: + # First, create table type nodes for all tables + for table_name, details in schema.items(): + # Create a TableType node for each table + table_node = TableType( + id=uuid5(NAMESPACE_OID, name=table_name), + name=table_name, + description=f"Table: {table_name}", ) - # Store the node object in our mapping - node_mapping[node_id] = row_node + # Add TableType node to mapping ( node will be added to the graph later based on this mapping ) + node_mapping[table_name] = table_node - # Add edge between row node and table node ( it will be added to the graph later ) - edge_mapping.append( - ( - row_node.id, - table_node.id, - "is_part_of", - dict( - relationship_name="is_part_of", - source_node_id=row_node.id, - target_node_id=table_node.id, - ), + # Fetch all rows for the current table + rows_result = await cursor.execute(text(f"SELECT * FROM {table_name};")) + rows = rows_result.fetchall() + + for row in rows: + # Build a dictionary of properties from the row + row_properties = { + col["name"]: row[idx] for idx, col in enumerate(details["columns"]) + } + + # Determine the primary key value + if not details["primary_key"]: + # Use the first column as primary key if not specified + primary_key_col = details["columns"][0]["name"] + primary_key_value = row_properties[primary_key_col] + else: + # Use value of the specified primary key column + primary_key_col = details["primary_key"] + primary_key_value = row_properties[primary_key_col] + + # Create a node ID in the format "table_name:primary_key_value" + node_id = f"{table_name}:{primary_key_value}" + + # Create a TableRow node + # Node id must uniquely map to the id used in the relational database + # To catch the foreign key relationships properly + row_node = TableRow( + id=uuid5(NAMESPACE_OID, name=node_id), + name=node_id, + is_a=table_node, + properties=str(row_properties), + description=f"Row in {table_name} with {primary_key_col}={primary_key_value}", ) - ) - # Migrate data stored in columns of table rows - if migrate_column_data: - # Get foreign key columns to filter them out from column migration - foreign_keys = [] - for fk in details.get("foreign_keys", []): - foreign_keys.append(fk["ref_column"]) + # Store the node object in our mapping + node_mapping[node_id] = row_node - for key, value in row_properties.items(): - # Skip mapping primary key information to itself and mapping of foreign key information (as it will be mapped bellow) - if key is primary_key_col or key in foreign_keys: - continue - - # Create column value node - column_node_id = f"{table_name}:{key}:{value}" - column_node = ColumnValue( - id=uuid5(NAMESPACE_OID, name=column_node_id), - name=column_node_id, - properties=f"{key} {value} {table_name}", - description=f"Column name={key} and value={value} from column from table={table_name}", - ) - node_mapping[column_node_id] = column_node - - # Create relationship between column value of table row and table row - edge_mapping.append( - ( - row_node.id, - column_node.id, - key, - dict( - relationship_name=key, - source_node_id=row_node.id, - target_node_id=column_node.id, - ), - ) - ) - - # Process foreign key relationships after all nodes are created - for table_name, details in schema.items(): - # Process foreign key relationships for the current table - for fk in details.get("foreign_keys", []): - # Aliases needed for self-referencing tables - alias_1 = f"{table_name}_e1" - alias_2 = f"{fk['ref_table']}_e2" - - # Determine primary key column - if not details["primary_key"]: - primary_key_col = details["columns"][0]["name"] - else: - primary_key_col = details["primary_key"] - - # Query to find relationships based on foreign keys - fk_query = text( - f"SELECT {alias_1}.{primary_key_col} AS source_id, " - f"{alias_2}.{fk['ref_column']} AS ref_value " - f"FROM {table_name} AS {alias_1} " - f"JOIN {fk['ref_table']} AS {alias_2} " - f"ON {alias_1}.{fk['column']} = {alias_2}.{fk['ref_column']};" - ) - - fk_result = await cursor.execute(fk_query) - relations = fk_result.fetchall() - - for source_id, ref_value in relations: - # Construct node ids - source_node_id = f"{table_name}:{source_id}" - target_node_id = f"{fk['ref_table']}:{ref_value}" - - # Get the source and target node objects from our mapping - source_node = node_mapping[source_node_id] - target_node = node_mapping[target_node_id] - - # Add edge representing the foreign key relationship using the node objects - # Create edge to add to graph later + # Add edge between row node and table node ( it will be added to the graph later ) edge_mapping.append( ( - source_node.id, - target_node.id, - fk["column"], + row_node.id, + table_node.id, + "is_part_of", dict( - source_node_id=source_node.id, - target_node_id=target_node.id, - relationship_name=fk["column"], + relationship_name="is_part_of", + source_node_id=row_node.id, + target_node_id=table_node.id, ), ) ) + # Migrate data stored in columns of table rows + if migrate_column_data: + # Get foreign key columns to filter them out from column migration + foreign_keys = [] + for fk in details.get("foreign_keys", []): + foreign_keys.append(fk["ref_column"]) + + for key, value in row_properties.items(): + # Skip mapping primary key information to itself and mapping of foreign key information (as it will be mapped bellow) + if key is primary_key_col or key in foreign_keys: + continue + + # Create column value node + column_node_id = f"{table_name}:{key}:{value}" + column_node = ColumnValue( + id=uuid5(NAMESPACE_OID, name=column_node_id), + name=column_node_id, + properties=f"{key} {value} {table_name}", + description=f"Column name={key} and value={value} from column from table={table_name}", + ) + node_mapping[column_node_id] = column_node + + # Create relationship between column value of table row and table row + edge_mapping.append( + ( + row_node.id, + column_node.id, + key, + dict( + relationship_name=key, + source_node_id=row_node.id, + target_node_id=column_node.id, + ), + ) + ) + + # Process foreign key relationships after all nodes are created + for table_name, details in schema.items(): + # Process foreign key relationships for the current table + for fk in details.get("foreign_keys", []): + # Aliases needed for self-referencing tables + alias_1 = f"{table_name}_e1" + alias_2 = f"{fk['ref_table']}_e2" + + # Determine primary key column + if not details["primary_key"]: + primary_key_col = details["columns"][0]["name"] + else: + primary_key_col = details["primary_key"] + + # Query to find relationships based on foreign keys + fk_query = text( + f"SELECT {alias_1}.{primary_key_col} AS source_id, " + f"{alias_2}.{fk['ref_column']} AS ref_value " + f"FROM {table_name} AS {alias_1} " + f"JOIN {fk['ref_table']} AS {alias_2} " + f"ON {alias_1}.{fk['column']} = {alias_2}.{fk['ref_column']};" + ) + + fk_result = await cursor.execute(fk_query) + relations = fk_result.fetchall() + + for source_id, ref_value in relations: + # Construct node ids + source_node_id = f"{table_name}:{source_id}" + target_node_id = f"{fk['ref_table']}:{ref_value}" + + # Get the source and target node objects from our mapping + source_node = node_mapping[source_node_id] + target_node = node_mapping[target_node_id] + + # Add edge representing the foreign key relationship using the node objects + # Create edge to add to graph later + edge_mapping.append( + ( + source_node.id, + target_node.id, + fk["column"], + dict( + source_node_id=source_node.id, + target_node_id=target_node.id, + relationship_name=fk["column"], + ), + ) + ) + def _remove_duplicate_edges(edge_mapping): seen = set() unique_original_shape = [] From f5bb91e49df908c2eed636c539196aa2c4cba4ca Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Sun, 14 Sep 2025 03:29:38 +0530 Subject: [PATCH 18/36] added description attribute to every schema model --- cognee/tasks/schema/ingest_database_schema.py | 53 ++++++++++++------- cognee/tasks/schema/models.py | 5 +- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index 2b9cd38c5..6d6f0b5f3 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -1,4 +1,5 @@ from typing import List, Dict +from uuid import uuid5, NAMESPACE_OID from cognee.infrastructure.engine.models.DataPoint import DataPoint from cognee.infrastructure.databases.relational.get_migration_relational_engine import get_migration_relational_engine from sqlalchemy import text @@ -11,7 +12,7 @@ async def ingest_database_schema( schema_name: str = "default", max_sample_rows: int = 5, node_set: List[str] = ["database_schema"] -) -> Dict[str, List[DataPoint]|DataPoint]: +) -> Dict[str, List[DataPoint] | DataPoint]: """ Ingest database schema with sample data into dedicated nodeset @@ -25,57 +26,69 @@ async def ingest_database_schema( List of created DataPoint objects """ engine = create_relational_engine( - db_path=database_config.get("db_path", ""), - db_name=database_config.get("db_name", "cognee_db"), - db_host=database_config.get("db_host"), - db_port=database_config.get("db_port"), - db_username=database_config.get("db_username"), - db_password=database_config.get("db_password"), - db_provider=database_config.get("db_provider", "sqlite"), + db_path=database_config.get("migration_db_path", ""), + db_name=database_config.get("migration_db_name", "cognee_db"), + db_host=database_config.get("migration_db_host"), + db_port=database_config.get("migration_db_port"), + db_username=database_config.get("migration_db_username"), + db_password=database_config.get("migration_db_password"), + db_provider=database_config.get("migration_db_provider", "sqlite"), ) schema = await engine.extract_schema() - tables={} - sample_data={} + tables = {} + sample_data = {} schema_tables = [] schema_relationships = [] + async with engine.engine.begin() as cursor: for table_name, details in schema.items(): + print(table_name) rows_result = await cursor.execute(text(f"SELECT * FROM {table_name} LIMIT {max_sample_rows}")) rows = [dict(zip([col["name"] for col in details["columns"]], row)) for row in rows_result.fetchall()] count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {table_name};")) row_count_estimate = count_result.scalar() + schema_table = SchemaTable( + id=uuid5(NAMESPACE_OID, name=table_name), table_name=table_name, schema_name=schema_name, columns=details["columns"], primary_key=details.get("primary_key"), foreign_keys=details.get("foreign_keys", []), sample_rows=rows, - row_count_estimate=row_count_estimate + row_count_estimate=row_count_estimate, + description=f"Schema table for '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows." ) schema_tables.append(schema_table) tables[table_name] = details sample_data[table_name] = rows - - for fk in details.get("foreign_keys",[]): + + for fk in details.get("foreign_keys", []): + print(f"ref_table:{fk['ref_table']}") + print(f"table_name:{table_name}") relationship = SchemaRelationship( + id=uuid5(NAMESPACE_OID, name=f"{fk['column']}:{table_name}:{fk['ref_column']}:{fk['ref_table']}"), source_table=table_name, target_table=fk["ref_table"], - relationship_type=fk["type"], - source_column=fk["source_column"], - target_column=fk["target_column"] + relationship_type="foreign_key", + source_column=fk["column"], + target_column=fk["ref_column"], + description=f"Foreign key relationship: {table_name}.{fk['column']} → {fk['ref_table']}.{fk['ref_column']}" ) schema_relationships.append(relationship) + database_schema = DatabaseSchema( + id=uuid5(NAMESPACE_OID, name=schema_name), schema_name=schema_name, database_type=database_config.get("db_provider", "sqlite"), tables=tables, sample_data=sample_data, - extraction_timestamp=datetime.utcnow() + extraction_timestamp=datetime.utcnow(), + description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships." ) - - return{ + + return { "database_schema": database_schema, "schema_tables": schema_tables, "relationships": schema_relationships - } \ No newline at end of file + } diff --git a/cognee/tasks/schema/models.py b/cognee/tasks/schema/models.py index b38ec5ff5..ef9374163 100644 --- a/cognee/tasks/schema/models.py +++ b/cognee/tasks/schema/models.py @@ -9,6 +9,7 @@ class DatabaseSchema(DataPoint): tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter sample_data: Dict[str, List[Dict]] # Limited examples per table extraction_timestamp: datetime + description: str metadata: dict = {"index_fields": ["schema_name", "database_type"]} class SchemaTable(DataPoint): @@ -20,13 +21,15 @@ class SchemaTable(DataPoint): foreign_keys: List[Dict] # Foreign key relationships sample_rows: List[Dict] # Max 3-5 example rows row_count_estimate: Optional[int] # Actual table size + description: str metadata: dict = {"index_fields": ["table_name", "schema_name"]} class SchemaRelationship(DataPoint): """Represents relationships between tables""" source_table: str target_table: str - relationship_type: str # "foreign_key", "one_to_many", etc. + relationship_type: str source_column: str target_column: str + description: str metadata: dict = {"index_fields": ["source_table", "target_table"]} \ No newline at end of file From 51dfac359debcaa05d685412d53f559547de7f6c Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Sun, 14 Sep 2025 21:30:26 +0530 Subject: [PATCH 19/36] Removed print statements used while debugging --- cognee/tasks/schema/ingest_database_schema.py | 9 +++------ cognee/tasks/schema/models.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index 6d6f0b5f3..2ac57d0ba 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -42,7 +42,6 @@ async def ingest_database_schema( async with engine.engine.begin() as cursor: for table_name, details in schema.items(): - print(table_name) rows_result = await cursor.execute(text(f"SELECT * FROM {table_name} LIMIT {max_sample_rows}")) rows = [dict(zip([col["name"] for col in details["columns"]], row)) for row in rows_result.fetchall()] count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {table_name};")) @@ -57,15 +56,13 @@ async def ingest_database_schema( foreign_keys=details.get("foreign_keys", []), sample_rows=rows, row_count_estimate=row_count_estimate, - description=f"Schema table for '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows." + description=f"" ) schema_tables.append(schema_table) tables[table_name] = details sample_data[table_name] = rows for fk in details.get("foreign_keys", []): - print(f"ref_table:{fk['ref_table']}") - print(f"table_name:{table_name}") relationship = SchemaRelationship( id=uuid5(NAMESPACE_OID, name=f"{fk['column']}:{table_name}:{fk['ref_column']}:{fk['ref_table']}"), source_table=table_name, @@ -73,7 +70,7 @@ async def ingest_database_schema( relationship_type="foreign_key", source_column=fk["column"], target_column=fk["ref_column"], - description=f"Foreign key relationship: {table_name}.{fk['column']} → {fk['ref_table']}.{fk['ref_column']}" + description=f"" ) schema_relationships.append(relationship) @@ -84,7 +81,7 @@ async def ingest_database_schema( tables=tables, sample_data=sample_data, extraction_timestamp=datetime.utcnow(), - description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships." + description=f"" ) return { diff --git a/cognee/tasks/schema/models.py b/cognee/tasks/schema/models.py index ef9374163..0fb248758 100644 --- a/cognee/tasks/schema/models.py +++ b/cognee/tasks/schema/models.py @@ -28,7 +28,7 @@ class SchemaRelationship(DataPoint): """Represents relationships between tables""" source_table: str target_table: str - relationship_type: str + relationship_type: str # "foreign_key", "one_to_many", etc. source_column: str target_column: str description: str From 1ba9e1df317810b0ba796dd3fe75b8ca4c61cb89 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Sun, 14 Sep 2025 21:56:31 +0530 Subject: [PATCH 20/36] done with ruff checks --- .../ingestion/migrate_relational_database.py | 117 ++++++++++-------- cognee/tasks/schema/ingest_database_schema.py | 37 ++++-- cognee/tasks/schema/models.py | 8 +- 3 files changed, 95 insertions(+), 67 deletions(-) diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index e535a0ed8..62a8a0eac 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -16,7 +16,9 @@ from cognee.modules.engine.models import TableRow, TableType, ColumnValue logger = logging.getLogger(__name__) -async def migrate_relational_database(graph_db, schema, migrate_column_data=True,schema_only=False): +async def migrate_relational_database( + graph_db, schema, migrate_column_data=True, schema_only=False +): """ Migrates data from a relational database into a graph database. @@ -33,15 +35,15 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True # Create a mapping of node_id to node objects for referencing in edge creation node_mapping = {} edge_mapping = [] - + if schema_only: - database_config = get_migration_config().to_dict() + database_config = get_migration_config().to_dict() # Calling the ingest_database_schema function to return DataPoint subclasses result = await ingest_database_schema( database_config=database_config, schema_name="migrated_schema", max_sample_rows=5, - node_set=["database_schema", "schema_tables", "relationships"] + node_set=["database_schema", "schema_tables", "relationships"], ) database_schema = result["database_schema"] schema_tables = result["schema_tables"] @@ -51,57 +53,64 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True for table in schema_tables: table_node_id = table.id # Add TableSchema Datapoint as a node. - node_mapping[table_node_id]=table - edge_mapping.append(( - table_node_id, - database_node_id, - "is_part_of", - dict( - source_node_id=table_node_id, - target_node_id=database_node_id, - relationship_name="is_part_of", - ), - )) - for rel in schema_relationships: - source_table_id = uuid5(NAMESPACE_OID,name=rel.source_table) - target_table_id = uuid5(NAMESPACE_OID,name=rel.target_table) - relationship_id = rel.id - - # Add RelationshipTable DataPoint as a node. - node_mapping[relationship_id]=rel - edge_mapping.append(( - source_table_id, - relationship_id, - "has_relationship", - dict( - source_node_id=source_table_id, - target_node_id=relationship_id, - relationship_name=rel.relationship_type, - ), - )) - edge_mapping.append(( - relationship_id, - target_table_id, - "has_relationship", - dict( - source_node_id=relationship_id, - target_node_id=target_table_id, - relationship_name=rel.relationship_type, + node_mapping[table_node_id] = table + edge_mapping.append( + ( + table_node_id, + database_node_id, + "is_part_of", + dict( + source_node_id=table_node_id, + target_node_id=database_node_id, + relationship_name="is_part_of", + ), ) - )) - edge_mapping.append(( - source_table_id, - target_table_id, - rel.relationship_type, - dict( - source_node_id=source_table_id, - target_node_id=target_table_id, - relationship_name=rel.relationship_type, - ), - )) - - - + ) + for rel in schema_relationships: + source_table_id = uuid5(NAMESPACE_OID, name=rel.source_table) + target_table_id = uuid5(NAMESPACE_OID, name=rel.target_table) + + relationship_id = rel.id + + # Add RelationshipTable DataPoint as a node. + node_mapping[relationship_id] = rel + edge_mapping.append( + ( + source_table_id, + relationship_id, + "has_relationship", + dict( + source_node_id=source_table_id, + target_node_id=relationship_id, + relationship_name=rel.relationship_type, + ), + ) + ) + edge_mapping.append( + ( + relationship_id, + target_table_id, + "has_relationship", + dict( + source_node_id=relationship_id, + target_node_id=target_table_id, + relationship_name=rel.relationship_type, + ), + ) + ) + edge_mapping.append( + ( + source_table_id, + target_table_id, + rel.relationship_type, + dict( + source_node_id=source_table_id, + target_node_id=target_table_id, + relationship_name=rel.relationship_type, + ), + ) + ) + else: async with engine.engine.begin() as cursor: # First, create table type nodes for all tables diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index 2ac57d0ba..c4c13449d 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -1,27 +1,32 @@ from typing import List, Dict from uuid import uuid5, NAMESPACE_OID from cognee.infrastructure.engine.models.DataPoint import DataPoint -from cognee.infrastructure.databases.relational.get_migration_relational_engine import get_migration_relational_engine +from cognee.infrastructure.databases.relational.get_migration_relational_engine import ( + get_migration_relational_engine, +) from sqlalchemy import text from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship -from cognee.infrastructure.databases.relational.create_relational_engine import create_relational_engine +from cognee.infrastructure.databases.relational.create_relational_engine import ( + create_relational_engine, +) from datetime import datetime + async def ingest_database_schema( database_config: Dict, schema_name: str = "default", max_sample_rows: int = 5, - node_set: List[str] = ["database_schema"] + node_set: List[str] = ["database_schema"], ) -> Dict[str, List[DataPoint] | DataPoint]: """ Ingest database schema with sample data into dedicated nodeset - + Args: database_config: Database connection configuration schema_name: Name identifier for this schema max_sample_rows: Maximum sample rows per table node_set: Target nodeset (default: ["database_schema"]) - + Returns: List of created DataPoint objects """ @@ -42,8 +47,13 @@ async def ingest_database_schema( async with engine.engine.begin() as cursor: for table_name, details in schema.items(): - rows_result = await cursor.execute(text(f"SELECT * FROM {table_name} LIMIT {max_sample_rows}")) - rows = [dict(zip([col["name"] for col in details["columns"]], row)) for row in rows_result.fetchall()] + rows_result = await cursor.execute( + text(f"SELECT * FROM {table_name} LIMIT {max_sample_rows}") + ) + rows = [ + dict(zip([col["name"] for col in details["columns"]], row)) + for row in rows_result.fetchall() + ] count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {table_name};")) row_count_estimate = count_result.scalar() @@ -56,7 +66,7 @@ async def ingest_database_schema( foreign_keys=details.get("foreign_keys", []), sample_rows=rows, row_count_estimate=row_count_estimate, - description=f"" + description=f"Schema table for '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows.", ) schema_tables.append(schema_table) tables[table_name] = details @@ -64,13 +74,16 @@ async def ingest_database_schema( for fk in details.get("foreign_keys", []): relationship = SchemaRelationship( - id=uuid5(NAMESPACE_OID, name=f"{fk['column']}:{table_name}:{fk['ref_column']}:{fk['ref_table']}"), + id=uuid5( + NAMESPACE_OID, + name=f"{fk['column']}:{table_name}:{fk['ref_column']}:{fk['ref_table']}", + ), source_table=table_name, target_table=fk["ref_table"], relationship_type="foreign_key", source_column=fk["column"], target_column=fk["ref_column"], - description=f"" + description=f"Foreign key relationship: {table_name}.{fk['column']} → {fk['ref_table']}.{fk['ref_column']}", ) schema_relationships.append(relationship) @@ -81,11 +94,11 @@ async def ingest_database_schema( tables=tables, sample_data=sample_data, extraction_timestamp=datetime.utcnow(), - description=f"" + description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships.", ) return { "database_schema": database_schema, "schema_tables": schema_tables, - "relationships": schema_relationships + "relationships": schema_relationships, } diff --git a/cognee/tasks/schema/models.py b/cognee/tasks/schema/models.py index 0fb248758..423c92050 100644 --- a/cognee/tasks/schema/models.py +++ b/cognee/tasks/schema/models.py @@ -2,8 +2,10 @@ from cognee.infrastructure.engine.models.DataPoint import DataPoint from typing import List, Dict, Optional from datetime import datetime + class DatabaseSchema(DataPoint): """Represents a complete database schema with sample data""" + schema_name: str database_type: str # sqlite, postgres, etc. tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter @@ -12,8 +14,10 @@ class DatabaseSchema(DataPoint): description: str metadata: dict = {"index_fields": ["schema_name", "database_type"]} + class SchemaTable(DataPoint): """Represents an individual table schema with relationships""" + table_name: str schema_name: str columns: List[Dict] # Column definitions with types @@ -24,12 +28,14 @@ class SchemaTable(DataPoint): description: str metadata: dict = {"index_fields": ["table_name", "schema_name"]} + class SchemaRelationship(DataPoint): """Represents relationships between tables""" + source_table: str target_table: str relationship_type: str # "foreign_key", "one_to_many", etc. source_column: str target_column: str description: str - metadata: dict = {"index_fields": ["source_table", "target_table"]} \ No newline at end of file + metadata: dict = {"index_fields": ["source_table", "target_table"]} From 7cf4a0daeb260318bb2f9db9c32b69c2187f1db5 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Mon, 15 Sep 2025 11:12:43 +0530 Subject: [PATCH 21/36] id mismatch risk negated --- cognee/tasks/ingestion/migrate_relational_database.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index 62a8a0eac..e857ab34d 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -66,9 +66,10 @@ async def migrate_relational_database( ), ) ) + table_name_to_id = {t.table_name: t.id for t in schema_tables} for rel in schema_relationships: - source_table_id = uuid5(NAMESPACE_OID, name=rel.source_table) - target_table_id = uuid5(NAMESPACE_OID, name=rel.target_table) + source_table_id = table_name_to_id.get(rel.source_table) + target_table_id = table_name_to_id.get(rel.target_table) relationship_id = rel.id From 60016a6b09a76083ce01baeb38b1559b0acd130a Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Mon, 15 Sep 2025 12:55:38 +0530 Subject: [PATCH 22/36] more nitpick comments solved --- cognee/tasks/schema/ingest_database_schema.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index c4c13449d..f93314a3d 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -47,8 +47,11 @@ async def ingest_database_schema( async with engine.engine.begin() as cursor: for table_name, details in schema.items(): + qi = engine.engine.dialect.identifier_preparer.quote + tn = qi(table_name) rows_result = await cursor.execute( - text(f"SELECT * FROM {table_name} LIMIT {max_sample_rows}") + text(f"SELECT * FROM {tn} LIMIT :limit;"), + {"limit": max_sample_rows} ) rows = [ dict(zip([col["name"] for col in details["columns"]], row)) @@ -58,7 +61,7 @@ async def ingest_database_schema( row_count_estimate = count_result.scalar() schema_table = SchemaTable( - id=uuid5(NAMESPACE_OID, name=table_name), + id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{table_name}"), table_name=table_name, schema_name=schema_name, columns=details["columns"], From 7ec066111ed0d8323d5dedca74b065186a8a2d73 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Mon, 15 Sep 2025 13:39:15 +0530 Subject: [PATCH 23/36] Solved address configuration key inconsistency. --- cognee/tasks/schema/ingest_database_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index f93314a3d..e80ce2e75 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -93,7 +93,7 @@ async def ingest_database_schema( database_schema = DatabaseSchema( id=uuid5(NAMESPACE_OID, name=schema_name), schema_name=schema_name, - database_type=database_config.get("db_provider", "sqlite"), + database_type=database_config.get("migration_db_provider", "sqlite"), tables=tables, sample_data=sample_data, extraction_timestamp=datetime.utcnow(), From 93c733e6871f7232bc847a00829f9903611c6048 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Mon, 15 Sep 2025 16:13:32 +0530 Subject: [PATCH 24/36] solved more nitpick comments --- .../ingestion/migrate_relational_database.py | 1 - cognee/tasks/schema/ingest_database_schema.py | 20 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index e857ab34d..824fef2fa 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -43,7 +43,6 @@ async def migrate_relational_database( database_config=database_config, schema_name="migrated_schema", max_sample_rows=5, - node_set=["database_schema", "schema_tables", "relationships"], ) database_schema = result["database_schema"] schema_tables = result["schema_tables"] diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index e80ce2e75..be544408b 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import List, Dict, Optional from uuid import uuid5, NAMESPACE_OID from cognee.infrastructure.engine.models.DataPoint import DataPoint from cognee.infrastructure.databases.relational.get_migration_relational_engine import ( @@ -16,7 +16,6 @@ async def ingest_database_schema( database_config: Dict, schema_name: str = "default", max_sample_rows: int = 5, - node_set: List[str] = ["database_schema"], ) -> Dict[str, List[DataPoint] | DataPoint]: """ Ingest database schema with sample data into dedicated nodeset @@ -25,7 +24,6 @@ async def ingest_database_schema( database_config: Database connection configuration schema_name: Name identifier for this schema max_sample_rows: Maximum sample rows per table - node_set: Target nodeset (default: ["database_schema"]) Returns: List of created DataPoint objects @@ -48,6 +46,8 @@ async def ingest_database_schema( async with engine.engine.begin() as cursor: for table_name, details in schema.items(): qi = engine.engine.dialect.identifier_preparer.quote + qname = lambda name : ".".join(qi(p) for p in name.split(".")) + tn = qname(table_name) tn = qi(table_name) rows_result = await cursor.execute( text(f"SELECT * FROM {tn} LIMIT :limit;"), @@ -57,11 +57,11 @@ async def ingest_database_schema( dict(zip([col["name"] for col in details["columns"]], row)) for row in rows_result.fetchall() ] - count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {table_name};")) + count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) row_count_estimate = count_result.scalar() schema_table = SchemaTable( - id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{table_name}"), + id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{tn}"), table_name=table_name, schema_name=schema_name, columns=details["columns"], @@ -76,17 +76,21 @@ async def ingest_database_schema( sample_data[table_name] = rows for fk in details.get("foreign_keys", []): + ref_table_fq = fk["ref_table"] + if '.' not in ref_table_fq and '.' in table_name: + ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}" + relationship = SchemaRelationship( id=uuid5( NAMESPACE_OID, - name=f"{fk['column']}:{table_name}:{fk['ref_column']}:{fk['ref_table']}", + name=f"{schema_name}:{table_name}:{fk['column']}->{ref_table_fq}:{fk['ref_column']}", ), source_table=table_name, - target_table=fk["ref_table"], + target_table=ref_table_fq, relationship_type="foreign_key", source_column=fk["column"], target_column=fk["ref_column"], - description=f"Foreign key relationship: {table_name}.{fk['column']} → {fk['ref_table']}.{fk['ref_column']}", + description=f"Foreign key relationship: {table_name}.{fk['column']} → {ref_table_fq}.{fk['ref_column']}", ) schema_relationships.append(relationship) From 1e59f1594cf6778a866a6f4c8906d58faa98fa76 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Mon, 15 Sep 2025 16:30:58 +0530 Subject: [PATCH 25/36] solved more nitpick comments --- cognee/tasks/schema/ingest_database_schema.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index be544408b..2a343ea0d 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -26,7 +26,10 @@ async def ingest_database_schema( max_sample_rows: Maximum sample rows per table Returns: - List of created DataPoint objects + Dict with keys: + "database_schema": DatabaseSchema + "schema_tables": List[SchemaTable] + "relationships": List[SchemaRelationship] """ engine = create_relational_engine( db_path=database_config.get("migration_db_path", ""), @@ -48,7 +51,6 @@ async def ingest_database_schema( qi = engine.engine.dialect.identifier_preparer.quote qname = lambda name : ".".join(qi(p) for p in name.split(".")) tn = qname(table_name) - tn = qi(table_name) rows_result = await cursor.execute( text(f"SELECT * FROM {tn} LIMIT :limit;"), {"limit": max_sample_rows} @@ -61,7 +63,7 @@ async def ingest_database_schema( row_count_estimate = count_result.scalar() schema_table = SchemaTable( - id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{tn}"), + id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{table_name}"), table_name=table_name, schema_name=schema_name, columns=details["columns"], From df8b80d4a9e9b21e1ba8cbeb89ffae1e72b6f8b1 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Mon, 15 Sep 2025 19:05:00 +0530 Subject: [PATCH 26/36] solved more nitpick comments --- cognee/tasks/schema/ingest_database_schema.py | 53 +++++++++++++------ 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index 2a343ea0d..e362734fc 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -9,13 +9,13 @@ from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelati from cognee.infrastructure.databases.relational.create_relational_engine import ( create_relational_engine, ) -from datetime import datetime +from datetime import datetime, timezone async def ingest_database_schema( database_config: Dict, schema_name: str = "default", - max_sample_rows: int = 5, + max_sample_rows: int = 0, ) -> Dict[str, List[DataPoint] | DataPoint]: """ Ingest database schema with sample data into dedicated nodeset @@ -45,22 +45,41 @@ async def ingest_database_schema( sample_data = {} schema_tables = [] schema_relationships = [] + qi = engine.engine.dialect.identifier_preparer.quote + + def qname(name: str): + split_name = name.split(".") + ".".join(qi(p) for p in split_name) async with engine.engine.begin() as cursor: for table_name, details in schema.items(): - qi = engine.engine.dialect.identifier_preparer.quote - qname = lambda name : ".".join(qi(p) for p in name.split(".")) tn = qname(table_name) - rows_result = await cursor.execute( - text(f"SELECT * FROM {tn} LIMIT :limit;"), - {"limit": max_sample_rows} - ) - rows = [ - dict(zip([col["name"] for col in details["columns"]], row)) - for row in rows_result.fetchall() - ] - count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) - row_count_estimate = count_result.scalar() + if max_sample_rows > 0: + rows_result = await cursor.execute( + text(f"SELECT * FROM {tn} LIMIT :limit;"), {"limit": max_sample_rows} + ) + rows = [dict(r) for r in rows_result.mappings().all()] + else: + rows = [] + row_count_estimate = 0 + if engine.engine.dialect.name == "postegresql": + if "." in table_name: + schema_part, table_part = table_name.split(".", 1) + else: + schema_part, table_part = "public", table_name + estimate = await cursor.execute( + text( + "SELECT reltuples:bigint " + "FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relname = :table" + ), + {"schema": schema_part, "table": table_part}, + ) + row_count_estimate = estimate.scalar() + else: + count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) + row_count_estimate = count_result.scalar() schema_table = SchemaTable( id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{table_name}"), @@ -79,9 +98,9 @@ async def ingest_database_schema( for fk in details.get("foreign_keys", []): ref_table_fq = fk["ref_table"] - if '.' not in ref_table_fq and '.' in table_name: + if "." not in ref_table_fq and "." in table_name: ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}" - + relationship = SchemaRelationship( id=uuid5( NAMESPACE_OID, @@ -102,7 +121,7 @@ async def ingest_database_schema( database_type=database_config.get("migration_db_provider", "sqlite"), tables=tables, sample_data=sample_data, - extraction_timestamp=datetime.utcnow(), + extraction_timestamp=datetime.now(timezone.utc), description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships.", ) From e7bcf9043f36ab631390beef2be7a1d2a3ecc359 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Mon, 15 Sep 2025 19:53:21 +0530 Subject: [PATCH 27/36] solved more nitpick comments --- cognee/tasks/schema/ingest_database_schema.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index e362734fc..026bf588c 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -1,9 +1,6 @@ -from typing import List, Dict, Optional +from typing import List, Dict from uuid import uuid5, NAMESPACE_OID from cognee.infrastructure.engine.models.DataPoint import DataPoint -from cognee.infrastructure.databases.relational.get_migration_relational_engine import ( - get_migration_relational_engine, -) from sqlalchemy import text from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship from cognee.infrastructure.databases.relational.create_relational_engine import ( @@ -18,12 +15,12 @@ async def ingest_database_schema( max_sample_rows: int = 0, ) -> Dict[str, List[DataPoint] | DataPoint]: """ - Ingest database schema with sample data into dedicated nodeset + Extract database schema metadata (optionally with sample data) and return DataPoint models for graph construction. Args: database_config: Database connection configuration schema_name: Name identifier for this schema - max_sample_rows: Maximum sample rows per table + max_sample_rows: Maximum sample rows per table (0 means no sampling) Returns: Dict with keys: @@ -49,36 +46,37 @@ async def ingest_database_schema( def qname(name: str): split_name = name.split(".") - ".".join(qi(p) for p in split_name) + return ".".join(qi(p) for p in split_name) async with engine.engine.begin() as cursor: for table_name, details in schema.items(): tn = qname(table_name) if max_sample_rows > 0: rows_result = await cursor.execute( - text(f"SELECT * FROM {tn} LIMIT :limit;"), {"limit": max_sample_rows} + text(f"SELECT * FROM {tn} LIMIT :limit;"), + {"limit": max_sample_rows}, # noqa: S608 - tn is fully quoted ) rows = [dict(r) for r in rows_result.mappings().all()] else: rows = [] row_count_estimate = 0 - if engine.engine.dialect.name == "postegresql": + if engine.engine.dialect.name == "postgresql": if "." in table_name: schema_part, table_part = table_name.split(".", 1) else: schema_part, table_part = "public", table_name estimate = await cursor.execute( text( - "SELECT reltuples:bigint " + "SELECT reltuples:bigint AS estimate " "FROM pg_class c " "JOIN pg_namespace n ON n.oid = c.relnamespace " "WHERE n.nspname = :schema AND c.relname = :table" ), {"schema": schema_part, "table": table_part}, ) - row_count_estimate = estimate.scalar() + row_count_estimate = estimate.scalar() or 0 else: - count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) + count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) # noqa: S608 - tn is fully quoted row_count_estimate = count_result.scalar() schema_table = SchemaTable( @@ -115,8 +113,9 @@ async def ingest_database_schema( ) schema_relationships.append(relationship) + id_str = f"{database_config.get('migration_db_provider', 'sqlite')}:{database_config.get('migration_db_name', 'cognee_db')}:{schema_name}" database_schema = DatabaseSchema( - id=uuid5(NAMESPACE_OID, name=schema_name), + id=uuid5(NAMESPACE_OID, name=id_str), schema_name=schema_name, database_type=database_config.get("migration_db_provider", "sqlite"), tables=tables, From 67f948a1458e2474641180f530689f939bb29923 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Mon, 15 Sep 2025 20:24:49 +0530 Subject: [PATCH 28/36] solved nitpick comments --- cognee/tasks/schema/ingest_database_schema.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index 026bf588c..be9bf6ff1 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -43,6 +43,10 @@ async def ingest_database_schema( schema_tables = [] schema_relationships = [] qi = engine.engine.dialect.identifier_preparer.quote + try: + max_sample_rows = max(0, int(max_sample_rows)) + except (TypeError, ValueError): + max_sample_rows = 0 def qname(name: str): split_name = name.split(".") @@ -53,8 +57,8 @@ async def ingest_database_schema( tn = qname(table_name) if max_sample_rows > 0: rows_result = await cursor.execute( - text(f"SELECT * FROM {tn} LIMIT :limit;"), - {"limit": max_sample_rows}, # noqa: S608 - tn is fully quoted + text(f"SELECT * FROM {tn} LIMIT :limit;"), # noqa: S608 - tn is fully quoted + {"limit": max_sample_rows}, ) rows = [dict(r) for r in rows_result.mappings().all()] else: @@ -67,7 +71,7 @@ async def ingest_database_schema( schema_part, table_part = "public", table_name estimate = await cursor.execute( text( - "SELECT reltuples:bigint AS estimate " + "SELECT reltuples::bigint AS estimate " "FROM pg_class c " "JOIN pg_namespace n ON n.oid = c.relnamespace " "WHERE n.nspname = :schema AND c.relname = :table" From 656894370e38ef9bf795c1c2b7e448ac7fa109e8 Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Sat, 20 Sep 2025 11:05:39 +0530 Subject: [PATCH 29/36] Edited test_relation_db_migration.py to include schema_only ingestion testcase --- cognee/tests/test_relational_db_migration.py | 75 ++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index 68b46dbf5..cb360f1c2 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -197,6 +197,79 @@ async def relational_db_migration(): print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!") +async def test_schema_only_migration(): + # 1. Setup test DB and extract schema + migration_engine = await setup_test_db() + schema = await migration_engine.extract_schema() + + # 2. Setup graph engine + graph_engine = await get_graph_engine() + + # 4. Migrate schema only + await migrate_relational_database(graph_engine, schema=schema, schema_only=True) + + # 5. Verify number of tables through search + search_results = await cognee.search( + query_text="How many tables are there in this database", + query_type=cognee.SearchType.GRAPH_COMPLETION, + ) + assert any("11" in r for r in search_results), ( + "Number of tables in the database reported in search_results is either None or not equal to 11" + ) + + graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower() + + edge_counts = { + "is_part_of": 0, + "has_relationship": 0, + "foreign_key": 0, + } + + if graph_db_provider == "neo4j": + for rel_type in edge_counts.keys(): + query_str = f""" + MATCH ()-[r:{rel_type}]->() + RETURN count(r) as c + """ + rows = await graph_engine.query(query_str) + edge_counts[rel_type] = rows[0]["c"] + + elif graph_db_provider == "kuzu": + for rel_type in edge_counts.keys(): + query_str = f""" + MATCH ()-[r:EDGE]->() + WHERE r.relationship_name = '{rel_type}' + RETURN count(r) as c + """ + rows = await graph_engine.query(query_str) + edge_counts[rel_type] = rows[0][0] + + elif graph_db_provider == "networkx": + nodes, edges = await graph_engine.get_graph_data() + for _, _, key, _ in edges: + if key in edge_counts: + edge_counts[key] += 1 + + else: + raise ValueError(f"Unsupported graph database provider: {graph_db_provider}") + + # 7. Assert counts match expected values + expected_counts = { + "is_part_of": 11, + "has_relationship": 22, + "foreign_key": 11, + } + + for rel_type, expected in expected_counts.items(): + actual = edge_counts[rel_type] + assert actual == expected, ( + f"Expected {expected} edges for relationship '{rel_type}', but found {actual}" + ) + + print("Schema-only migration edge counts validated successfully!") + print(f"Edge counts: {edge_counts}") + + async def test_migration_sqlite(): database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/") @@ -209,6 +282,7 @@ async def test_migration_sqlite(): ) await relational_db_migration() + await test_schema_only_migration() async def test_migration_postgres(): @@ -224,6 +298,7 @@ async def test_migration_postgres(): } ) await relational_db_migration() + await test_schema_only_migration() async def main(): From 2921021ca309f631aae49b686f4ec971ff24b0fe Mon Sep 17 00:00:00 2001 From: Geoff-Robin Date: Fri, 26 Sep 2025 00:58:43 +0530 Subject: [PATCH 30/36] improved code readability by splitting code blocks under conditional statements into separate functions --- .../ingestion/migrate_relational_database.py | 462 +++++++++--------- 1 file changed, 238 insertions(+), 224 deletions(-) diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index 824fef2fa..ffebf442d 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -9,7 +9,6 @@ from cognee.infrastructure.databases.relational.config import get_migration_conf from cognee.tasks.storage.index_data_points import index_data_points from cognee.tasks.storage.index_graph_edges import index_graph_edges from cognee.tasks.schema.ingest_database_schema import ingest_database_schema -from cognee.tasks.schema.models import SchemaTable from cognee.modules.engine.models import TableRow, TableType, ColumnValue @@ -31,236 +30,15 @@ async def migrate_relational_database( Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model. """ - engine = get_migration_relational_engine() # Create a mapping of node_id to node objects for referencing in edge creation node_mapping = {} edge_mapping = [] if schema_only: - database_config = get_migration_config().to_dict() - # Calling the ingest_database_schema function to return DataPoint subclasses - result = await ingest_database_schema( - database_config=database_config, - schema_name="migrated_schema", - max_sample_rows=5, - ) - database_schema = result["database_schema"] - schema_tables = result["schema_tables"] - schema_relationships = result["relationships"] - database_node_id = database_schema.id - node_mapping[database_node_id] = database_schema - for table in schema_tables: - table_node_id = table.id - # Add TableSchema Datapoint as a node. - node_mapping[table_node_id] = table - edge_mapping.append( - ( - table_node_id, - database_node_id, - "is_part_of", - dict( - source_node_id=table_node_id, - target_node_id=database_node_id, - relationship_name="is_part_of", - ), - ) - ) - table_name_to_id = {t.table_name: t.id for t in schema_tables} - for rel in schema_relationships: - source_table_id = table_name_to_id.get(rel.source_table) - target_table_id = table_name_to_id.get(rel.target_table) - - relationship_id = rel.id - - # Add RelationshipTable DataPoint as a node. - node_mapping[relationship_id] = rel - edge_mapping.append( - ( - source_table_id, - relationship_id, - "has_relationship", - dict( - source_node_id=source_table_id, - target_node_id=relationship_id, - relationship_name=rel.relationship_type, - ), - ) - ) - edge_mapping.append( - ( - relationship_id, - target_table_id, - "has_relationship", - dict( - source_node_id=relationship_id, - target_node_id=target_table_id, - relationship_name=rel.relationship_type, - ), - ) - ) - edge_mapping.append( - ( - source_table_id, - target_table_id, - rel.relationship_type, - dict( - source_node_id=source_table_id, - target_node_id=target_table_id, - relationship_name=rel.relationship_type, - ), - ) - ) + node_mapping, edge_mapping = await schema_only_ingestion() else: - async with engine.engine.begin() as cursor: - # First, create table type nodes for all tables - for table_name, details in schema.items(): - # Create a TableType node for each table - table_node = TableType( - id=uuid5(NAMESPACE_OID, name=table_name), - name=table_name, - description=f"Table: {table_name}", - ) - - # Add TableType node to mapping ( node will be added to the graph later based on this mapping ) - node_mapping[table_name] = table_node - - # Fetch all rows for the current table - rows_result = await cursor.execute(text(f"SELECT * FROM {table_name};")) - rows = rows_result.fetchall() - - for row in rows: - # Build a dictionary of properties from the row - row_properties = { - col["name"]: row[idx] for idx, col in enumerate(details["columns"]) - } - - # Determine the primary key value - if not details["primary_key"]: - # Use the first column as primary key if not specified - primary_key_col = details["columns"][0]["name"] - primary_key_value = row_properties[primary_key_col] - else: - # Use value of the specified primary key column - primary_key_col = details["primary_key"] - primary_key_value = row_properties[primary_key_col] - - # Create a node ID in the format "table_name:primary_key_value" - node_id = f"{table_name}:{primary_key_value}" - - # Create a TableRow node - # Node id must uniquely map to the id used in the relational database - # To catch the foreign key relationships properly - row_node = TableRow( - id=uuid5(NAMESPACE_OID, name=node_id), - name=node_id, - is_a=table_node, - properties=str(row_properties), - description=f"Row in {table_name} with {primary_key_col}={primary_key_value}", - ) - - # Store the node object in our mapping - node_mapping[node_id] = row_node - - # Add edge between row node and table node ( it will be added to the graph later ) - edge_mapping.append( - ( - row_node.id, - table_node.id, - "is_part_of", - dict( - relationship_name="is_part_of", - source_node_id=row_node.id, - target_node_id=table_node.id, - ), - ) - ) - - # Migrate data stored in columns of table rows - if migrate_column_data: - # Get foreign key columns to filter them out from column migration - foreign_keys = [] - for fk in details.get("foreign_keys", []): - foreign_keys.append(fk["ref_column"]) - - for key, value in row_properties.items(): - # Skip mapping primary key information to itself and mapping of foreign key information (as it will be mapped bellow) - if key is primary_key_col or key in foreign_keys: - continue - - # Create column value node - column_node_id = f"{table_name}:{key}:{value}" - column_node = ColumnValue( - id=uuid5(NAMESPACE_OID, name=column_node_id), - name=column_node_id, - properties=f"{key} {value} {table_name}", - description=f"Column name={key} and value={value} from column from table={table_name}", - ) - node_mapping[column_node_id] = column_node - - # Create relationship between column value of table row and table row - edge_mapping.append( - ( - row_node.id, - column_node.id, - key, - dict( - relationship_name=key, - source_node_id=row_node.id, - target_node_id=column_node.id, - ), - ) - ) - - # Process foreign key relationships after all nodes are created - for table_name, details in schema.items(): - # Process foreign key relationships for the current table - for fk in details.get("foreign_keys", []): - # Aliases needed for self-referencing tables - alias_1 = f"{table_name}_e1" - alias_2 = f"{fk['ref_table']}_e2" - - # Determine primary key column - if not details["primary_key"]: - primary_key_col = details["columns"][0]["name"] - else: - primary_key_col = details["primary_key"] - - # Query to find relationships based on foreign keys - fk_query = text( - f"SELECT {alias_1}.{primary_key_col} AS source_id, " - f"{alias_2}.{fk['ref_column']} AS ref_value " - f"FROM {table_name} AS {alias_1} " - f"JOIN {fk['ref_table']} AS {alias_2} " - f"ON {alias_1}.{fk['column']} = {alias_2}.{fk['ref_column']};" - ) - - fk_result = await cursor.execute(fk_query) - relations = fk_result.fetchall() - - for source_id, ref_value in relations: - # Construct node ids - source_node_id = f"{table_name}:{source_id}" - target_node_id = f"{fk['ref_table']}:{ref_value}" - - # Get the source and target node objects from our mapping - source_node = node_mapping[source_node_id] - target_node = node_mapping[target_node_id] - - # Add edge representing the foreign key relationship using the node objects - # Create edge to add to graph later - edge_mapping.append( - ( - source_node.id, - target_node.id, - fk["column"], - dict( - source_node_id=source_node.id, - target_node_id=target_node.id, - relationship_name=fk["column"], - ), - ) - ) + node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data) def _remove_duplicate_edges(edge_mapping): seen = set() @@ -297,3 +75,239 @@ async def migrate_relational_database( logger.info("Data successfully migrated from relational database to desired graph database.") return await graph_db.get_graph_data() + + +async def schema_only_ingestion(): + node_mapping = {} + edge_mapping = [] + database_config = get_migration_config().to_dict() + # Calling the ingest_database_schema function to return DataPoint subclasses + result = await ingest_database_schema( + database_config=database_config, + schema_name="migrated_schema", + max_sample_rows=5, + ) + database_schema = result["database_schema"] + schema_tables = result["schema_tables"] + schema_relationships = result["relationships"] + database_node_id = database_schema.id + node_mapping[database_node_id] = database_schema + for table in schema_tables: + table_node_id = table.id + # Add TableSchema Datapoint as a node. + node_mapping[table_node_id] = table + edge_mapping.append( + ( + table_node_id, + database_node_id, + "is_part_of", + dict( + source_node_id=table_node_id, + target_node_id=database_node_id, + relationship_name="is_part_of", + ), + ) + ) + table_name_to_id = {t.table_name: t.id for t in schema_tables} + for rel in schema_relationships: + source_table_id = table_name_to_id.get(rel.source_table) + target_table_id = table_name_to_id.get(rel.target_table) + + relationship_id = rel.id + + # Add RelationshipTable DataPoint as a node. + node_mapping[relationship_id] = rel + edge_mapping.append( + ( + source_table_id, + relationship_id, + "has_relationship", + dict( + source_node_id=source_table_id, + target_node_id=relationship_id, + relationship_name=rel.relationship_type, + ), + ) + ) + edge_mapping.append( + ( + relationship_id, + target_table_id, + "has_relationship", + dict( + source_node_id=relationship_id, + target_node_id=target_table_id, + relationship_name=rel.relationship_type, + ), + ) + ) + edge_mapping.append( + ( + source_table_id, + target_table_id, + rel.relationship_type, + dict( + source_node_id=source_table_id, + target_node_id=target_table_id, + relationship_name=rel.relationship_type, + ), + ) + ) + return node_mapping, edge_mapping + + +async def complete_database_ingestion(schema, migrate_column_data): + engine = get_migration_relational_engine() + # Create a mapping of node_id to node objects for referencing in edge creation + node_mapping = {} + edge_mapping = [] + async with engine.engine.begin() as cursor: + # First, create table type nodes for all tables + for table_name, details in schema.items(): + # Create a TableType node for each table + table_node = TableType( + id=uuid5(NAMESPACE_OID, name=table_name), + name=table_name, + description=f"Table: {table_name}", + ) + + # Add TableType node to mapping ( node will be added to the graph later based on this mapping ) + node_mapping[table_name] = table_node + + # Fetch all rows for the current table + rows_result = await cursor.execute(text(f"SELECT * FROM {table_name};")) + rows = rows_result.fetchall() + + for row in rows: + # Build a dictionary of properties from the row + row_properties = { + col["name"]: row[idx] for idx, col in enumerate(details["columns"]) + } + + # Determine the primary key value + if not details["primary_key"]: + # Use the first column as primary key if not specified + primary_key_col = details["columns"][0]["name"] + primary_key_value = row_properties[primary_key_col] + else: + # Use value of the specified primary key column + primary_key_col = details["primary_key"] + primary_key_value = row_properties[primary_key_col] + + # Create a node ID in the format "table_name:primary_key_value" + node_id = f"{table_name}:{primary_key_value}" + + # Create a TableRow node + # Node id must uniquely map to the id used in the relational database + # To catch the foreign key relationships properly + row_node = TableRow( + id=uuid5(NAMESPACE_OID, name=node_id), + name=node_id, + is_a=table_node, + properties=str(row_properties), + description=f"Row in {table_name} with {primary_key_col}={primary_key_value}", + ) + + # Store the node object in our mapping + node_mapping[node_id] = row_node + + # Add edge between row node and table node ( it will be added to the graph later ) + edge_mapping.append( + ( + row_node.id, + table_node.id, + "is_part_of", + dict( + relationship_name="is_part_of", + source_node_id=row_node.id, + target_node_id=table_node.id, + ), + ) + ) + + # Migrate data stored in columns of table rows + if migrate_column_data: + # Get foreign key columns to filter them out from column migration + foreign_keys = [] + for fk in details.get("foreign_keys", []): + foreign_keys.append(fk["ref_column"]) + + for key, value in row_properties.items(): + # Skip mapping primary key information to itself and mapping of foreign key information (as it will be mapped bellow) + if key is primary_key_col or key in foreign_keys: + continue + + # Create column value node + column_node_id = f"{table_name}:{key}:{value}" + column_node = ColumnValue( + id=uuid5(NAMESPACE_OID, name=column_node_id), + name=column_node_id, + properties=f"{key} {value} {table_name}", + description=f"Column name={key} and value={value} from column from table={table_name}", + ) + node_mapping[column_node_id] = column_node + + # Create relationship between column value of table row and table row + edge_mapping.append( + ( + row_node.id, + column_node.id, + key, + dict( + relationship_name=key, + source_node_id=row_node.id, + target_node_id=column_node.id, + ), + ) + ) + + # Process foreign key relationships after all nodes are created + for table_name, details in schema.items(): + # Process foreign key relationships for the current table + for fk in details.get("foreign_keys", []): + # Aliases needed for self-referencing tables + alias_1 = f"{table_name}_e1" + alias_2 = f"{fk['ref_table']}_e2" + + # Determine primary key column + if not details["primary_key"]: + primary_key_col = details["columns"][0]["name"] + else: + primary_key_col = details["primary_key"] + + # Query to find relationships based on foreign keys + fk_query = text( + f"SELECT {alias_1}.{primary_key_col} AS source_id, " + f"{alias_2}.{fk['ref_column']} AS ref_value " + f"FROM {table_name} AS {alias_1} " + f"JOIN {fk['ref_table']} AS {alias_2} " + f"ON {alias_1}.{fk['column']} = {alias_2}.{fk['ref_column']};" + ) + + fk_result = await cursor.execute(fk_query) + relations = fk_result.fetchall() + + for source_id, ref_value in relations: + # Construct node ids + source_node_id = f"{table_name}:{source_id}" + target_node_id = f"{fk['ref_table']}:{ref_value}" + + # Get the source and target node objects from our mapping + source_node = node_mapping[source_node_id] + target_node = node_mapping[target_node_id] + + # Add edge representing the foreign key relationship using the node objects + # Create edge to add to graph later + edge_mapping.append( + ( + source_node.id, + target_node.id, + fk["column"], + dict( + source_node_id=source_node.id, + target_node_id=target_node.id, + relationship_name=fk["column"], + ), + ) + ) + return node_mapping, edge_mapping From 920bc78f151aa8c6e75334881b8e5cc52fdf814f Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Sat, 27 Sep 2025 00:18:57 +0200 Subject: [PATCH 31/36] refactor: Remove unused code --- cognee/tasks/ingestion/migrate_relational_database.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index ffebf442d..7ea08d5e0 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -31,9 +31,6 @@ async def migrate_relational_database( Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model. """ # Create a mapping of node_id to node objects for referencing in edge creation - node_mapping = {} - edge_mapping = [] - if schema_only: node_mapping, edge_mapping = await schema_only_ingestion() From f93d30ae77f7232e03110a817226539d5eb4d483 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Sat, 27 Sep 2025 00:41:58 +0200 Subject: [PATCH 32/36] refactor: refactor schema migration --- .../ingestion/migrate_relational_database.py | 8 +++--- cognee/tasks/schema/ingest_database_schema.py | 27 ++++++++----------- cognee/tests/test_relational_db_migration.py | 1 + 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index 5ee9f5973..83ad452c3 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -32,7 +32,7 @@ async def migrate_relational_database( """ # Create a mapping of node_id to node objects for referencing in edge creation if schema_only: - node_mapping, edge_mapping = await schema_only_ingestion() + node_mapping, edge_mapping = await schema_only_ingestion(schema) else: node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data) @@ -74,13 +74,13 @@ async def migrate_relational_database( return await graph_db.get_graph_data() -async def schema_only_ingestion(): +async def schema_only_ingestion(schema): node_mapping = {} edge_mapping = [] - database_config = get_migration_config().to_dict() + # Calling the ingest_database_schema function to return DataPoint subclasses result = await ingest_database_schema( - database_config=database_config, + schema=schema, schema_name="migrated_schema", max_sample_rows=5, ) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index be9bf6ff1..e89b679d2 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -3,14 +3,15 @@ from uuid import uuid5, NAMESPACE_OID from cognee.infrastructure.engine.models.DataPoint import DataPoint from sqlalchemy import text from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship -from cognee.infrastructure.databases.relational.create_relational_engine import ( - create_relational_engine, +from cognee.infrastructure.databases.relational.get_migration_relational_engine import ( + get_migration_relational_engine, ) +from cognee.infrastructure.databases.relational.config import get_migration_config from datetime import datetime, timezone async def ingest_database_schema( - database_config: Dict, + schema, schema_name: str = "default", max_sample_rows: int = 0, ) -> Dict[str, List[DataPoint] | DataPoint]: @@ -28,20 +29,13 @@ async def ingest_database_schema( "schema_tables": List[SchemaTable] "relationships": List[SchemaRelationship] """ - engine = create_relational_engine( - db_path=database_config.get("migration_db_path", ""), - db_name=database_config.get("migration_db_name", "cognee_db"), - db_host=database_config.get("migration_db_host"), - db_port=database_config.get("migration_db_port"), - db_username=database_config.get("migration_db_username"), - db_password=database_config.get("migration_db_password"), - db_provider=database_config.get("migration_db_provider", "sqlite"), - ) - schema = await engine.extract_schema() + tables = {} sample_data = {} schema_tables = [] schema_relationships = [] + + engine = get_migration_relational_engine() qi = engine.engine.dialect.identifier_preparer.quote try: max_sample_rows = max(0, int(max_sample_rows)) @@ -63,7 +57,7 @@ async def ingest_database_schema( rows = [dict(r) for r in rows_result.mappings().all()] else: rows = [] - row_count_estimate = 0 + if engine.engine.dialect.name == "postgresql": if "." in table_name: schema_part, table_part = table_name.split(".", 1) @@ -117,11 +111,12 @@ async def ingest_database_schema( ) schema_relationships.append(relationship) - id_str = f"{database_config.get('migration_db_provider', 'sqlite')}:{database_config.get('migration_db_name', 'cognee_db')}:{schema_name}" + migration_config = get_migration_config() + id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}:{schema_name}" database_schema = DatabaseSchema( id=uuid5(NAMESPACE_OID, name=id_str), schema_name=schema_name, - database_type=database_config.get("migration_db_provider", "sqlite"), + database_type=migration_config.migration_db_provider, tables=tables, sample_data=sample_data, extraction_timestamp=datetime.now(timezone.utc), diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index cb360f1c2..2b69ce854 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -212,6 +212,7 @@ async def test_schema_only_migration(): search_results = await cognee.search( query_text="How many tables are there in this database", query_type=cognee.SearchType.GRAPH_COMPLETION, + top_k=30, ) assert any("11" in r for r in search_results), ( "Number of tables in the database reported in search_results is either None or not equal to 11" From 17fb3b49efce5801dc37341bd40967f0fb202017 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Sat, 27 Sep 2025 01:15:30 +0200 Subject: [PATCH 33/36] refactor: add visualization to schema migration --- .../cognee_network_visualization.py | 3 ++ .../ingestion/migrate_relational_database.py | 3 +- cognee/tasks/schema/ingest_database_schema.py | 36 ++++++++++--------- cognee/tasks/schema/models.py | 12 +++---- 4 files changed, 30 insertions(+), 24 deletions(-) diff --git a/cognee/modules/visualization/cognee_network_visualization.py b/cognee/modules/visualization/cognee_network_visualization.py index bbdbc0019..c735e70f1 100644 --- a/cognee/modules/visualization/cognee_network_visualization.py +++ b/cognee/modules/visualization/cognee_network_visualization.py @@ -23,6 +23,9 @@ async def cognee_network_visualization(graph_data, destination_file_path: str = "TableRow": "#f47710", "TableType": "#6510f4", "ColumnValue": "#13613a", + "SchemaTable": "#f47710", + "DatabaseSchema": "#6510f4", + "SchemaRelationship": "#13613a", "default": "#D3D3D3", } diff --git a/cognee/tasks/ingestion/migrate_relational_database.py b/cognee/tasks/ingestion/migrate_relational_database.py index 83ad452c3..53ce176e8 100644 --- a/cognee/tasks/ingestion/migrate_relational_database.py +++ b/cognee/tasks/ingestion/migrate_relational_database.py @@ -81,7 +81,6 @@ async def schema_only_ingestion(schema): # Calling the ingest_database_schema function to return DataPoint subclasses result = await ingest_database_schema( schema=schema, - schema_name="migrated_schema", max_sample_rows=5, ) database_schema = result["database_schema"] @@ -105,7 +104,7 @@ async def schema_only_ingestion(schema): ), ) ) - table_name_to_id = {t.table_name: t.id for t in schema_tables} + table_name_to_id = {t.name: t.id for t in schema_tables} for rel in schema_relationships: source_table_id = table_name_to_id.get(rel.source_table) target_table_id = table_name_to_id.get(rel.target_table) diff --git a/cognee/tasks/schema/ingest_database_schema.py b/cognee/tasks/schema/ingest_database_schema.py index e89b679d2..e3823701c 100644 --- a/cognee/tasks/schema/ingest_database_schema.py +++ b/cognee/tasks/schema/ingest_database_schema.py @@ -12,15 +12,13 @@ from datetime import datetime, timezone async def ingest_database_schema( schema, - schema_name: str = "default", max_sample_rows: int = 0, ) -> Dict[str, List[DataPoint] | DataPoint]: """ Extract database schema metadata (optionally with sample data) and return DataPoint models for graph construction. Args: - database_config: Database connection configuration - schema_name: Name identifier for this schema + schema: Database schema max_sample_rows: Maximum sample rows per table (0 means no sampling) Returns: @@ -35,6 +33,7 @@ async def ingest_database_schema( schema_tables = [] schema_relationships = [] + migration_config = get_migration_config() engine = get_migration_relational_engine() qi = engine.engine.dialect.identifier_preparer.quote try: @@ -78,15 +77,17 @@ async def ingest_database_schema( row_count_estimate = count_result.scalar() schema_table = SchemaTable( - id=uuid5(NAMESPACE_OID, name=f"{schema_name}:{table_name}"), - table_name=table_name, - schema_name=schema_name, + id=uuid5(NAMESPACE_OID, name=f"{table_name}"), + name=table_name, columns=details["columns"], primary_key=details.get("primary_key"), foreign_keys=details.get("foreign_keys", []), sample_rows=rows, row_count_estimate=row_count_estimate, - description=f"Schema table for '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows.", + description=f"Relational database table with '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows." + f"Here are the columns this table contains: {details['columns']}" + f"Here are a few sample_rows to show the contents of the table: {rows}" + f"Table is part of the database: {migration_config.migration_db_name}", ) schema_tables.append(schema_table) tables[table_name] = details @@ -97,30 +98,33 @@ async def ingest_database_schema( if "." not in ref_table_fq and "." in table_name: ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}" + relationship_name = ( + f"{table_name}:{fk['column']}->{ref_table_fq}:{fk['ref_column']}" + ) relationship = SchemaRelationship( - id=uuid5( - NAMESPACE_OID, - name=f"{schema_name}:{table_name}:{fk['column']}->{ref_table_fq}:{fk['ref_column']}", - ), + id=uuid5(NAMESPACE_OID, name=relationship_name), + name=relationship_name, source_table=table_name, target_table=ref_table_fq, relationship_type="foreign_key", source_column=fk["column"], target_column=fk["ref_column"], - description=f"Foreign key relationship: {table_name}.{fk['column']} → {ref_table_fq}.{fk['ref_column']}", + description=f"Relational database table foreign key relationship between: {table_name}.{fk['column']} → {ref_table_fq}.{fk['ref_column']}" + f"This foreing key relationship between table columns is a part of the following database: {migration_config.migration_db_name}", ) schema_relationships.append(relationship) - migration_config = get_migration_config() - id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}:{schema_name}" + id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}" database_schema = DatabaseSchema( id=uuid5(NAMESPACE_OID, name=id_str), - schema_name=schema_name, + name=migration_config.migration_db_name, database_type=migration_config.migration_db_provider, tables=tables, sample_data=sample_data, extraction_timestamp=datetime.now(timezone.utc), - description=f"Database schema '{schema_name}' containing {len(schema_tables)} tables and {len(schema_relationships)} relationships.", + description=f"Database schema containing {len(schema_tables)} tables and {len(schema_relationships)} relationships. " + f"The database type is {migration_config.migration_db_provider}." + f"The database contains the following tables: {tables}", ) return { diff --git a/cognee/tasks/schema/models.py b/cognee/tasks/schema/models.py index 423c92050..4b13f420b 100644 --- a/cognee/tasks/schema/models.py +++ b/cognee/tasks/schema/models.py @@ -6,36 +6,36 @@ from datetime import datetime class DatabaseSchema(DataPoint): """Represents a complete database schema with sample data""" - schema_name: str + name: str database_type: str # sqlite, postgres, etc. tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter sample_data: Dict[str, List[Dict]] # Limited examples per table extraction_timestamp: datetime description: str - metadata: dict = {"index_fields": ["schema_name", "database_type"]} + metadata: dict = {"index_fields": ["description", "name"]} class SchemaTable(DataPoint): """Represents an individual table schema with relationships""" - table_name: str - schema_name: str + name: str columns: List[Dict] # Column definitions with types primary_key: Optional[str] foreign_keys: List[Dict] # Foreign key relationships sample_rows: List[Dict] # Max 3-5 example rows row_count_estimate: Optional[int] # Actual table size description: str - metadata: dict = {"index_fields": ["table_name", "schema_name"]} + metadata: dict = {"index_fields": ["description", "name"]} class SchemaRelationship(DataPoint): """Represents relationships between tables""" + name: str source_table: str target_table: str relationship_type: str # "foreign_key", "one_to_many", etc. source_column: str target_column: str description: str - metadata: dict = {"index_fields": ["source_table", "target_table"]} + metadata: dict = {"index_fields": ["description", "name"]} From dc1669a948eaeb04028129e42ee1b2e39e5f9e25 Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Sat, 27 Sep 2025 19:31:39 +0100 Subject: [PATCH 34/36] feat: add CORS middleware support for SSE and HTTP transports in MCP server --- cognee-mcp/src/server.py | 60 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index f249f1d08..33cd26cb1 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -20,6 +20,9 @@ from cognee.modules.search.types import SearchType from cognee.shared.data_models import KnowledgeGraph from cognee.modules.storage.utils import JSONEncoder from starlette.responses import JSONResponse +from starlette.middleware import Middleware +from starlette.middleware.cors import CORSMiddleware +import uvicorn try: @@ -39,8 +42,59 @@ mcp = FastMCP("Cognee") logger = get_logger() +cors_middleware = Middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + + +async def run_sse_with_cors(): + """Custom SSE transport with CORS middleware.""" + sse_app = mcp.sse_app() + sse_app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["*"], + ) + + config = uvicorn.Config( + sse_app, + host=mcp.settings.host, + port=mcp.settings.port, + log_level=mcp.settings.log_level.lower(), + ) + server = uvicorn.Server(config) + await server.serve() + + +async def run_http_with_cors(): + """Custom HTTP transport with CORS middleware.""" + http_app = mcp.streamable_http_app() + http_app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["*"], + ) + + config = uvicorn.Config( + http_app, + host=mcp.settings.host, + port=mcp.settings.port, + log_level=mcp.settings.log_level.lower(), + ) + server = uvicorn.Server(config) + await server.serve() + + @mcp.custom_route("/health", methods=["GET"]) -async def health_check(request) -> dict: +async def health_check(request): return JSONResponse({"status": "ok"}) @@ -981,12 +1035,12 @@ async def main(): await mcp.run_stdio_async() elif args.transport == "sse": logger.info(f"Running MCP server with SSE transport on {args.host}:{args.port}") - await mcp.run_sse_async() + await run_sse_with_cors() elif args.transport == "http": logger.info( f"Running MCP server with Streamable HTTP transport on {args.host}:{args.port}{args.path}" ) - await mcp.run_streamable_http_async() + await run_http_with_cors() if __name__ == "__main__": From c0d2abdf5e7758c8fa9c94c68b6fc8d4351ceab9 Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Sat, 27 Sep 2025 19:31:56 +0100 Subject: [PATCH 35/36] feat: implement MCP connection health check in header component --- cognee-frontend/src/ui/Layout/Header.tsx | 25 +++++++++++++++++++----- cognee-frontend/src/utils/fetch.ts | 6 ++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/cognee-frontend/src/ui/Layout/Header.tsx b/cognee-frontend/src/ui/Layout/Header.tsx index 30bf7ddb0..1bc57f699 100644 --- a/cognee-frontend/src/ui/Layout/Header.tsx +++ b/cognee-frontend/src/ui/Layout/Header.tsx @@ -2,7 +2,8 @@ import Link from "next/link"; import Image from "next/image"; -import { useBoolean } from "@/utils"; +import { useEffect } from "react"; +import { useBoolean, fetch } from "@/utils"; import { CloseIcon, CloudIcon, CogneeIcon } from "../Icons"; import { CTAButton, GhostButton, IconButton, Modal, StatusDot } from "../elements"; @@ -24,8 +25,9 @@ export default function Header({ user }: HeaderProps) { } = useBoolean(false); const { - value: isMCPStatusOpen, - setTrue: setMCPStatusOpen, + value: isMCPConnected, + setTrue: setMCPConnected, + setFalse: setMCPDisconnected, } = useBoolean(false); const handleDataSyncConfirm = () => { @@ -35,6 +37,19 @@ export default function Header({ user }: HeaderProps) { }); }; + useEffect(() => { + const checkMCPConnection = () => { + fetch.checkMCPHealth() + .then(() => setMCPConnected()) + .catch(() => setMCPDisconnected()); + }; + + checkMCPConnection(); + const interval = setInterval(checkMCPConnection, 30000); + + return () => clearInterval(interval); + }, [setMCPConnected, setMCPDisconnected]); + return ( <>
@@ -45,8 +60,8 @@ export default function Header({ user }: HeaderProps) {
- - { isMCPStatusOpen ? "MCP connected" : "MCP disconnected" } + + { isMCPConnected ? "MCP connected" : "MCP disconnected" } diff --git a/cognee-frontend/src/utils/fetch.ts b/cognee-frontend/src/utils/fetch.ts index 246853fb9..e67845d78 100644 --- a/cognee-frontend/src/utils/fetch.ts +++ b/cognee-frontend/src/utils/fetch.ts @@ -9,6 +9,8 @@ const backendApiUrl = process.env.NEXT_PUBLIC_BACKEND_API_URL || "http://localho const cloudApiUrl = process.env.NEXT_PUBLIC_CLOUD_API_URL || "http://localhost:8001"; +const mcpApiUrl = process.env.NEXT_PUBLIC_MCP_API_URL || "http://localhost:8001"; + let apiKey: string | null = process.env.NEXT_PUBLIC_COGWIT_API_KEY || null; let accessToken: string | null = null; @@ -66,6 +68,10 @@ fetch.checkHealth = () => { return global.fetch(`${backendApiUrl.replace("/api", "")}/health`); }; +fetch.checkMCPHealth = () => { + return global.fetch(`${mcpApiUrl.replace("/api", "")}/health`); +}; + fetch.setApiKey = (newApiKey: string) => { apiKey = newApiKey; }; From 0fac104fc7ec6d757b3fc10a3bd1e114384b7a10 Mon Sep 17 00:00:00 2001 From: Daulet Amirkhanov Date: Sat, 27 Sep 2025 20:11:39 +0100 Subject: [PATCH 36/36] fix: update UI server startup message to reflect dynamic frontend port --- cognee-mcp/src/server.py | 17 ++++------------- cognee/cli/_cognee.py | 2 +- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index 33cd26cb1..cc6eac09e 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -42,23 +42,14 @@ mcp = FastMCP("Cognee") logger = get_logger() -cors_middleware = Middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], - allow_headers=["*"], -) - - async def run_sse_with_cors(): """Custom SSE transport with CORS middleware.""" sse_app = mcp.sse_app() sse_app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=["http://localhost:3000"], allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_methods=["GET"], allow_headers=["*"], ) @@ -77,9 +68,9 @@ async def run_http_with_cors(): http_app = mcp.streamable_http_app() http_app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=["http://localhost:3000"], allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_methods=["GET"], allow_headers=["*"], ) diff --git a/cognee/cli/_cognee.py b/cognee/cli/_cognee.py index 7f2b06c89..b68e5c80f 100644 --- a/cognee/cli/_cognee.py +++ b/cognee/cli/_cognee.py @@ -220,7 +220,7 @@ def main() -> int: if server_process: fmt.success("UI server started successfully!") - fmt.echo("The interface is available at: http://localhost:3000") + fmt.echo(f"The interface is available at: http://localhost:{frontend_port}") if start_backend: fmt.echo(f"The API backend is available at: http://localhost:{backend_port}") if start_mcp: