Merge branch 'dev' into merge-main-vol6

This commit is contained in:
Vasilije 2025-09-28 15:29:23 +02:00 committed by GitHub
commit 7403e31738
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 728 additions and 212 deletions

View file

@ -203,6 +203,16 @@ LITELLM_LOG="ERROR"
# DEFAULT_USER_EMAIL=""
# DEFAULT_USER_PASSWORD=""
################################################################################
# 📂 AWS Settings
################################################################################
#AWS_REGION=""
#AWS_ENDPOINT_URL=""
#AWS_ACCESS_KEY_ID=""
#AWS_SECRET_ACCESS_KEY=""
#AWS_SESSION_TOKEN=""
------------------------------- END OF POSSIBLE SETTINGS -------------------------------

View file

@ -2,10 +2,11 @@
import Link from "next/link";
import Image from "next/image";
import { useBoolean } from "@/utils";
import { useEffect } from "react";
import { useBoolean, fetch } from "@/utils";
import { CloseIcon, CloudIcon, CogneeIcon } from "../Icons";
import { CTAButton, GhostButton, IconButton, Modal } from "../elements";
import { CTAButton, GhostButton, IconButton, Modal, StatusDot } from "../elements";
import syncData from "@/modules/cloud/syncData";
interface HeaderProps {
@ -23,6 +24,12 @@ export default function Header({ user }: HeaderProps) {
setFalse: closeSyncModal,
} = useBoolean(false);
const {
value: isMCPConnected,
setTrue: setMCPConnected,
setFalse: setMCPDisconnected,
} = useBoolean(false);
const handleDataSyncConfirm = () => {
syncData()
.finally(() => {
@ -30,6 +37,19 @@ export default function Header({ user }: HeaderProps) {
});
};
useEffect(() => {
const checkMCPConnection = () => {
fetch.checkMCPHealth()
.then(() => setMCPConnected())
.catch(() => setMCPDisconnected());
};
checkMCPConnection();
const interval = setInterval(checkMCPConnection, 30000);
return () => clearInterval(interval);
}, [setMCPConnected, setMCPDisconnected]);
return (
<>
<header className="relative flex flex-row h-14 min-h-14 px-5 items-center justify-between w-full max-w-[1920px] mx-auto">
@ -39,6 +59,10 @@ export default function Header({ user }: HeaderProps) {
</div>
<div className="flex flex-row items-center gap-2.5">
<Link href="/mcp-status" className="!text-indigo-600 pl-4 pr-4">
<StatusDot className="mr-2" isActive={isMCPConnected} />
{ isMCPConnected ? "MCP connected" : "MCP disconnected" }
</Link>
<GhostButton onClick={openSyncModal} className="text-indigo-600 gap-3 pl-4 pr-4">
<CloudIcon />
<div>Sync</div>

View file

@ -0,0 +1,13 @@
import React from "react";
const StatusDot = ({ isActive, className }: { isActive: boolean, className?: string }) => {
return (
<span
className={`inline-block w-3 h-3 rounded-full ${className} ${
isActive ? "bg-green-500" : "bg-red-500"
}`}
/>
);
};
export default StatusDot;

View file

@ -8,5 +8,6 @@ export { default as IconButton } from "./IconButton";
export { default as GhostButton } from "./GhostButton";
export { default as NeutralButton } from "./NeutralButton";
export { default as StatusIndicator } from "./StatusIndicator";
export { default as StatusDot } from "./StatusDot";
export { default as Accordion } from "./Accordion";
export { default as Notebook } from "./Notebook";

View file

@ -9,6 +9,8 @@ const backendApiUrl = process.env.NEXT_PUBLIC_BACKEND_API_URL || "http://localho
const cloudApiUrl = process.env.NEXT_PUBLIC_CLOUD_API_URL || "http://localhost:8001";
const mcpApiUrl = process.env.NEXT_PUBLIC_MCP_API_URL || "http://localhost:8001";
let apiKey: string | null = process.env.NEXT_PUBLIC_COGWIT_API_KEY || null;
let accessToken: string | null = null;
@ -66,6 +68,10 @@ fetch.checkHealth = () => {
return global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
};
fetch.checkMCPHealth = () => {
return global.fetch(`${mcpApiUrl.replace("/api", "")}/health`);
};
fetch.setApiKey = (newApiKey: string) => {
apiKey = newApiKey;
};

View file

@ -48,27 +48,27 @@ if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then
if [ "$DEBUG" = "true" ]; then
echo "Waiting for the debugger to attach..."
if [ "$TRANSPORT_MODE" = "sse" ]; then
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
elif [ "$TRANSPORT_MODE" = "http" ]; then
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
else
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport stdio --no-migration
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport stdio --no-migration
fi
else
if [ "$TRANSPORT_MODE" = "sse" ]; then
exec cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
elif [ "$TRANSPORT_MODE" = "http" ]; then
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
else
exec cognee --transport stdio --no-migration
exec cognee-mcp --transport stdio --no-migration
fi
fi
else
if [ "$TRANSPORT_MODE" = "sse" ]; then
exec cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
elif [ "$TRANSPORT_MODE" = "http" ]; then
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
else
exec cognee --transport stdio --no-migration
exec cognee-mcp --transport stdio --no-migration
fi
fi

View file

@ -36,4 +36,4 @@ dev = [
allow-direct-references = true
[project.scripts]
cognee = "src:main"
cognee-mcp = "src:main"

View file

@ -19,6 +19,10 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.search.types import SearchType
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.storage.utils import JSONEncoder
from starlette.responses import JSONResponse
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
import uvicorn
try:
@ -38,6 +42,53 @@ mcp = FastMCP("Cognee")
logger = get_logger()
async def run_sse_with_cors():
"""Custom SSE transport with CORS middleware."""
sse_app = mcp.sse_app()
sse_app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["GET"],
allow_headers=["*"],
)
config = uvicorn.Config(
sse_app,
host=mcp.settings.host,
port=mcp.settings.port,
log_level=mcp.settings.log_level.lower(),
)
server = uvicorn.Server(config)
await server.serve()
async def run_http_with_cors():
"""Custom HTTP transport with CORS middleware."""
http_app = mcp.streamable_http_app()
http_app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["GET"],
allow_headers=["*"],
)
config = uvicorn.Config(
http_app,
host=mcp.settings.host,
port=mcp.settings.port,
log_level=mcp.settings.log_level.lower(),
)
server = uvicorn.Server(config)
await server.serve()
@mcp.custom_route("/health", methods=["GET"])
async def health_check(request):
return JSONResponse({"status": "ok"})
@mcp.tool()
async def cognee_add_developer_rules(
base_path: str = ".", graph_model_file: str = None, graph_model_name: str = None
@ -975,12 +1026,12 @@ async def main():
await mcp.run_stdio_async()
elif args.transport == "sse":
logger.info(f"Running MCP server with SSE transport on {args.host}:{args.port}")
await mcp.run_sse_async()
await run_sse_with_cors()
elif args.transport == "http":
logger.info(
f"Running MCP server with Streamable HTTP transport on {args.host}:{args.port}{args.path}"
)
await mcp.run_streamable_http_async()
await run_http_with_cors()
if __name__ == "__main__":

View file

@ -1 +1 @@
from .ui import start_ui, stop_ui, ui
from .ui import start_ui

View file

@ -1,5 +1,5 @@
import os
import signal
import socket
import subprocess
import threading
import time
@ -7,7 +7,7 @@ import webbrowser
import zipfile
import requests
from pathlib import Path
from typing import Callable, Optional, Tuple
from typing import Callable, Optional, Tuple, List
import tempfile
import shutil
@ -17,6 +17,80 @@ from cognee.version import get_cognee_version
logger = get_logger()
def _stream_process_output(
process: subprocess.Popen, stream_name: str, prefix: str, color_code: str = ""
) -> threading.Thread:
"""
Stream output from a process with a prefix to identify the source.
Args:
process: The subprocess to monitor
stream_name: 'stdout' or 'stderr'
prefix: Text prefix for each line (e.g., '[BACKEND]', '[FRONTEND]')
color_code: ANSI color code for the prefix (optional)
Returns:
Thread that handles the streaming
"""
def stream_reader():
stream = getattr(process, stream_name)
if stream is None:
return
reset_code = "\033[0m" if color_code else ""
try:
for line in iter(stream.readline, b""):
if line:
line_text = line.decode("utf-8").rstrip()
if line_text:
print(f"{color_code}{prefix}{reset_code} {line_text}", flush=True)
except Exception:
pass
finally:
if stream:
stream.close()
thread = threading.Thread(target=stream_reader, daemon=True)
thread.start()
return thread
def _is_port_available(port: int) -> bool:
"""
Check if a port is available on localhost.
Returns True if the port is available, False otherwise.
"""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1) # 1 second timeout
result = sock.connect_ex(("localhost", port))
return result != 0 # Port is available if connection fails
except Exception:
return False
def _check_required_ports(ports_to_check: List[Tuple[int, str]]) -> Tuple[bool, List[str]]:
"""
Check if all required ports are available on localhost.
Args:
ports_to_check: List of (port, service_name) tuples
Returns:
Tuple of (all_available: bool, unavailable_services: List[str])
"""
unavailable = []
for port, service_name in ports_to_check:
if not _is_port_available(port):
unavailable.append(f"{service_name} (port {port})")
logger.error(f"Port {port} is already in use for {service_name}")
return len(unavailable) == 0, unavailable
def normalize_version_for_comparison(version: str) -> str:
"""
Normalize version string for comparison.
@ -327,55 +401,111 @@ def prompt_user_for_download() -> bool:
def start_ui(
pid_callback: Callable[[int], None],
host: str = "localhost",
port: int = 3000,
open_browser: bool = True,
auto_download: bool = False,
start_backend: bool = False,
backend_host: str = "localhost",
backend_port: int = 8000,
start_mcp: bool = False,
mcp_port: int = 8001,
) -> Optional[subprocess.Popen]:
"""
Start the cognee frontend UI server, optionally with the backend API server.
Start the cognee frontend UI server, optionally with the backend API server and MCP server.
This function will:
1. Optionally start the cognee backend API server
2. Find the cognee-frontend directory (development) or download it (pip install)
3. Check if Node.js and npm are available (for development mode)
4. Install dependencies if needed (development mode)
5. Start the frontend server
6. Optionally open the browser
2. Optionally start the cognee MCP server
3. Find the cognee-frontend directory (development) or download it (pip install)
4. Check if Node.js and npm are available (for development mode)
5. Install dependencies if needed (development mode)
6. Start the frontend server
7. Optionally open the browser
Args:
pid_callback: Callback to notify with PID of each spawned process
host: Host to bind the frontend server to (default: localhost)
port: Port to run the frontend server on (default: 3000)
open_browser: Whether to open the browser automatically (default: True)
auto_download: If True, download frontend without prompting (default: False)
start_backend: If True, also start the cognee API backend server (default: False)
backend_host: Host to bind the backend server to (default: localhost)
backend_port: Port to run the backend server on (default: 8000)
start_mcp: If True, also start the cognee MCP server (default: False)
mcp_port: Port to run the MCP server on (default: 8001)
Returns:
subprocess.Popen object representing the running frontend server, or None if failed
Note: If backend is started, it runs in a separate process that will be cleaned up
when the frontend process is terminated.
Note: If backend and/or MCP server are started, they run in separate processes
that will be cleaned up when the frontend process is terminated.
Example:
>>> import cognee
>>> def dummy_callback(pid): pass
>>> # Start just the frontend
>>> server = cognee.start_ui()
>>> server = cognee.start_ui(dummy_callback)
>>>
>>> # Start both frontend and backend
>>> server = cognee.start_ui(start_backend=True)
>>> server = cognee.start_ui(dummy_callback, start_backend=True)
>>> # UI will be available at http://localhost:3000
>>> # API will be available at http://localhost:8000
>>> # To stop both servers later:
>>>
>>> # Start frontend with MCP server
>>> server = cognee.start_ui(dummy_callback, start_mcp=True)
>>> # UI will be available at http://localhost:3000
>>> # MCP server will be available at http://127.0.0.1:8001/sse
>>> # To stop all servers later:
>>> server.terminate()
"""
logger.info("Starting cognee UI...")
ports_to_check = [(port, "Frontend UI")]
if start_backend:
ports_to_check.append((backend_port, "Backend API"))
if start_mcp:
ports_to_check.append((mcp_port, "MCP Server"))
logger.info("Checking port availability...")
all_ports_available, unavailable_services = _check_required_ports(ports_to_check)
if not all_ports_available:
error_msg = f"Cannot start cognee UI: The following services have ports already in use: {', '.join(unavailable_services)}"
logger.error(error_msg)
logger.error("Please stop the conflicting services or change the port configuration.")
return None
logger.info("✓ All required ports are available")
backend_process = None
if start_mcp:
logger.info("Starting Cognee MCP server with Docker...")
cwd = os.getcwd()
env_file = os.path.join(cwd, ".env")
try:
mcp_process = subprocess.Popen(
[
"docker",
"run",
"-p",
f"{mcp_port}:8000",
"--rm",
"--env-file",
env_file,
"-e",
"TRANSPORT_MODE=sse",
"cognee/cognee-mcp:daulet-dev",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
_stream_process_output(mcp_process, "stdout", "[MCP]", "\033[34m") # Blue
_stream_process_output(mcp_process, "stderr", "[MCP]", "\033[34m") # Blue
pid_callback(mcp_process.pid)
logger.info(f"✓ Cognee MCP server starting on http://127.0.0.1:{mcp_port}/sse")
except Exception as e:
logger.error(f"Failed to start MCP server with Docker: {str(e)}")
# Start backend server if requested
if start_backend:
logger.info("Starting cognee backend API server...")
@ -389,16 +519,19 @@ def start_ui(
"uvicorn",
"cognee.api.client:app",
"--host",
backend_host,
"localhost",
"--port",
str(backend_port),
],
# Inherit stdout/stderr from parent process to show logs
stdout=None,
stderr=None,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
# Start threads to stream backend output with prefix
_stream_process_output(backend_process, "stdout", "[BACKEND]", "\033[32m") # Green
_stream_process_output(backend_process, "stderr", "[BACKEND]", "\033[32m") # Green
pid_callback(backend_process.pid)
# Give the backend a moment to start
@ -408,7 +541,7 @@ def start_ui(
logger.error("Backend server failed to start - process exited early")
return None
logger.info(f"✓ Backend API started at http://{backend_host}:{backend_port}")
logger.info(f"✓ Backend API started at http://localhost:{backend_port}")
except Exception as e:
logger.error(f"Failed to start backend server: {str(e)}")
@ -453,11 +586,11 @@ def start_ui(
# Prepare environment variables
env = os.environ.copy()
env["HOST"] = host
env["HOST"] = "localhost"
env["PORT"] = str(port)
# Start the development server
logger.info(f"Starting frontend server at http://{host}:{port}")
logger.info(f"Starting frontend server at http://localhost:{port}")
logger.info("This may take a moment to compile and start...")
try:
@ -468,10 +601,13 @@ def start_ui(
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
# Start threads to stream frontend output with prefix
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow
_stream_process_output(process, "stderr", "[FRONTEND]", "\033[33m") # Yellow
pid_callback(process.pid)
# Give it a moment to start up
@ -479,10 +615,7 @@ def start_ui(
# Check if process is still running
if process.poll() is not None:
stdout, stderr = process.communicate()
logger.error("Frontend server failed to start:")
logger.error(f"stdout: {stdout}")
logger.error(f"stderr: {stderr}")
logger.error("Frontend server failed to start - check the logs above for details")
return None
# Open browser if requested
@ -491,7 +624,7 @@ def start_ui(
def open_browser_delayed():
time.sleep(5) # Give Next.js time to fully start
try:
webbrowser.open(f"http://{host}:{port}") # TODO: use dashboard url?
webbrowser.open(f"http://localhost:{port}")
except Exception as e:
logger.warning(f"Could not open browser automatically: {e}")
@ -499,13 +632,9 @@ def start_ui(
browser_thread.start()
logger.info("✓ Cognee UI is starting up...")
logger.info(f"✓ Open your browser to: http://{host}:{port}")
logger.info(f"✓ Open your browser to: http://localhost:{port}")
logger.info("✓ The UI will be available once Next.js finishes compiling")
# Store backend process reference in the frontend process for cleanup
if backend_process:
process._cognee_backend_process = backend_process
return process
except Exception as e:
@ -523,102 +652,3 @@ def start_ui(
except (OSError, ProcessLookupError):
pass
return None
def stop_ui(process: subprocess.Popen) -> bool:
"""
Stop a running UI server process and backend process (if started), along with all their children.
Args:
process: The subprocess.Popen object returned by start_ui()
Returns:
bool: True if stopped successfully, False otherwise
"""
if not process:
return False
success = True
try:
# First, stop the backend process if it exists
backend_process = getattr(process, "_cognee_backend_process", None)
if backend_process:
logger.info("Stopping backend server...")
try:
backend_process.terminate()
try:
backend_process.wait(timeout=5)
logger.info("Backend server stopped gracefully")
except subprocess.TimeoutExpired:
logger.warning("Backend didn't terminate gracefully, forcing kill")
backend_process.kill()
backend_process.wait()
logger.info("Backend server stopped")
except Exception as e:
logger.error(f"Error stopping backend server: {str(e)}")
success = False
# Now stop the frontend process
logger.info("Stopping frontend server...")
# Try to terminate the process group (includes child processes like Next.js)
if hasattr(os, "killpg"):
try:
# Kill the entire process group
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
logger.debug("Sent SIGTERM to process group")
except (OSError, ProcessLookupError):
# Fall back to terminating just the main process
process.terminate()
logger.debug("Terminated main process only")
else:
process.terminate()
logger.debug("Terminated main process (Windows)")
try:
process.wait(timeout=10)
logger.info("Frontend server stopped gracefully")
except subprocess.TimeoutExpired:
logger.warning("Frontend didn't terminate gracefully, forcing kill")
# Force kill the process group
if hasattr(os, "killpg"):
try:
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
logger.debug("Sent SIGKILL to process group")
except (OSError, ProcessLookupError):
process.kill()
logger.debug("Force killed main process only")
else:
process.kill()
logger.debug("Force killed main process (Windows)")
process.wait()
if success:
logger.info("UI servers stopped successfully")
return success
except Exception as e:
logger.error(f"Error stopping UI servers: {str(e)}")
return False
# Convenience function similar to DuckDB's approach
def ui() -> Optional[subprocess.Popen]:
"""
Convenient alias for start_ui() with default parameters.
Similar to how DuckDB provides simple ui() function.
"""
return start_ui()
if __name__ == "__main__":
# Test the UI startup
server = start_ui()
if server:
try:
input("Press Enter to stop the server...")
finally:
stop_ui(server)

View file

@ -204,19 +204,27 @@ def main() -> int:
nonlocal spawned_pids
spawned_pids.append(pid)
frontend_port = 3000
start_backend, backend_port = True, 8000
start_mcp, mcp_port = True, 8001
server_process = start_ui(
host="localhost",
port=3000,
open_browser=True,
start_backend=True,
auto_download=True,
pid_callback=pid_callback,
port=frontend_port,
open_browser=True,
auto_download=True,
start_backend=start_backend,
backend_port=backend_port,
start_mcp=start_mcp,
mcp_port=mcp_port,
)
if server_process:
fmt.success("UI server started successfully!")
fmt.echo("The interface is available at: http://localhost:3000")
fmt.echo("The API backend is available at: http://localhost:8000")
fmt.echo(f"The interface is available at: http://localhost:{frontend_port}")
if start_backend:
fmt.echo(f"The API backend is available at: http://localhost:{backend_port}")
if start_mcp:
fmt.echo(f"The MCP server is available at: http://localhost:{mcp_port}")
fmt.note("Press Ctrl+C to stop the server...")
try:

View file

@ -34,6 +34,7 @@ class S3FileStorage(Storage):
self.s3 = s3fs.S3FileSystem(
key=s3_config.aws_access_key_id,
secret=s3_config.aws_secret_access_key,
token=s3_config.aws_session_token,
anon=False,
endpoint_url=s3_config.aws_endpoint_url,
client_kwargs={"region_name": s3_config.aws_region},

View file

@ -8,6 +8,7 @@ class S3Config(BaseSettings):
aws_endpoint_url: Optional[str] = None
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
aws_session_token: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")

View file

@ -23,6 +23,9 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
"TableRow": "#f47710",
"TableType": "#6510f4",
"ColumnValue": "#13613a",
"SchemaTable": "#f47710",
"DatabaseSchema": "#6510f4",
"SchemaRelationship": "#13613a",
"default": "#D3D3D3",
}

View file

@ -4,16 +4,20 @@ from sqlalchemy import text
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
get_migration_relational_engine,
)
from cognee.infrastructure.databases.relational.config import get_migration_config
from cognee.tasks.storage.index_data_points import index_data_points
from cognee.tasks.storage.index_graph_edges import index_graph_edges
from cognee.tasks.schema.ingest_database_schema import ingest_database_schema
from cognee.modules.engine.models import TableRow, TableType, ColumnValue
logger = logging.getLogger(__name__)
async def migrate_relational_database(graph_db, schema, migrate_column_data=True):
async def migrate_relational_database(
graph_db, schema, migrate_column_data=True, schema_only=False
):
"""
Migrates data from a relational database into a graph database.
@ -26,11 +30,133 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model.
"""
# Create a mapping of node_id to node objects for referencing in edge creation
if schema_only:
node_mapping, edge_mapping = await schema_only_ingestion(schema)
else:
node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data)
def _remove_duplicate_edges(edge_mapping):
seen = set()
unique_original_shape = []
for tup in edge_mapping:
# We go through all the tuples in the edge_mapping and we only add unique tuples to the list
# To eliminate duplicate edges.
source_id, target_id, rel_name, rel_dict = tup
# We need to convert the dictionary to a frozenset to be able to compare values for it
rel_dict_hashable = frozenset(sorted(rel_dict.items()))
hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable)
# We use the seen set to keep track of unique edges
if hashable_tup not in seen:
# A list that has frozensets elements instead of dictionaries is needed to be able to compare values
seen.add(hashable_tup)
# append the original tuple shape (with the dictionary) if it's the first time we see it
unique_original_shape.append(tup)
return unique_original_shape
# Add all nodes and edges to the graph
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
# If we'd create nodes and add them to graph in real time the process would take too long.
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
await graph_db.add_nodes(list(node_mapping.values()))
await graph_db.add_edges(_remove_duplicate_edges(edge_mapping))
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
# Cognee uses this information to perform searches on the knowledge graph.
await index_data_points(list(node_mapping.values()))
await index_graph_edges()
logger.info("Data successfully migrated from relational database to desired graph database.")
return await graph_db.get_graph_data()
async def schema_only_ingestion(schema):
node_mapping = {}
edge_mapping = []
# Calling the ingest_database_schema function to return DataPoint subclasses
result = await ingest_database_schema(
schema=schema,
max_sample_rows=5,
)
database_schema = result["database_schema"]
schema_tables = result["schema_tables"]
schema_relationships = result["relationships"]
database_node_id = database_schema.id
node_mapping[database_node_id] = database_schema
for table in schema_tables:
table_node_id = table.id
# Add TableSchema Datapoint as a node.
node_mapping[table_node_id] = table
edge_mapping.append(
(
table_node_id,
database_node_id,
"is_part_of",
dict(
source_node_id=table_node_id,
target_node_id=database_node_id,
relationship_name="is_part_of",
),
)
)
table_name_to_id = {t.name: t.id for t in schema_tables}
for rel in schema_relationships:
source_table_id = table_name_to_id.get(rel.source_table)
target_table_id = table_name_to_id.get(rel.target_table)
relationship_id = rel.id
# Add RelationshipTable DataPoint as a node.
node_mapping[relationship_id] = rel
edge_mapping.append(
(
source_table_id,
relationship_id,
"has_relationship",
dict(
source_node_id=source_table_id,
target_node_id=relationship_id,
relationship_name=rel.relationship_type,
),
)
)
edge_mapping.append(
(
relationship_id,
target_table_id,
"has_relationship",
dict(
source_node_id=relationship_id,
target_node_id=target_table_id,
relationship_name=rel.relationship_type,
),
)
)
edge_mapping.append(
(
source_table_id,
target_table_id,
rel.relationship_type,
dict(
source_node_id=source_table_id,
target_node_id=target_table_id,
relationship_name=rel.relationship_type,
),
)
)
return node_mapping, edge_mapping
async def complete_database_ingestion(schema, migrate_column_data):
engine = get_migration_relational_engine()
# Create a mapping of node_id to node objects for referencing in edge creation
node_mapping = {}
edge_mapping = []
async with engine.engine.begin() as cursor:
# First, create table type nodes for all tables
for table_name, details in schema.items():
@ -38,7 +164,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
table_node = TableType(
id=uuid5(NAMESPACE_OID, name=table_name),
name=table_name,
description=f"Table: {table_name}",
description=f'Relational database table with the following name: "{table_name}".',
)
# Add TableType node to mapping ( node will be added to the graph later based on this mapping )
@ -75,7 +201,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
name=node_id,
is_a=table_node,
properties=str(row_properties),
description=f"Row in {table_name} with {primary_key_col}={primary_key_value}",
description=f'Row in relational database table from the table with the name: "{table_name}" with the following row data {str(row_properties)} where the dictionary key value is the column name and the value is the column value. This row has the id of: {node_id}',
)
# Store the node object in our mapping
@ -113,7 +239,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
id=uuid5(NAMESPACE_OID, name=column_node_id),
name=column_node_id,
properties=f"{key} {value} {table_name}",
description=f"Column name={key} and value={value} from column from table={table_name}",
description=f"column from relational database table={table_name}. Column name={key} and value={value}. The value of the column is related to the following row with this id: {row_node.id}. This column has the following ID: {column_node_id}",
)
node_mapping[column_node_id] = column_node
@ -180,39 +306,4 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
),
)
)
def _remove_duplicate_edges(edge_mapping):
seen = set()
unique_original_shape = []
for tup in edge_mapping:
# We go through all the tuples in the edge_mapping and we only add unique tuples to the list
# To eliminate duplicate edges.
source_id, target_id, rel_name, rel_dict = tup
# We need to convert the dictionary to a frozenset to be able to compare values for it
rel_dict_hashable = frozenset(sorted(rel_dict.items()))
hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable)
# We use the seen set to keep track of unique edges
if hashable_tup not in seen:
# A list that has frozensets elements instead of dictionaries is needed to be able to compare values
seen.add(hashable_tup)
# append the original tuple shape (with the dictionary) if it's the first time we see it
unique_original_shape.append(tup)
return unique_original_shape
# Add all nodes and edges to the graph
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
# If we'd create nodes and add them to graph in real time the process would take too long.
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
await graph_db.add_nodes(list(node_mapping.values()))
await graph_db.add_edges(_remove_duplicate_edges(edge_mapping))
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
# Cognee uses this information to perform searches on the knowledge graph.
await index_data_points(list(node_mapping.values()))
await index_graph_edges()
logger.info("Data successfully migrated from relational database to desired graph database.")
return await graph_db.get_graph_data()
return node_mapping, edge_mapping

View file

@ -32,7 +32,10 @@ async def resolve_data_directories(
import s3fs
fs = s3fs.S3FileSystem(
key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, anon=False
key=s3_config.aws_access_key_id,
secret=s3_config.aws_secret_access_key,
token=s3_config.aws_session_token,
anon=False,
)
for item in data:

View file

@ -0,0 +1,134 @@
from typing import List, Dict
from uuid import uuid5, NAMESPACE_OID
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from sqlalchemy import text
from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
get_migration_relational_engine,
)
from cognee.infrastructure.databases.relational.config import get_migration_config
from datetime import datetime, timezone
async def ingest_database_schema(
schema,
max_sample_rows: int = 0,
) -> Dict[str, List[DataPoint] | DataPoint]:
"""
Extract database schema metadata (optionally with sample data) and return DataPoint models for graph construction.
Args:
schema: Database schema
max_sample_rows: Maximum sample rows per table (0 means no sampling)
Returns:
Dict with keys:
"database_schema": DatabaseSchema
"schema_tables": List[SchemaTable]
"relationships": List[SchemaRelationship]
"""
tables = {}
sample_data = {}
schema_tables = []
schema_relationships = []
migration_config = get_migration_config()
engine = get_migration_relational_engine()
qi = engine.engine.dialect.identifier_preparer.quote
try:
max_sample_rows = max(0, int(max_sample_rows))
except (TypeError, ValueError):
max_sample_rows = 0
def qname(name: str):
split_name = name.split(".")
return ".".join(qi(p) for p in split_name)
async with engine.engine.begin() as cursor:
for table_name, details in schema.items():
tn = qname(table_name)
if max_sample_rows > 0:
rows_result = await cursor.execute(
text(f"SELECT * FROM {tn} LIMIT :limit;"), # noqa: S608 - tn is fully quoted
{"limit": max_sample_rows},
)
rows = [dict(r) for r in rows_result.mappings().all()]
else:
rows = []
if engine.engine.dialect.name == "postgresql":
if "." in table_name:
schema_part, table_part = table_name.split(".", 1)
else:
schema_part, table_part = "public", table_name
estimate = await cursor.execute(
text(
"SELECT reltuples::bigint AS estimate "
"FROM pg_class c "
"JOIN pg_namespace n ON n.oid = c.relnamespace "
"WHERE n.nspname = :schema AND c.relname = :table"
),
{"schema": schema_part, "table": table_part},
)
row_count_estimate = estimate.scalar() or 0
else:
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) # noqa: S608 - tn is fully quoted
row_count_estimate = count_result.scalar()
schema_table = SchemaTable(
id=uuid5(NAMESPACE_OID, name=f"{table_name}"),
name=table_name,
columns=details["columns"],
primary_key=details.get("primary_key"),
foreign_keys=details.get("foreign_keys", []),
sample_rows=rows,
row_count_estimate=row_count_estimate,
description=f"Relational database table with '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows."
f"Here are the columns this table contains: {details['columns']}"
f"Here are a few sample_rows to show the contents of the table: {rows}"
f"Table is part of the database: {migration_config.migration_db_name}",
)
schema_tables.append(schema_table)
tables[table_name] = details
sample_data[table_name] = rows
for fk in details.get("foreign_keys", []):
ref_table_fq = fk["ref_table"]
if "." not in ref_table_fq and "." in table_name:
ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}"
relationship_name = (
f"{table_name}:{fk['column']}->{ref_table_fq}:{fk['ref_column']}"
)
relationship = SchemaRelationship(
id=uuid5(NAMESPACE_OID, name=relationship_name),
name=relationship_name,
source_table=table_name,
target_table=ref_table_fq,
relationship_type="foreign_key",
source_column=fk["column"],
target_column=fk["ref_column"],
description=f"Relational database table foreign key relationship between: {table_name}.{fk['column']}{ref_table_fq}.{fk['ref_column']}"
f"This foreing key relationship between table columns is a part of the following database: {migration_config.migration_db_name}",
)
schema_relationships.append(relationship)
id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}"
database_schema = DatabaseSchema(
id=uuid5(NAMESPACE_OID, name=id_str),
name=migration_config.migration_db_name,
database_type=migration_config.migration_db_provider,
tables=tables,
sample_data=sample_data,
extraction_timestamp=datetime.now(timezone.utc),
description=f"Database schema containing {len(schema_tables)} tables and {len(schema_relationships)} relationships. "
f"The database type is {migration_config.migration_db_provider}."
f"The database contains the following tables: {tables}",
)
return {
"database_schema": database_schema,
"schema_tables": schema_tables,
"relationships": schema_relationships,
}

View file

@ -0,0 +1,41 @@
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from typing import List, Dict, Optional
from datetime import datetime
class DatabaseSchema(DataPoint):
"""Represents a complete database schema with sample data"""
name: str
database_type: str # sqlite, postgres, etc.
tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter
sample_data: Dict[str, List[Dict]] # Limited examples per table
extraction_timestamp: datetime
description: str
metadata: dict = {"index_fields": ["description", "name"]}
class SchemaTable(DataPoint):
"""Represents an individual table schema with relationships"""
name: str
columns: List[Dict] # Column definitions with types
primary_key: Optional[str]
foreign_keys: List[Dict] # Foreign key relationships
sample_rows: List[Dict] # Max 3-5 example rows
row_count_estimate: Optional[int] # Actual table size
description: str
metadata: dict = {"index_fields": ["description", "name"]}
class SchemaRelationship(DataPoint):
"""Represents relationships between tables"""
name: str
source_table: str
target_table: str
relationship_type: str # "foreign_key", "one_to_many", etc.
source_column: str
target_column: str
description: str
metadata: dict = {"index_fields": ["description", "name"]}

View file

@ -197,6 +197,80 @@ async def relational_db_migration():
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
async def test_schema_only_migration():
# 1. Setup test DB and extract schema
migration_engine = await setup_test_db()
schema = await migration_engine.extract_schema()
# 2. Setup graph engine
graph_engine = await get_graph_engine()
# 4. Migrate schema only
await migrate_relational_database(graph_engine, schema=schema, schema_only=True)
# 5. Verify number of tables through search
search_results = await cognee.search(
query_text="How many tables are there in this database",
query_type=cognee.SearchType.GRAPH_COMPLETION,
top_k=30,
)
assert any("11" in r for r in search_results), (
"Number of tables in the database reported in search_results is either None or not equal to 11"
)
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower()
edge_counts = {
"is_part_of": 0,
"has_relationship": 0,
"foreign_key": 0,
}
if graph_db_provider == "neo4j":
for rel_type in edge_counts.keys():
query_str = f"""
MATCH ()-[r:{rel_type}]->()
RETURN count(r) as c
"""
rows = await graph_engine.query(query_str)
edge_counts[rel_type] = rows[0]["c"]
elif graph_db_provider == "kuzu":
for rel_type in edge_counts.keys():
query_str = f"""
MATCH ()-[r:EDGE]->()
WHERE r.relationship_name = '{rel_type}'
RETURN count(r) as c
"""
rows = await graph_engine.query(query_str)
edge_counts[rel_type] = rows[0][0]
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()
for _, _, key, _ in edges:
if key in edge_counts:
edge_counts[key] += 1
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
# 7. Assert counts match expected values
expected_counts = {
"is_part_of": 11,
"has_relationship": 22,
"foreign_key": 11,
}
for rel_type, expected in expected_counts.items():
actual = edge_counts[rel_type]
assert actual == expected, (
f"Expected {expected} edges for relationship '{rel_type}', but found {actual}"
)
print("Schema-only migration edge counts validated successfully!")
print(f"Edge counts: {edge_counts}")
async def test_migration_sqlite():
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
@ -209,6 +283,7 @@ async def test_migration_sqlite():
)
await relational_db_migration()
await test_schema_only_migration()
async def test_migration_postgres():
@ -224,6 +299,7 @@ async def test_migration_postgres():
}
)
await relational_db_migration()
await test_schema_only_migration()
async def main():

View file

@ -1,16 +1,15 @@
from pathlib import Path
import asyncio
import cognee
import os
import cognee
from cognee.infrastructure.databases.relational.config import get_migration_config
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.infrastructure.databases.relational import (
get_migration_relational_engine,
)
from cognee.modules.search.types import SearchType
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
)
@ -32,16 +31,29 @@ from cognee.infrastructure.databases.vector.pgvector import (
async def main():
engine = get_migration_relational_engine()
# Clean all data stored in Cognee
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
# Needed to create appropriate tables only on the Cognee side
# Needed to create appropriate database tables only on the Cognee side
await create_relational_db_and_tables()
await create_vector_db_and_tables()
# In case environment variables are not set use the example database from the Cognee repo
migration_db_provider = os.environ.get("MIGRATION_DB_PROVIDER", "sqlite")
migration_db_path = os.environ.get(
"MIGRATION_DB_PATH",
os.path.join(Path(__file__).resolve().parent.parent.parent, "cognee/tests/test_data"),
)
migration_db_name = os.environ.get("MIGRATION_DB_NAME", "migration_database.sqlite")
migration_config = get_migration_config()
migration_config.migration_db_provider = migration_db_provider
migration_config.migration_db_path = migration_db_path
migration_config.migration_db_name = migration_db_name
engine = get_migration_relational_engine()
print("\nExtracting schema of database to migrate.")
schema = await engine.extract_schema()
print(f"Migrated database schema:\n{schema}")
@ -53,10 +65,6 @@ async def main():
await migrate_relational_database(graph, schema=schema)
print("Relational database migration complete.")
# Define location where to store html visualization of graph of the migrated database
home_dir = os.path.expanduser("~")
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
# Make sure to set top_k at a high value for a broader search, the default value is only 10!
# top_k represent the number of graph tripplets to supply to the LLM to answer your question
search_results = await cognee.search(
@ -69,13 +77,25 @@ async def main():
# Having a top_k value set to too high might overwhelm the LLM context when specific questions need to be answered.
# For this kind of question we've set the top_k to 30
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION_COT,
query_type=SearchType.GRAPH_COMPLETION,
query_text="What invoices are related to Leonie Köhler?",
top_k=30,
)
print(f"Search results: {search_results}")
# test.html is a file with visualized data migration
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="What invoices are related to Luís Gonçalves?",
top_k=30,
)
print(f"Search results: {search_results}")
# If you check the relational database for this example you can see that the search results successfully found all
# the invoices related to the two customers, without any hallucinations or additional information
# Define location where to store html visualization of graph of the migrated database
home_dir = os.path.expanduser("~")
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
print("Adding html visualization of graph database after migration.")
await visualize_graph(destination_file_path)
print(f"Visualization can be found at: {destination_file_path}")

View file

@ -29,8 +29,11 @@ async def main():
print("=" * 60)
# Start the UI server
def dummy_callback(pid):
pass
server = cognee.start_ui(
host="localhost",
pid_callback=dummy_callback,
port=3000,
open_browser=True, # This will automatically open your browser
)