Merge branch 'dev' into merge-main-vol6
This commit is contained in:
commit
7403e31738
21 changed files with 728 additions and 212 deletions
|
|
@ -203,6 +203,16 @@ LITELLM_LOG="ERROR"
|
||||||
# DEFAULT_USER_EMAIL=""
|
# DEFAULT_USER_EMAIL=""
|
||||||
# DEFAULT_USER_PASSWORD=""
|
# DEFAULT_USER_PASSWORD=""
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# 📂 AWS Settings
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
#AWS_REGION=""
|
||||||
|
#AWS_ENDPOINT_URL=""
|
||||||
|
#AWS_ACCESS_KEY_ID=""
|
||||||
|
#AWS_SECRET_ACCESS_KEY=""
|
||||||
|
#AWS_SESSION_TOKEN=""
|
||||||
|
|
||||||
------------------------------- END OF POSSIBLE SETTINGS -------------------------------
|
------------------------------- END OF POSSIBLE SETTINGS -------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,11 @@
|
||||||
|
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import Image from "next/image";
|
import Image from "next/image";
|
||||||
import { useBoolean } from "@/utils";
|
import { useEffect } from "react";
|
||||||
|
import { useBoolean, fetch } from "@/utils";
|
||||||
|
|
||||||
import { CloseIcon, CloudIcon, CogneeIcon } from "../Icons";
|
import { CloseIcon, CloudIcon, CogneeIcon } from "../Icons";
|
||||||
import { CTAButton, GhostButton, IconButton, Modal } from "../elements";
|
import { CTAButton, GhostButton, IconButton, Modal, StatusDot } from "../elements";
|
||||||
import syncData from "@/modules/cloud/syncData";
|
import syncData from "@/modules/cloud/syncData";
|
||||||
|
|
||||||
interface HeaderProps {
|
interface HeaderProps {
|
||||||
|
|
@ -23,6 +24,12 @@ export default function Header({ user }: HeaderProps) {
|
||||||
setFalse: closeSyncModal,
|
setFalse: closeSyncModal,
|
||||||
} = useBoolean(false);
|
} = useBoolean(false);
|
||||||
|
|
||||||
|
const {
|
||||||
|
value: isMCPConnected,
|
||||||
|
setTrue: setMCPConnected,
|
||||||
|
setFalse: setMCPDisconnected,
|
||||||
|
} = useBoolean(false);
|
||||||
|
|
||||||
const handleDataSyncConfirm = () => {
|
const handleDataSyncConfirm = () => {
|
||||||
syncData()
|
syncData()
|
||||||
.finally(() => {
|
.finally(() => {
|
||||||
|
|
@ -30,6 +37,19 @@ export default function Header({ user }: HeaderProps) {
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const checkMCPConnection = () => {
|
||||||
|
fetch.checkMCPHealth()
|
||||||
|
.then(() => setMCPConnected())
|
||||||
|
.catch(() => setMCPDisconnected());
|
||||||
|
};
|
||||||
|
|
||||||
|
checkMCPConnection();
|
||||||
|
const interval = setInterval(checkMCPConnection, 30000);
|
||||||
|
|
||||||
|
return () => clearInterval(interval);
|
||||||
|
}, [setMCPConnected, setMCPDisconnected]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<header className="relative flex flex-row h-14 min-h-14 px-5 items-center justify-between w-full max-w-[1920px] mx-auto">
|
<header className="relative flex flex-row h-14 min-h-14 px-5 items-center justify-between w-full max-w-[1920px] mx-auto">
|
||||||
|
|
@ -39,6 +59,10 @@ export default function Header({ user }: HeaderProps) {
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex flex-row items-center gap-2.5">
|
<div className="flex flex-row items-center gap-2.5">
|
||||||
|
<Link href="/mcp-status" className="!text-indigo-600 pl-4 pr-4">
|
||||||
|
<StatusDot className="mr-2" isActive={isMCPConnected} />
|
||||||
|
{ isMCPConnected ? "MCP connected" : "MCP disconnected" }
|
||||||
|
</Link>
|
||||||
<GhostButton onClick={openSyncModal} className="text-indigo-600 gap-3 pl-4 pr-4">
|
<GhostButton onClick={openSyncModal} className="text-indigo-600 gap-3 pl-4 pr-4">
|
||||||
<CloudIcon />
|
<CloudIcon />
|
||||||
<div>Sync</div>
|
<div>Sync</div>
|
||||||
|
|
|
||||||
13
cognee-frontend/src/ui/elements/StatusDot.tsx
Normal file
13
cognee-frontend/src/ui/elements/StatusDot.tsx
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
import React from "react";
|
||||||
|
|
||||||
|
const StatusDot = ({ isActive, className }: { isActive: boolean, className?: string }) => {
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
className={`inline-block w-3 h-3 rounded-full ${className} ${
|
||||||
|
isActive ? "bg-green-500" : "bg-red-500"
|
||||||
|
}`}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default StatusDot;
|
||||||
|
|
@ -8,5 +8,6 @@ export { default as IconButton } from "./IconButton";
|
||||||
export { default as GhostButton } from "./GhostButton";
|
export { default as GhostButton } from "./GhostButton";
|
||||||
export { default as NeutralButton } from "./NeutralButton";
|
export { default as NeutralButton } from "./NeutralButton";
|
||||||
export { default as StatusIndicator } from "./StatusIndicator";
|
export { default as StatusIndicator } from "./StatusIndicator";
|
||||||
|
export { default as StatusDot } from "./StatusDot";
|
||||||
export { default as Accordion } from "./Accordion";
|
export { default as Accordion } from "./Accordion";
|
||||||
export { default as Notebook } from "./Notebook";
|
export { default as Notebook } from "./Notebook";
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ const backendApiUrl = process.env.NEXT_PUBLIC_BACKEND_API_URL || "http://localho
|
||||||
|
|
||||||
const cloudApiUrl = process.env.NEXT_PUBLIC_CLOUD_API_URL || "http://localhost:8001";
|
const cloudApiUrl = process.env.NEXT_PUBLIC_CLOUD_API_URL || "http://localhost:8001";
|
||||||
|
|
||||||
|
const mcpApiUrl = process.env.NEXT_PUBLIC_MCP_API_URL || "http://localhost:8001";
|
||||||
|
|
||||||
let apiKey: string | null = process.env.NEXT_PUBLIC_COGWIT_API_KEY || null;
|
let apiKey: string | null = process.env.NEXT_PUBLIC_COGWIT_API_KEY || null;
|
||||||
let accessToken: string | null = null;
|
let accessToken: string | null = null;
|
||||||
|
|
||||||
|
|
@ -66,6 +68,10 @@ fetch.checkHealth = () => {
|
||||||
return global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
|
return global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
fetch.checkMCPHealth = () => {
|
||||||
|
return global.fetch(`${mcpApiUrl.replace("/api", "")}/health`);
|
||||||
|
};
|
||||||
|
|
||||||
fetch.setApiKey = (newApiKey: string) => {
|
fetch.setApiKey = (newApiKey: string) => {
|
||||||
apiKey = newApiKey;
|
apiKey = newApiKey;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -48,27 +48,27 @@ if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then
|
||||||
if [ "$DEBUG" = "true" ]; then
|
if [ "$DEBUG" = "true" ]; then
|
||||||
echo "Waiting for the debugger to attach..."
|
echo "Waiting for the debugger to attach..."
|
||||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||||
else
|
else
|
||||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport stdio --no-migration
|
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee-mcp --transport stdio --no-migration
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||||
exec cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||||
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||||
else
|
else
|
||||||
exec cognee --transport stdio --no-migration
|
exec cognee-mcp --transport stdio --no-migration
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||||
exec cognee --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
exec cognee-mcp --transport sse --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||||
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
exec cognee-mcp --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||||
else
|
else
|
||||||
exec cognee --transport stdio --no-migration
|
exec cognee-mcp --transport stdio --no-migration
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
|
||||||
|
|
@ -36,4 +36,4 @@ dev = [
|
||||||
allow-direct-references = true
|
allow-direct-references = true
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
cognee = "src:main"
|
cognee-mcp = "src:main"
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,10 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.middleware import Middleware
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -38,6 +42,53 @@ mcp = FastMCP("Cognee")
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_sse_with_cors():
|
||||||
|
"""Custom SSE transport with CORS middleware."""
|
||||||
|
sse_app = mcp.sse_app()
|
||||||
|
sse_app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["http://localhost:3000"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
config = uvicorn.Config(
|
||||||
|
sse_app,
|
||||||
|
host=mcp.settings.host,
|
||||||
|
port=mcp.settings.port,
|
||||||
|
log_level=mcp.settings.log_level.lower(),
|
||||||
|
)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_http_with_cors():
|
||||||
|
"""Custom HTTP transport with CORS middleware."""
|
||||||
|
http_app = mcp.streamable_http_app()
|
||||||
|
http_app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["http://localhost:3000"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
config = uvicorn.Config(
|
||||||
|
http_app,
|
||||||
|
host=mcp.settings.host,
|
||||||
|
port=mcp.settings.port,
|
||||||
|
log_level=mcp.settings.log_level.lower(),
|
||||||
|
)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.custom_route("/health", methods=["GET"])
|
||||||
|
async def health_check(request):
|
||||||
|
return JSONResponse({"status": "ok"})
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def cognee_add_developer_rules(
|
async def cognee_add_developer_rules(
|
||||||
base_path: str = ".", graph_model_file: str = None, graph_model_name: str = None
|
base_path: str = ".", graph_model_file: str = None, graph_model_name: str = None
|
||||||
|
|
@ -975,12 +1026,12 @@ async def main():
|
||||||
await mcp.run_stdio_async()
|
await mcp.run_stdio_async()
|
||||||
elif args.transport == "sse":
|
elif args.transport == "sse":
|
||||||
logger.info(f"Running MCP server with SSE transport on {args.host}:{args.port}")
|
logger.info(f"Running MCP server with SSE transport on {args.host}:{args.port}")
|
||||||
await mcp.run_sse_async()
|
await run_sse_with_cors()
|
||||||
elif args.transport == "http":
|
elif args.transport == "http":
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Running MCP server with Streamable HTTP transport on {args.host}:{args.port}{args.path}"
|
f"Running MCP server with Streamable HTTP transport on {args.host}:{args.port}{args.path}"
|
||||||
)
|
)
|
||||||
await mcp.run_streamable_http_async()
|
await run_http_with_cors()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
from .ui import start_ui, stop_ui, ui
|
from .ui import start_ui
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import signal
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
@ -7,7 +7,7 @@ import webbrowser
|
||||||
import zipfile
|
import zipfile
|
||||||
import requests
|
import requests
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Tuple
|
from typing import Callable, Optional, Tuple, List
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
@ -17,6 +17,80 @@ from cognee.version import get_cognee_version
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_process_output(
|
||||||
|
process: subprocess.Popen, stream_name: str, prefix: str, color_code: str = ""
|
||||||
|
) -> threading.Thread:
|
||||||
|
"""
|
||||||
|
Stream output from a process with a prefix to identify the source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process: The subprocess to monitor
|
||||||
|
stream_name: 'stdout' or 'stderr'
|
||||||
|
prefix: Text prefix for each line (e.g., '[BACKEND]', '[FRONTEND]')
|
||||||
|
color_code: ANSI color code for the prefix (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Thread that handles the streaming
|
||||||
|
"""
|
||||||
|
|
||||||
|
def stream_reader():
|
||||||
|
stream = getattr(process, stream_name)
|
||||||
|
if stream is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
reset_code = "\033[0m" if color_code else ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
for line in iter(stream.readline, b""):
|
||||||
|
if line:
|
||||||
|
line_text = line.decode("utf-8").rstrip()
|
||||||
|
if line_text:
|
||||||
|
print(f"{color_code}{prefix}{reset_code} {line_text}", flush=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
if stream:
|
||||||
|
stream.close()
|
||||||
|
|
||||||
|
thread = threading.Thread(target=stream_reader, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
return thread
|
||||||
|
|
||||||
|
|
||||||
|
def _is_port_available(port: int) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a port is available on localhost.
|
||||||
|
Returns True if the port is available, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||||
|
sock.settimeout(1) # 1 second timeout
|
||||||
|
result = sock.connect_ex(("localhost", port))
|
||||||
|
return result != 0 # Port is available if connection fails
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _check_required_ports(ports_to_check: List[Tuple[int, str]]) -> Tuple[bool, List[str]]:
|
||||||
|
"""
|
||||||
|
Check if all required ports are available on localhost.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ports_to_check: List of (port, service_name) tuples
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (all_available: bool, unavailable_services: List[str])
|
||||||
|
"""
|
||||||
|
unavailable = []
|
||||||
|
|
||||||
|
for port, service_name in ports_to_check:
|
||||||
|
if not _is_port_available(port):
|
||||||
|
unavailable.append(f"{service_name} (port {port})")
|
||||||
|
logger.error(f"Port {port} is already in use for {service_name}")
|
||||||
|
|
||||||
|
return len(unavailable) == 0, unavailable
|
||||||
|
|
||||||
|
|
||||||
def normalize_version_for_comparison(version: str) -> str:
|
def normalize_version_for_comparison(version: str) -> str:
|
||||||
"""
|
"""
|
||||||
Normalize version string for comparison.
|
Normalize version string for comparison.
|
||||||
|
|
@ -327,55 +401,111 @@ def prompt_user_for_download() -> bool:
|
||||||
|
|
||||||
def start_ui(
|
def start_ui(
|
||||||
pid_callback: Callable[[int], None],
|
pid_callback: Callable[[int], None],
|
||||||
host: str = "localhost",
|
|
||||||
port: int = 3000,
|
port: int = 3000,
|
||||||
open_browser: bool = True,
|
open_browser: bool = True,
|
||||||
auto_download: bool = False,
|
auto_download: bool = False,
|
||||||
start_backend: bool = False,
|
start_backend: bool = False,
|
||||||
backend_host: str = "localhost",
|
|
||||||
backend_port: int = 8000,
|
backend_port: int = 8000,
|
||||||
|
start_mcp: bool = False,
|
||||||
|
mcp_port: int = 8001,
|
||||||
) -> Optional[subprocess.Popen]:
|
) -> Optional[subprocess.Popen]:
|
||||||
"""
|
"""
|
||||||
Start the cognee frontend UI server, optionally with the backend API server.
|
Start the cognee frontend UI server, optionally with the backend API server and MCP server.
|
||||||
|
|
||||||
This function will:
|
This function will:
|
||||||
1. Optionally start the cognee backend API server
|
1. Optionally start the cognee backend API server
|
||||||
2. Find the cognee-frontend directory (development) or download it (pip install)
|
2. Optionally start the cognee MCP server
|
||||||
3. Check if Node.js and npm are available (for development mode)
|
3. Find the cognee-frontend directory (development) or download it (pip install)
|
||||||
4. Install dependencies if needed (development mode)
|
4. Check if Node.js and npm are available (for development mode)
|
||||||
5. Start the frontend server
|
5. Install dependencies if needed (development mode)
|
||||||
6. Optionally open the browser
|
6. Start the frontend server
|
||||||
|
7. Optionally open the browser
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pid_callback: Callback to notify with PID of each spawned process
|
pid_callback: Callback to notify with PID of each spawned process
|
||||||
host: Host to bind the frontend server to (default: localhost)
|
|
||||||
port: Port to run the frontend server on (default: 3000)
|
port: Port to run the frontend server on (default: 3000)
|
||||||
open_browser: Whether to open the browser automatically (default: True)
|
open_browser: Whether to open the browser automatically (default: True)
|
||||||
auto_download: If True, download frontend without prompting (default: False)
|
auto_download: If True, download frontend without prompting (default: False)
|
||||||
start_backend: If True, also start the cognee API backend server (default: False)
|
start_backend: If True, also start the cognee API backend server (default: False)
|
||||||
backend_host: Host to bind the backend server to (default: localhost)
|
|
||||||
backend_port: Port to run the backend server on (default: 8000)
|
backend_port: Port to run the backend server on (default: 8000)
|
||||||
|
start_mcp: If True, also start the cognee MCP server (default: False)
|
||||||
|
mcp_port: Port to run the MCP server on (default: 8001)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
subprocess.Popen object representing the running frontend server, or None if failed
|
subprocess.Popen object representing the running frontend server, or None if failed
|
||||||
Note: If backend is started, it runs in a separate process that will be cleaned up
|
Note: If backend and/or MCP server are started, they run in separate processes
|
||||||
when the frontend process is terminated.
|
that will be cleaned up when the frontend process is terminated.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> import cognee
|
>>> import cognee
|
||||||
|
>>> def dummy_callback(pid): pass
|
||||||
>>> # Start just the frontend
|
>>> # Start just the frontend
|
||||||
>>> server = cognee.start_ui()
|
>>> server = cognee.start_ui(dummy_callback)
|
||||||
>>>
|
>>>
|
||||||
>>> # Start both frontend and backend
|
>>> # Start both frontend and backend
|
||||||
>>> server = cognee.start_ui(start_backend=True)
|
>>> server = cognee.start_ui(dummy_callback, start_backend=True)
|
||||||
>>> # UI will be available at http://localhost:3000
|
>>> # UI will be available at http://localhost:3000
|
||||||
>>> # API will be available at http://localhost:8000
|
>>> # API will be available at http://localhost:8000
|
||||||
>>> # To stop both servers later:
|
>>>
|
||||||
|
>>> # Start frontend with MCP server
|
||||||
|
>>> server = cognee.start_ui(dummy_callback, start_mcp=True)
|
||||||
|
>>> # UI will be available at http://localhost:3000
|
||||||
|
>>> # MCP server will be available at http://127.0.0.1:8001/sse
|
||||||
|
>>> # To stop all servers later:
|
||||||
>>> server.terminate()
|
>>> server.terminate()
|
||||||
"""
|
"""
|
||||||
logger.info("Starting cognee UI...")
|
logger.info("Starting cognee UI...")
|
||||||
|
|
||||||
|
ports_to_check = [(port, "Frontend UI")]
|
||||||
|
|
||||||
|
if start_backend:
|
||||||
|
ports_to_check.append((backend_port, "Backend API"))
|
||||||
|
|
||||||
|
if start_mcp:
|
||||||
|
ports_to_check.append((mcp_port, "MCP Server"))
|
||||||
|
|
||||||
|
logger.info("Checking port availability...")
|
||||||
|
all_ports_available, unavailable_services = _check_required_ports(ports_to_check)
|
||||||
|
|
||||||
|
if not all_ports_available:
|
||||||
|
error_msg = f"Cannot start cognee UI: The following services have ports already in use: {', '.join(unavailable_services)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.error("Please stop the conflicting services or change the port configuration.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info("✓ All required ports are available")
|
||||||
backend_process = None
|
backend_process = None
|
||||||
|
|
||||||
|
if start_mcp:
|
||||||
|
logger.info("Starting Cognee MCP server with Docker...")
|
||||||
|
cwd = os.getcwd()
|
||||||
|
env_file = os.path.join(cwd, ".env")
|
||||||
|
try:
|
||||||
|
mcp_process = subprocess.Popen(
|
||||||
|
[
|
||||||
|
"docker",
|
||||||
|
"run",
|
||||||
|
"-p",
|
||||||
|
f"{mcp_port}:8000",
|
||||||
|
"--rm",
|
||||||
|
"--env-file",
|
||||||
|
env_file,
|
||||||
|
"-e",
|
||||||
|
"TRANSPORT_MODE=sse",
|
||||||
|
"cognee/cognee-mcp:daulet-dev",
|
||||||
|
],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
_stream_process_output(mcp_process, "stdout", "[MCP]", "\033[34m") # Blue
|
||||||
|
_stream_process_output(mcp_process, "stderr", "[MCP]", "\033[34m") # Blue
|
||||||
|
|
||||||
|
pid_callback(mcp_process.pid)
|
||||||
|
logger.info(f"✓ Cognee MCP server starting on http://127.0.0.1:{mcp_port}/sse")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start MCP server with Docker: {str(e)}")
|
||||||
# Start backend server if requested
|
# Start backend server if requested
|
||||||
if start_backend:
|
if start_backend:
|
||||||
logger.info("Starting cognee backend API server...")
|
logger.info("Starting cognee backend API server...")
|
||||||
|
|
@ -389,16 +519,19 @@ def start_ui(
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"cognee.api.client:app",
|
"cognee.api.client:app",
|
||||||
"--host",
|
"--host",
|
||||||
backend_host,
|
"localhost",
|
||||||
"--port",
|
"--port",
|
||||||
str(backend_port),
|
str(backend_port),
|
||||||
],
|
],
|
||||||
# Inherit stdout/stderr from parent process to show logs
|
stdout=subprocess.PIPE,
|
||||||
stdout=None,
|
stderr=subprocess.PIPE,
|
||||||
stderr=None,
|
|
||||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Start threads to stream backend output with prefix
|
||||||
|
_stream_process_output(backend_process, "stdout", "[BACKEND]", "\033[32m") # Green
|
||||||
|
_stream_process_output(backend_process, "stderr", "[BACKEND]", "\033[32m") # Green
|
||||||
|
|
||||||
pid_callback(backend_process.pid)
|
pid_callback(backend_process.pid)
|
||||||
|
|
||||||
# Give the backend a moment to start
|
# Give the backend a moment to start
|
||||||
|
|
@ -408,7 +541,7 @@ def start_ui(
|
||||||
logger.error("Backend server failed to start - process exited early")
|
logger.error("Backend server failed to start - process exited early")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.info(f"✓ Backend API started at http://{backend_host}:{backend_port}")
|
logger.info(f"✓ Backend API started at http://localhost:{backend_port}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start backend server: {str(e)}")
|
logger.error(f"Failed to start backend server: {str(e)}")
|
||||||
|
|
@ -453,11 +586,11 @@ def start_ui(
|
||||||
|
|
||||||
# Prepare environment variables
|
# Prepare environment variables
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env["HOST"] = host
|
env["HOST"] = "localhost"
|
||||||
env["PORT"] = str(port)
|
env["PORT"] = str(port)
|
||||||
|
|
||||||
# Start the development server
|
# Start the development server
|
||||||
logger.info(f"Starting frontend server at http://{host}:{port}")
|
logger.info(f"Starting frontend server at http://localhost:{port}")
|
||||||
logger.info("This may take a moment to compile and start...")
|
logger.info("This may take a moment to compile and start...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -468,10 +601,13 @@ def start_ui(
|
||||||
env=env,
|
env=env,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
text=True,
|
|
||||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Start threads to stream frontend output with prefix
|
||||||
|
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow
|
||||||
|
_stream_process_output(process, "stderr", "[FRONTEND]", "\033[33m") # Yellow
|
||||||
|
|
||||||
pid_callback(process.pid)
|
pid_callback(process.pid)
|
||||||
|
|
||||||
# Give it a moment to start up
|
# Give it a moment to start up
|
||||||
|
|
@ -479,10 +615,7 @@ def start_ui(
|
||||||
|
|
||||||
# Check if process is still running
|
# Check if process is still running
|
||||||
if process.poll() is not None:
|
if process.poll() is not None:
|
||||||
stdout, stderr = process.communicate()
|
logger.error("Frontend server failed to start - check the logs above for details")
|
||||||
logger.error("Frontend server failed to start:")
|
|
||||||
logger.error(f"stdout: {stdout}")
|
|
||||||
logger.error(f"stderr: {stderr}")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Open browser if requested
|
# Open browser if requested
|
||||||
|
|
@ -491,7 +624,7 @@ def start_ui(
|
||||||
def open_browser_delayed():
|
def open_browser_delayed():
|
||||||
time.sleep(5) # Give Next.js time to fully start
|
time.sleep(5) # Give Next.js time to fully start
|
||||||
try:
|
try:
|
||||||
webbrowser.open(f"http://{host}:{port}") # TODO: use dashboard url?
|
webbrowser.open(f"http://localhost:{port}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not open browser automatically: {e}")
|
logger.warning(f"Could not open browser automatically: {e}")
|
||||||
|
|
||||||
|
|
@ -499,13 +632,9 @@ def start_ui(
|
||||||
browser_thread.start()
|
browser_thread.start()
|
||||||
|
|
||||||
logger.info("✓ Cognee UI is starting up...")
|
logger.info("✓ Cognee UI is starting up...")
|
||||||
logger.info(f"✓ Open your browser to: http://{host}:{port}")
|
logger.info(f"✓ Open your browser to: http://localhost:{port}")
|
||||||
logger.info("✓ The UI will be available once Next.js finishes compiling")
|
logger.info("✓ The UI will be available once Next.js finishes compiling")
|
||||||
|
|
||||||
# Store backend process reference in the frontend process for cleanup
|
|
||||||
if backend_process:
|
|
||||||
process._cognee_backend_process = backend_process
|
|
||||||
|
|
||||||
return process
|
return process
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -523,102 +652,3 @@ def start_ui(
|
||||||
except (OSError, ProcessLookupError):
|
except (OSError, ProcessLookupError):
|
||||||
pass
|
pass
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def stop_ui(process: subprocess.Popen) -> bool:
|
|
||||||
"""
|
|
||||||
Stop a running UI server process and backend process (if started), along with all their children.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
process: The subprocess.Popen object returned by start_ui()
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if stopped successfully, False otherwise
|
|
||||||
"""
|
|
||||||
if not process:
|
|
||||||
return False
|
|
||||||
|
|
||||||
success = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
# First, stop the backend process if it exists
|
|
||||||
backend_process = getattr(process, "_cognee_backend_process", None)
|
|
||||||
if backend_process:
|
|
||||||
logger.info("Stopping backend server...")
|
|
||||||
try:
|
|
||||||
backend_process.terminate()
|
|
||||||
try:
|
|
||||||
backend_process.wait(timeout=5)
|
|
||||||
logger.info("Backend server stopped gracefully")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
logger.warning("Backend didn't terminate gracefully, forcing kill")
|
|
||||||
backend_process.kill()
|
|
||||||
backend_process.wait()
|
|
||||||
logger.info("Backend server stopped")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error stopping backend server: {str(e)}")
|
|
||||||
success = False
|
|
||||||
|
|
||||||
# Now stop the frontend process
|
|
||||||
logger.info("Stopping frontend server...")
|
|
||||||
# Try to terminate the process group (includes child processes like Next.js)
|
|
||||||
if hasattr(os, "killpg"):
|
|
||||||
try:
|
|
||||||
# Kill the entire process group
|
|
||||||
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
|
|
||||||
logger.debug("Sent SIGTERM to process group")
|
|
||||||
except (OSError, ProcessLookupError):
|
|
||||||
# Fall back to terminating just the main process
|
|
||||||
process.terminate()
|
|
||||||
logger.debug("Terminated main process only")
|
|
||||||
else:
|
|
||||||
process.terminate()
|
|
||||||
logger.debug("Terminated main process (Windows)")
|
|
||||||
|
|
||||||
try:
|
|
||||||
process.wait(timeout=10)
|
|
||||||
logger.info("Frontend server stopped gracefully")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
logger.warning("Frontend didn't terminate gracefully, forcing kill")
|
|
||||||
|
|
||||||
# Force kill the process group
|
|
||||||
if hasattr(os, "killpg"):
|
|
||||||
try:
|
|
||||||
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
|
|
||||||
logger.debug("Sent SIGKILL to process group")
|
|
||||||
except (OSError, ProcessLookupError):
|
|
||||||
process.kill()
|
|
||||||
logger.debug("Force killed main process only")
|
|
||||||
else:
|
|
||||||
process.kill()
|
|
||||||
logger.debug("Force killed main process (Windows)")
|
|
||||||
|
|
||||||
process.wait()
|
|
||||||
|
|
||||||
if success:
|
|
||||||
logger.info("UI servers stopped successfully")
|
|
||||||
|
|
||||||
return success
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error stopping UI servers: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience function similar to DuckDB's approach
|
|
||||||
def ui() -> Optional[subprocess.Popen]:
|
|
||||||
"""
|
|
||||||
Convenient alias for start_ui() with default parameters.
|
|
||||||
Similar to how DuckDB provides simple ui() function.
|
|
||||||
"""
|
|
||||||
return start_ui()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Test the UI startup
|
|
||||||
server = start_ui()
|
|
||||||
if server:
|
|
||||||
try:
|
|
||||||
input("Press Enter to stop the server...")
|
|
||||||
finally:
|
|
||||||
stop_ui(server)
|
|
||||||
|
|
|
||||||
|
|
@ -204,19 +204,27 @@ def main() -> int:
|
||||||
nonlocal spawned_pids
|
nonlocal spawned_pids
|
||||||
spawned_pids.append(pid)
|
spawned_pids.append(pid)
|
||||||
|
|
||||||
|
frontend_port = 3000
|
||||||
|
start_backend, backend_port = True, 8000
|
||||||
|
start_mcp, mcp_port = True, 8001
|
||||||
server_process = start_ui(
|
server_process = start_ui(
|
||||||
host="localhost",
|
|
||||||
port=3000,
|
|
||||||
open_browser=True,
|
|
||||||
start_backend=True,
|
|
||||||
auto_download=True,
|
|
||||||
pid_callback=pid_callback,
|
pid_callback=pid_callback,
|
||||||
|
port=frontend_port,
|
||||||
|
open_browser=True,
|
||||||
|
auto_download=True,
|
||||||
|
start_backend=start_backend,
|
||||||
|
backend_port=backend_port,
|
||||||
|
start_mcp=start_mcp,
|
||||||
|
mcp_port=mcp_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
if server_process:
|
if server_process:
|
||||||
fmt.success("UI server started successfully!")
|
fmt.success("UI server started successfully!")
|
||||||
fmt.echo("The interface is available at: http://localhost:3000")
|
fmt.echo(f"The interface is available at: http://localhost:{frontend_port}")
|
||||||
fmt.echo("The API backend is available at: http://localhost:8000")
|
if start_backend:
|
||||||
|
fmt.echo(f"The API backend is available at: http://localhost:{backend_port}")
|
||||||
|
if start_mcp:
|
||||||
|
fmt.echo(f"The MCP server is available at: http://localhost:{mcp_port}")
|
||||||
fmt.note("Press Ctrl+C to stop the server...")
|
fmt.note("Press Ctrl+C to stop the server...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ class S3FileStorage(Storage):
|
||||||
self.s3 = s3fs.S3FileSystem(
|
self.s3 = s3fs.S3FileSystem(
|
||||||
key=s3_config.aws_access_key_id,
|
key=s3_config.aws_access_key_id,
|
||||||
secret=s3_config.aws_secret_access_key,
|
secret=s3_config.aws_secret_access_key,
|
||||||
|
token=s3_config.aws_session_token,
|
||||||
anon=False,
|
anon=False,
|
||||||
endpoint_url=s3_config.aws_endpoint_url,
|
endpoint_url=s3_config.aws_endpoint_url,
|
||||||
client_kwargs={"region_name": s3_config.aws_region},
|
client_kwargs={"region_name": s3_config.aws_region},
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ class S3Config(BaseSettings):
|
||||||
aws_endpoint_url: Optional[str] = None
|
aws_endpoint_url: Optional[str] = None
|
||||||
aws_access_key_id: Optional[str] = None
|
aws_access_key_id: Optional[str] = None
|
||||||
aws_secret_access_key: Optional[str] = None
|
aws_secret_access_key: Optional[str] = None
|
||||||
|
aws_session_token: Optional[str] = None
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,9 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
||||||
"TableRow": "#f47710",
|
"TableRow": "#f47710",
|
||||||
"TableType": "#6510f4",
|
"TableType": "#6510f4",
|
||||||
"ColumnValue": "#13613a",
|
"ColumnValue": "#13613a",
|
||||||
|
"SchemaTable": "#f47710",
|
||||||
|
"DatabaseSchema": "#6510f4",
|
||||||
|
"SchemaRelationship": "#13613a",
|
||||||
"default": "#D3D3D3",
|
"default": "#D3D3D3",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,20 @@ from sqlalchemy import text
|
||||||
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
|
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
|
||||||
get_migration_relational_engine,
|
get_migration_relational_engine,
|
||||||
)
|
)
|
||||||
|
from cognee.infrastructure.databases.relational.config import get_migration_config
|
||||||
|
|
||||||
from cognee.tasks.storage.index_data_points import index_data_points
|
from cognee.tasks.storage.index_data_points import index_data_points
|
||||||
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
from cognee.tasks.storage.index_graph_edges import index_graph_edges
|
||||||
|
from cognee.tasks.schema.ingest_database_schema import ingest_database_schema
|
||||||
|
|
||||||
from cognee.modules.engine.models import TableRow, TableType, ColumnValue
|
from cognee.modules.engine.models import TableRow, TableType, ColumnValue
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def migrate_relational_database(graph_db, schema, migrate_column_data=True):
|
async def migrate_relational_database(
|
||||||
|
graph_db, schema, migrate_column_data=True, schema_only=False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Migrates data from a relational database into a graph database.
|
Migrates data from a relational database into a graph database.
|
||||||
|
|
||||||
|
|
@ -26,11 +30,133 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
|
||||||
|
|
||||||
Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model.
|
Both TableType and TableRow inherit from DataPoint to maintain consistency with Cognee data model.
|
||||||
"""
|
"""
|
||||||
|
# Create a mapping of node_id to node objects for referencing in edge creation
|
||||||
|
if schema_only:
|
||||||
|
node_mapping, edge_mapping = await schema_only_ingestion(schema)
|
||||||
|
|
||||||
|
else:
|
||||||
|
node_mapping, edge_mapping = await complete_database_ingestion(schema, migrate_column_data)
|
||||||
|
|
||||||
|
def _remove_duplicate_edges(edge_mapping):
|
||||||
|
seen = set()
|
||||||
|
unique_original_shape = []
|
||||||
|
|
||||||
|
for tup in edge_mapping:
|
||||||
|
# We go through all the tuples in the edge_mapping and we only add unique tuples to the list
|
||||||
|
# To eliminate duplicate edges.
|
||||||
|
source_id, target_id, rel_name, rel_dict = tup
|
||||||
|
# We need to convert the dictionary to a frozenset to be able to compare values for it
|
||||||
|
rel_dict_hashable = frozenset(sorted(rel_dict.items()))
|
||||||
|
hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable)
|
||||||
|
|
||||||
|
# We use the seen set to keep track of unique edges
|
||||||
|
if hashable_tup not in seen:
|
||||||
|
# A list that has frozensets elements instead of dictionaries is needed to be able to compare values
|
||||||
|
seen.add(hashable_tup)
|
||||||
|
# append the original tuple shape (with the dictionary) if it's the first time we see it
|
||||||
|
unique_original_shape.append(tup)
|
||||||
|
|
||||||
|
return unique_original_shape
|
||||||
|
|
||||||
|
# Add all nodes and edges to the graph
|
||||||
|
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
|
||||||
|
# If we'd create nodes and add them to graph in real time the process would take too long.
|
||||||
|
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
|
||||||
|
await graph_db.add_nodes(list(node_mapping.values()))
|
||||||
|
await graph_db.add_edges(_remove_duplicate_edges(edge_mapping))
|
||||||
|
|
||||||
|
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
|
||||||
|
# Cognee uses this information to perform searches on the knowledge graph.
|
||||||
|
await index_data_points(list(node_mapping.values()))
|
||||||
|
await index_graph_edges()
|
||||||
|
|
||||||
|
logger.info("Data successfully migrated from relational database to desired graph database.")
|
||||||
|
return await graph_db.get_graph_data()
|
||||||
|
|
||||||
|
|
||||||
|
async def schema_only_ingestion(schema):
|
||||||
|
node_mapping = {}
|
||||||
|
edge_mapping = []
|
||||||
|
|
||||||
|
# Calling the ingest_database_schema function to return DataPoint subclasses
|
||||||
|
result = await ingest_database_schema(
|
||||||
|
schema=schema,
|
||||||
|
max_sample_rows=5,
|
||||||
|
)
|
||||||
|
database_schema = result["database_schema"]
|
||||||
|
schema_tables = result["schema_tables"]
|
||||||
|
schema_relationships = result["relationships"]
|
||||||
|
database_node_id = database_schema.id
|
||||||
|
node_mapping[database_node_id] = database_schema
|
||||||
|
for table in schema_tables:
|
||||||
|
table_node_id = table.id
|
||||||
|
# Add TableSchema Datapoint as a node.
|
||||||
|
node_mapping[table_node_id] = table
|
||||||
|
edge_mapping.append(
|
||||||
|
(
|
||||||
|
table_node_id,
|
||||||
|
database_node_id,
|
||||||
|
"is_part_of",
|
||||||
|
dict(
|
||||||
|
source_node_id=table_node_id,
|
||||||
|
target_node_id=database_node_id,
|
||||||
|
relationship_name="is_part_of",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
table_name_to_id = {t.name: t.id for t in schema_tables}
|
||||||
|
for rel in schema_relationships:
|
||||||
|
source_table_id = table_name_to_id.get(rel.source_table)
|
||||||
|
target_table_id = table_name_to_id.get(rel.target_table)
|
||||||
|
|
||||||
|
relationship_id = rel.id
|
||||||
|
|
||||||
|
# Add RelationshipTable DataPoint as a node.
|
||||||
|
node_mapping[relationship_id] = rel
|
||||||
|
edge_mapping.append(
|
||||||
|
(
|
||||||
|
source_table_id,
|
||||||
|
relationship_id,
|
||||||
|
"has_relationship",
|
||||||
|
dict(
|
||||||
|
source_node_id=source_table_id,
|
||||||
|
target_node_id=relationship_id,
|
||||||
|
relationship_name=rel.relationship_type,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
edge_mapping.append(
|
||||||
|
(
|
||||||
|
relationship_id,
|
||||||
|
target_table_id,
|
||||||
|
"has_relationship",
|
||||||
|
dict(
|
||||||
|
source_node_id=relationship_id,
|
||||||
|
target_node_id=target_table_id,
|
||||||
|
relationship_name=rel.relationship_type,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
edge_mapping.append(
|
||||||
|
(
|
||||||
|
source_table_id,
|
||||||
|
target_table_id,
|
||||||
|
rel.relationship_type,
|
||||||
|
dict(
|
||||||
|
source_node_id=source_table_id,
|
||||||
|
target_node_id=target_table_id,
|
||||||
|
relationship_name=rel.relationship_type,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return node_mapping, edge_mapping
|
||||||
|
|
||||||
|
|
||||||
|
async def complete_database_ingestion(schema, migrate_column_data):
|
||||||
engine = get_migration_relational_engine()
|
engine = get_migration_relational_engine()
|
||||||
# Create a mapping of node_id to node objects for referencing in edge creation
|
# Create a mapping of node_id to node objects for referencing in edge creation
|
||||||
node_mapping = {}
|
node_mapping = {}
|
||||||
edge_mapping = []
|
edge_mapping = []
|
||||||
|
|
||||||
async with engine.engine.begin() as cursor:
|
async with engine.engine.begin() as cursor:
|
||||||
# First, create table type nodes for all tables
|
# First, create table type nodes for all tables
|
||||||
for table_name, details in schema.items():
|
for table_name, details in schema.items():
|
||||||
|
|
@ -38,7 +164,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
|
||||||
table_node = TableType(
|
table_node = TableType(
|
||||||
id=uuid5(NAMESPACE_OID, name=table_name),
|
id=uuid5(NAMESPACE_OID, name=table_name),
|
||||||
name=table_name,
|
name=table_name,
|
||||||
description=f"Table: {table_name}",
|
description=f'Relational database table with the following name: "{table_name}".',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add TableType node to mapping ( node will be added to the graph later based on this mapping )
|
# Add TableType node to mapping ( node will be added to the graph later based on this mapping )
|
||||||
|
|
@ -75,7 +201,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
|
||||||
name=node_id,
|
name=node_id,
|
||||||
is_a=table_node,
|
is_a=table_node,
|
||||||
properties=str(row_properties),
|
properties=str(row_properties),
|
||||||
description=f"Row in {table_name} with {primary_key_col}={primary_key_value}",
|
description=f'Row in relational database table from the table with the name: "{table_name}" with the following row data {str(row_properties)} where the dictionary key value is the column name and the value is the column value. This row has the id of: {node_id}',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store the node object in our mapping
|
# Store the node object in our mapping
|
||||||
|
|
@ -113,7 +239,7 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
|
||||||
id=uuid5(NAMESPACE_OID, name=column_node_id),
|
id=uuid5(NAMESPACE_OID, name=column_node_id),
|
||||||
name=column_node_id,
|
name=column_node_id,
|
||||||
properties=f"{key} {value} {table_name}",
|
properties=f"{key} {value} {table_name}",
|
||||||
description=f"Column name={key} and value={value} from column from table={table_name}",
|
description=f"column from relational database table={table_name}. Column name={key} and value={value}. The value of the column is related to the following row with this id: {row_node.id}. This column has the following ID: {column_node_id}",
|
||||||
)
|
)
|
||||||
node_mapping[column_node_id] = column_node
|
node_mapping[column_node_id] = column_node
|
||||||
|
|
||||||
|
|
@ -180,39 +306,4 @@ async def migrate_relational_database(graph_db, schema, migrate_column_data=True
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
return node_mapping, edge_mapping
|
||||||
def _remove_duplicate_edges(edge_mapping):
|
|
||||||
seen = set()
|
|
||||||
unique_original_shape = []
|
|
||||||
|
|
||||||
for tup in edge_mapping:
|
|
||||||
# We go through all the tuples in the edge_mapping and we only add unique tuples to the list
|
|
||||||
# To eliminate duplicate edges.
|
|
||||||
source_id, target_id, rel_name, rel_dict = tup
|
|
||||||
# We need to convert the dictionary to a frozenset to be able to compare values for it
|
|
||||||
rel_dict_hashable = frozenset(sorted(rel_dict.items()))
|
|
||||||
hashable_tup = (source_id, target_id, rel_name, rel_dict_hashable)
|
|
||||||
|
|
||||||
# We use the seen set to keep track of unique edges
|
|
||||||
if hashable_tup not in seen:
|
|
||||||
# A list that has frozensets elements instead of dictionaries is needed to be able to compare values
|
|
||||||
seen.add(hashable_tup)
|
|
||||||
# append the original tuple shape (with the dictionary) if it's the first time we see it
|
|
||||||
unique_original_shape.append(tup)
|
|
||||||
|
|
||||||
return unique_original_shape
|
|
||||||
|
|
||||||
# Add all nodes and edges to the graph
|
|
||||||
# NOTE: Nodes and edges have to be added in batch for speed optimization, Especially for NetworkX.
|
|
||||||
# If we'd create nodes and add them to graph in real time the process would take too long.
|
|
||||||
# Every node and edge added to NetworkX is saved to file which is very slow when not done in batches.
|
|
||||||
await graph_db.add_nodes(list(node_mapping.values()))
|
|
||||||
await graph_db.add_edges(_remove_duplicate_edges(edge_mapping))
|
|
||||||
|
|
||||||
# In these steps we calculate the vector embeddings of our nodes and edges and save them to vector database
|
|
||||||
# Cognee uses this information to perform searches on the knowledge graph.
|
|
||||||
await index_data_points(list(node_mapping.values()))
|
|
||||||
await index_graph_edges()
|
|
||||||
|
|
||||||
logger.info("Data successfully migrated from relational database to desired graph database.")
|
|
||||||
return await graph_db.get_graph_data()
|
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,10 @@ async def resolve_data_directories(
|
||||||
import s3fs
|
import s3fs
|
||||||
|
|
||||||
fs = s3fs.S3FileSystem(
|
fs = s3fs.S3FileSystem(
|
||||||
key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, anon=False
|
key=s3_config.aws_access_key_id,
|
||||||
|
secret=s3_config.aws_secret_access_key,
|
||||||
|
token=s3_config.aws_session_token,
|
||||||
|
anon=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
for item in data:
|
for item in data:
|
||||||
|
|
|
||||||
134
cognee/tasks/schema/ingest_database_schema.py
Normal file
134
cognee/tasks/schema/ingest_database_schema.py
Normal file
|
|
@ -0,0 +1,134 @@
|
||||||
|
from typing import List, Dict
|
||||||
|
from uuid import uuid5, NAMESPACE_OID
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
|
from sqlalchemy import text
|
||||||
|
from cognee.tasks.schema.models import DatabaseSchema, SchemaTable, SchemaRelationship
|
||||||
|
from cognee.infrastructure.databases.relational.get_migration_relational_engine import (
|
||||||
|
get_migration_relational_engine,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.databases.relational.config import get_migration_config
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
|
async def ingest_database_schema(
|
||||||
|
schema,
|
||||||
|
max_sample_rows: int = 0,
|
||||||
|
) -> Dict[str, List[DataPoint] | DataPoint]:
|
||||||
|
"""
|
||||||
|
Extract database schema metadata (optionally with sample data) and return DataPoint models for graph construction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: Database schema
|
||||||
|
max_sample_rows: Maximum sample rows per table (0 means no sampling)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys:
|
||||||
|
"database_schema": DatabaseSchema
|
||||||
|
"schema_tables": List[SchemaTable]
|
||||||
|
"relationships": List[SchemaRelationship]
|
||||||
|
"""
|
||||||
|
|
||||||
|
tables = {}
|
||||||
|
sample_data = {}
|
||||||
|
schema_tables = []
|
||||||
|
schema_relationships = []
|
||||||
|
|
||||||
|
migration_config = get_migration_config()
|
||||||
|
engine = get_migration_relational_engine()
|
||||||
|
qi = engine.engine.dialect.identifier_preparer.quote
|
||||||
|
try:
|
||||||
|
max_sample_rows = max(0, int(max_sample_rows))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
max_sample_rows = 0
|
||||||
|
|
||||||
|
def qname(name: str):
|
||||||
|
split_name = name.split(".")
|
||||||
|
return ".".join(qi(p) for p in split_name)
|
||||||
|
|
||||||
|
async with engine.engine.begin() as cursor:
|
||||||
|
for table_name, details in schema.items():
|
||||||
|
tn = qname(table_name)
|
||||||
|
if max_sample_rows > 0:
|
||||||
|
rows_result = await cursor.execute(
|
||||||
|
text(f"SELECT * FROM {tn} LIMIT :limit;"), # noqa: S608 - tn is fully quoted
|
||||||
|
{"limit": max_sample_rows},
|
||||||
|
)
|
||||||
|
rows = [dict(r) for r in rows_result.mappings().all()]
|
||||||
|
else:
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
if engine.engine.dialect.name == "postgresql":
|
||||||
|
if "." in table_name:
|
||||||
|
schema_part, table_part = table_name.split(".", 1)
|
||||||
|
else:
|
||||||
|
schema_part, table_part = "public", table_name
|
||||||
|
estimate = await cursor.execute(
|
||||||
|
text(
|
||||||
|
"SELECT reltuples::bigint AS estimate "
|
||||||
|
"FROM pg_class c "
|
||||||
|
"JOIN pg_namespace n ON n.oid = c.relnamespace "
|
||||||
|
"WHERE n.nspname = :schema AND c.relname = :table"
|
||||||
|
),
|
||||||
|
{"schema": schema_part, "table": table_part},
|
||||||
|
)
|
||||||
|
row_count_estimate = estimate.scalar() or 0
|
||||||
|
else:
|
||||||
|
count_result = await cursor.execute(text(f"SELECT COUNT(*) FROM {tn};")) # noqa: S608 - tn is fully quoted
|
||||||
|
row_count_estimate = count_result.scalar()
|
||||||
|
|
||||||
|
schema_table = SchemaTable(
|
||||||
|
id=uuid5(NAMESPACE_OID, name=f"{table_name}"),
|
||||||
|
name=table_name,
|
||||||
|
columns=details["columns"],
|
||||||
|
primary_key=details.get("primary_key"),
|
||||||
|
foreign_keys=details.get("foreign_keys", []),
|
||||||
|
sample_rows=rows,
|
||||||
|
row_count_estimate=row_count_estimate,
|
||||||
|
description=f"Relational database table with '{table_name}' with {len(details['columns'])} columns and approx. {row_count_estimate} rows."
|
||||||
|
f"Here are the columns this table contains: {details['columns']}"
|
||||||
|
f"Here are a few sample_rows to show the contents of the table: {rows}"
|
||||||
|
f"Table is part of the database: {migration_config.migration_db_name}",
|
||||||
|
)
|
||||||
|
schema_tables.append(schema_table)
|
||||||
|
tables[table_name] = details
|
||||||
|
sample_data[table_name] = rows
|
||||||
|
|
||||||
|
for fk in details.get("foreign_keys", []):
|
||||||
|
ref_table_fq = fk["ref_table"]
|
||||||
|
if "." not in ref_table_fq and "." in table_name:
|
||||||
|
ref_table_fq = f"{table_name.split('.', 1)[0]}.{ref_table_fq}"
|
||||||
|
|
||||||
|
relationship_name = (
|
||||||
|
f"{table_name}:{fk['column']}->{ref_table_fq}:{fk['ref_column']}"
|
||||||
|
)
|
||||||
|
relationship = SchemaRelationship(
|
||||||
|
id=uuid5(NAMESPACE_OID, name=relationship_name),
|
||||||
|
name=relationship_name,
|
||||||
|
source_table=table_name,
|
||||||
|
target_table=ref_table_fq,
|
||||||
|
relationship_type="foreign_key",
|
||||||
|
source_column=fk["column"],
|
||||||
|
target_column=fk["ref_column"],
|
||||||
|
description=f"Relational database table foreign key relationship between: {table_name}.{fk['column']} → {ref_table_fq}.{fk['ref_column']}"
|
||||||
|
f"This foreing key relationship between table columns is a part of the following database: {migration_config.migration_db_name}",
|
||||||
|
)
|
||||||
|
schema_relationships.append(relationship)
|
||||||
|
|
||||||
|
id_str = f"{migration_config.migration_db_provider}:{migration_config.migration_db_name}"
|
||||||
|
database_schema = DatabaseSchema(
|
||||||
|
id=uuid5(NAMESPACE_OID, name=id_str),
|
||||||
|
name=migration_config.migration_db_name,
|
||||||
|
database_type=migration_config.migration_db_provider,
|
||||||
|
tables=tables,
|
||||||
|
sample_data=sample_data,
|
||||||
|
extraction_timestamp=datetime.now(timezone.utc),
|
||||||
|
description=f"Database schema containing {len(schema_tables)} tables and {len(schema_relationships)} relationships. "
|
||||||
|
f"The database type is {migration_config.migration_db_provider}."
|
||||||
|
f"The database contains the following tables: {tables}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"database_schema": database_schema,
|
||||||
|
"schema_tables": schema_tables,
|
||||||
|
"relationships": schema_relationships,
|
||||||
|
}
|
||||||
41
cognee/tasks/schema/models.py
Normal file
41
cognee/tasks/schema/models.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseSchema(DataPoint):
|
||||||
|
"""Represents a complete database schema with sample data"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
database_type: str # sqlite, postgres, etc.
|
||||||
|
tables: Dict[str, Dict] # Reuse existing schema format from SqlAlchemyAdapter
|
||||||
|
sample_data: Dict[str, List[Dict]] # Limited examples per table
|
||||||
|
extraction_timestamp: datetime
|
||||||
|
description: str
|
||||||
|
metadata: dict = {"index_fields": ["description", "name"]}
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaTable(DataPoint):
|
||||||
|
"""Represents an individual table schema with relationships"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
columns: List[Dict] # Column definitions with types
|
||||||
|
primary_key: Optional[str]
|
||||||
|
foreign_keys: List[Dict] # Foreign key relationships
|
||||||
|
sample_rows: List[Dict] # Max 3-5 example rows
|
||||||
|
row_count_estimate: Optional[int] # Actual table size
|
||||||
|
description: str
|
||||||
|
metadata: dict = {"index_fields": ["description", "name"]}
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaRelationship(DataPoint):
|
||||||
|
"""Represents relationships between tables"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
source_table: str
|
||||||
|
target_table: str
|
||||||
|
relationship_type: str # "foreign_key", "one_to_many", etc.
|
||||||
|
source_column: str
|
||||||
|
target_column: str
|
||||||
|
description: str
|
||||||
|
metadata: dict = {"index_fields": ["description", "name"]}
|
||||||
|
|
@ -197,6 +197,80 @@ async def relational_db_migration():
|
||||||
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
|
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_schema_only_migration():
|
||||||
|
# 1. Setup test DB and extract schema
|
||||||
|
migration_engine = await setup_test_db()
|
||||||
|
schema = await migration_engine.extract_schema()
|
||||||
|
|
||||||
|
# 2. Setup graph engine
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
|
# 4. Migrate schema only
|
||||||
|
await migrate_relational_database(graph_engine, schema=schema, schema_only=True)
|
||||||
|
|
||||||
|
# 5. Verify number of tables through search
|
||||||
|
search_results = await cognee.search(
|
||||||
|
query_text="How many tables are there in this database",
|
||||||
|
query_type=cognee.SearchType.GRAPH_COMPLETION,
|
||||||
|
top_k=30,
|
||||||
|
)
|
||||||
|
assert any("11" in r for r in search_results), (
|
||||||
|
"Number of tables in the database reported in search_results is either None or not equal to 11"
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower()
|
||||||
|
|
||||||
|
edge_counts = {
|
||||||
|
"is_part_of": 0,
|
||||||
|
"has_relationship": 0,
|
||||||
|
"foreign_key": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if graph_db_provider == "neo4j":
|
||||||
|
for rel_type in edge_counts.keys():
|
||||||
|
query_str = f"""
|
||||||
|
MATCH ()-[r:{rel_type}]->()
|
||||||
|
RETURN count(r) as c
|
||||||
|
"""
|
||||||
|
rows = await graph_engine.query(query_str)
|
||||||
|
edge_counts[rel_type] = rows[0]["c"]
|
||||||
|
|
||||||
|
elif graph_db_provider == "kuzu":
|
||||||
|
for rel_type in edge_counts.keys():
|
||||||
|
query_str = f"""
|
||||||
|
MATCH ()-[r:EDGE]->()
|
||||||
|
WHERE r.relationship_name = '{rel_type}'
|
||||||
|
RETURN count(r) as c
|
||||||
|
"""
|
||||||
|
rows = await graph_engine.query(query_str)
|
||||||
|
edge_counts[rel_type] = rows[0][0]
|
||||||
|
|
||||||
|
elif graph_db_provider == "networkx":
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
for _, _, key, _ in edges:
|
||||||
|
if key in edge_counts:
|
||||||
|
edge_counts[key] += 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
||||||
|
|
||||||
|
# 7. Assert counts match expected values
|
||||||
|
expected_counts = {
|
||||||
|
"is_part_of": 11,
|
||||||
|
"has_relationship": 22,
|
||||||
|
"foreign_key": 11,
|
||||||
|
}
|
||||||
|
|
||||||
|
for rel_type, expected in expected_counts.items():
|
||||||
|
actual = edge_counts[rel_type]
|
||||||
|
assert actual == expected, (
|
||||||
|
f"Expected {expected} edges for relationship '{rel_type}', but found {actual}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Schema-only migration edge counts validated successfully!")
|
||||||
|
print(f"Edge counts: {edge_counts}")
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_sqlite():
|
async def test_migration_sqlite():
|
||||||
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
|
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
|
||||||
|
|
||||||
|
|
@ -209,6 +283,7 @@ async def test_migration_sqlite():
|
||||||
)
|
)
|
||||||
|
|
||||||
await relational_db_migration()
|
await relational_db_migration()
|
||||||
|
await test_schema_only_migration()
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_postgres():
|
async def test_migration_postgres():
|
||||||
|
|
@ -224,6 +299,7 @@ async def test_migration_postgres():
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
await relational_db_migration()
|
await relational_db_migration()
|
||||||
|
await test_schema_only_migration()
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,15 @@
|
||||||
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import cognee
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.infrastructure.databases.relational.config import get_migration_config
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||||
from cognee.infrastructure.databases.relational import (
|
from cognee.infrastructure.databases.relational import (
|
||||||
get_migration_relational_engine,
|
get_migration_relational_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import (
|
from cognee.infrastructure.databases.relational import (
|
||||||
create_db_and_tables as create_relational_db_and_tables,
|
create_db_and_tables as create_relational_db_and_tables,
|
||||||
)
|
)
|
||||||
|
|
@ -32,16 +31,29 @@ from cognee.infrastructure.databases.vector.pgvector import (
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
engine = get_migration_relational_engine()
|
|
||||||
|
|
||||||
# Clean all data stored in Cognee
|
# Clean all data stored in Cognee
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
|
||||||
# Needed to create appropriate tables only on the Cognee side
|
# Needed to create appropriate database tables only on the Cognee side
|
||||||
await create_relational_db_and_tables()
|
await create_relational_db_and_tables()
|
||||||
await create_vector_db_and_tables()
|
await create_vector_db_and_tables()
|
||||||
|
|
||||||
|
# In case environment variables are not set use the example database from the Cognee repo
|
||||||
|
migration_db_provider = os.environ.get("MIGRATION_DB_PROVIDER", "sqlite")
|
||||||
|
migration_db_path = os.environ.get(
|
||||||
|
"MIGRATION_DB_PATH",
|
||||||
|
os.path.join(Path(__file__).resolve().parent.parent.parent, "cognee/tests/test_data"),
|
||||||
|
)
|
||||||
|
migration_db_name = os.environ.get("MIGRATION_DB_NAME", "migration_database.sqlite")
|
||||||
|
|
||||||
|
migration_config = get_migration_config()
|
||||||
|
migration_config.migration_db_provider = migration_db_provider
|
||||||
|
migration_config.migration_db_path = migration_db_path
|
||||||
|
migration_config.migration_db_name = migration_db_name
|
||||||
|
|
||||||
|
engine = get_migration_relational_engine()
|
||||||
|
|
||||||
print("\nExtracting schema of database to migrate.")
|
print("\nExtracting schema of database to migrate.")
|
||||||
schema = await engine.extract_schema()
|
schema = await engine.extract_schema()
|
||||||
print(f"Migrated database schema:\n{schema}")
|
print(f"Migrated database schema:\n{schema}")
|
||||||
|
|
@ -53,10 +65,6 @@ async def main():
|
||||||
await migrate_relational_database(graph, schema=schema)
|
await migrate_relational_database(graph, schema=schema)
|
||||||
print("Relational database migration complete.")
|
print("Relational database migration complete.")
|
||||||
|
|
||||||
# Define location where to store html visualization of graph of the migrated database
|
|
||||||
home_dir = os.path.expanduser("~")
|
|
||||||
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
|
|
||||||
|
|
||||||
# Make sure to set top_k at a high value for a broader search, the default value is only 10!
|
# Make sure to set top_k at a high value for a broader search, the default value is only 10!
|
||||||
# top_k represent the number of graph tripplets to supply to the LLM to answer your question
|
# top_k represent the number of graph tripplets to supply to the LLM to answer your question
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
|
|
@ -69,13 +77,25 @@ async def main():
|
||||||
# Having a top_k value set to too high might overwhelm the LLM context when specific questions need to be answered.
|
# Having a top_k value set to too high might overwhelm the LLM context when specific questions need to be answered.
|
||||||
# For this kind of question we've set the top_k to 30
|
# For this kind of question we've set the top_k to 30
|
||||||
search_results = await cognee.search(
|
search_results = await cognee.search(
|
||||||
query_type=SearchType.GRAPH_COMPLETION_COT,
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
query_text="What invoices are related to Leonie Köhler?",
|
query_text="What invoices are related to Leonie Köhler?",
|
||||||
top_k=30,
|
top_k=30,
|
||||||
)
|
)
|
||||||
print(f"Search results: {search_results}")
|
print(f"Search results: {search_results}")
|
||||||
|
|
||||||
# test.html is a file with visualized data migration
|
search_results = await cognee.search(
|
||||||
|
query_type=SearchType.GRAPH_COMPLETION,
|
||||||
|
query_text="What invoices are related to Luís Gonçalves?",
|
||||||
|
top_k=30,
|
||||||
|
)
|
||||||
|
print(f"Search results: {search_results}")
|
||||||
|
|
||||||
|
# If you check the relational database for this example you can see that the search results successfully found all
|
||||||
|
# the invoices related to the two customers, without any hallucinations or additional information
|
||||||
|
|
||||||
|
# Define location where to store html visualization of graph of the migrated database
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
|
||||||
print("Adding html visualization of graph database after migration.")
|
print("Adding html visualization of graph database after migration.")
|
||||||
await visualize_graph(destination_file_path)
|
await visualize_graph(destination_file_path)
|
||||||
print(f"Visualization can be found at: {destination_file_path}")
|
print(f"Visualization can be found at: {destination_file_path}")
|
||||||
|
|
|
||||||
|
|
@ -29,8 +29,11 @@ async def main():
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# Start the UI server
|
# Start the UI server
|
||||||
|
def dummy_callback(pid):
|
||||||
|
pass
|
||||||
|
|
||||||
server = cognee.start_ui(
|
server = cognee.start_ui(
|
||||||
host="localhost",
|
pid_callback=dummy_callback,
|
||||||
port=3000,
|
port=3000,
|
||||||
open_browser=True, # This will automatically open your browser
|
open_browser=True, # This will automatically open your browser
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue