Merge branch 'dev' into update-endpoint
This commit is contained in:
commit
9ae3c97aef
22 changed files with 3074 additions and 697 deletions
|
|
@ -43,7 +43,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
os: [ubuntu-22.04, macos-13, macos-15, windows-latest]
|
os: [ubuntu-22.04, macos-15, windows-latest]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
|
|
@ -79,7 +79,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
|
|
@ -115,7 +115,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
|
|
@ -151,7 +151,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
|
|
@ -180,7 +180,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
|
|
@ -210,7 +210,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,18 @@
|
||||||
import classNames from "classnames";
|
import classNames from "classnames";
|
||||||
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
|
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
|
||||||
import { forceCollide, forceManyBody } from "d3-force-3d";
|
import { forceCollide, forceManyBody } from "d3-force-3d";
|
||||||
import ForceGraph, { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
|
import dynamic from "next/dynamic";
|
||||||
import { GraphControlsAPI } from "./GraphControls";
|
import { GraphControlsAPI } from "./GraphControls";
|
||||||
import getColorForNodeType from "./getColorForNodeType";
|
import getColorForNodeType from "./getColorForNodeType";
|
||||||
|
|
||||||
|
// Dynamically import ForceGraph to prevent SSR issues
|
||||||
|
const ForceGraph = dynamic(() => import("react-force-graph-2d"), {
|
||||||
|
ssr: false,
|
||||||
|
loading: () => <div className="w-full h-full flex items-center justify-center">Loading graph...</div>
|
||||||
|
});
|
||||||
|
|
||||||
|
import type { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
|
||||||
|
|
||||||
interface GraphVisuzaliationProps {
|
interface GraphVisuzaliationProps {
|
||||||
ref: MutableRefObject<GraphVisualizationAPI>;
|
ref: MutableRefObject<GraphVisualizationAPI>;
|
||||||
data?: GraphData<NodeObject, LinkObject>;
|
data?: GraphData<NodeObject, LinkObject>;
|
||||||
|
|
@ -200,7 +208,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
||||||
const graphRef = useRef<ForceGraphMethods>();
|
const graphRef = useRef<ForceGraphMethods>();
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (typeof window !== "undefined" && data && graphRef.current) {
|
if (data && graphRef.current) {
|
||||||
// add collision force
|
// add collision force
|
||||||
graphRef.current.d3Force("collision", forceCollide(nodeSize * 1.5));
|
graphRef.current.d3Force("collision", forceCollide(nodeSize * 1.5));
|
||||||
graphRef.current.d3Force("charge", forceManyBody().strength(-10).distanceMin(10).distanceMax(50));
|
graphRef.current.d3Force("charge", forceManyBody().strength(-10).distanceMin(10).distanceMax(50));
|
||||||
|
|
@ -216,56 +224,34 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
|
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
|
||||||
{(data && typeof window !== "undefined") ? (
|
<ForceGraph
|
||||||
<ForceGraph
|
ref={graphRef}
|
||||||
ref={graphRef}
|
width={dimensions.width}
|
||||||
width={dimensions.width}
|
height={dimensions.height}
|
||||||
height={dimensions.height}
|
dagMode={graphShape as unknown as undefined}
|
||||||
dagMode={graphShape as unknown as undefined}
|
dagLevelDistance={data ? 300 : 100}
|
||||||
dagLevelDistance={300}
|
onDagError={handleDagError}
|
||||||
onDagError={handleDagError}
|
graphData={data || {
|
||||||
graphData={data}
|
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
|
||||||
|
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
|
||||||
|
}}
|
||||||
|
|
||||||
nodeLabel="label"
|
nodeLabel="label"
|
||||||
nodeRelSize={nodeSize}
|
nodeRelSize={data ? nodeSize : 20}
|
||||||
nodeCanvasObject={renderNode}
|
nodeCanvasObject={data ? renderNode : renderInitialNode}
|
||||||
nodeCanvasObjectMode={() => "replace"}
|
nodeCanvasObjectMode={() => data ? "replace" : "after"}
|
||||||
|
nodeAutoColorBy={data ? undefined : "type"}
|
||||||
|
|
||||||
linkLabel="label"
|
linkLabel="label"
|
||||||
linkCanvasObject={renderLink}
|
linkCanvasObject={renderLink}
|
||||||
linkCanvasObjectMode={() => "after"}
|
linkCanvasObjectMode={() => "after"}
|
||||||
linkDirectionalArrowLength={3.5}
|
linkDirectionalArrowLength={3.5}
|
||||||
linkDirectionalArrowRelPos={1}
|
linkDirectionalArrowRelPos={1}
|
||||||
|
|
||||||
onNodeClick={handleNodeClick}
|
onNodeClick={handleNodeClick}
|
||||||
onBackgroundClick={handleBackgroundClick}
|
onBackgroundClick={handleBackgroundClick}
|
||||||
d3VelocityDecay={0.3}
|
d3VelocityDecay={data ? 0.3 : undefined}
|
||||||
/>
|
/>
|
||||||
) : (
|
|
||||||
<ForceGraph
|
|
||||||
ref={graphRef}
|
|
||||||
width={dimensions.width}
|
|
||||||
height={dimensions.height}
|
|
||||||
dagMode={graphShape as unknown as undefined}
|
|
||||||
dagLevelDistance={100}
|
|
||||||
graphData={{
|
|
||||||
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
|
|
||||||
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
|
|
||||||
}}
|
|
||||||
|
|
||||||
nodeLabel="label"
|
|
||||||
nodeRelSize={20}
|
|
||||||
nodeCanvasObject={renderInitialNode}
|
|
||||||
nodeCanvasObjectMode={() => "after"}
|
|
||||||
nodeAutoColorBy="type"
|
|
||||||
|
|
||||||
linkLabel="label"
|
|
||||||
linkCanvasObject={renderLink}
|
|
||||||
linkCanvasObjectMode={() => "after"}
|
|
||||||
linkDirectionalArrowLength={3.5}
|
|
||||||
linkDirectionalArrowRelPos={1}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,13 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
|
||||||
)
|
)
|
||||||
.then((response) => handleServerErrors(response, retry, useCloud))
|
.then((response) => handleServerErrors(response, retry, useCloud))
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
|
// Handle network errors more gracefully
|
||||||
|
if (error.name === 'TypeError' && error.message.includes('fetch')) {
|
||||||
|
return Promise.reject(
|
||||||
|
new Error("Backend server is not responding. Please check if the server is running.")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (error.detail === undefined) {
|
if (error.detail === undefined) {
|
||||||
return Promise.reject(
|
return Promise.reject(
|
||||||
new Error("No connection to the server.")
|
new Error("No connection to the server.")
|
||||||
|
|
@ -64,8 +71,27 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fetch.checkHealth = () => {
|
fetch.checkHealth = async () => {
|
||||||
return global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
|
const maxRetries = 5;
|
||||||
|
const retryDelay = 1000; // 1 second
|
||||||
|
|
||||||
|
for (let i = 0; i < maxRetries; i++) {
|
||||||
|
try {
|
||||||
|
const response = await global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
|
||||||
|
if (response.ok) {
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If this is the last retry, throw the error
|
||||||
|
if (i === maxRetries - 1) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
// Wait before retrying
|
||||||
|
await new Promise(resolve => setTimeout(resolve, retryDelay));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error("Backend server is not responding after multiple attempts");
|
||||||
};
|
};
|
||||||
|
|
||||||
fetch.checkMCPHealth = () => {
|
fetch.checkMCPHealth = () => {
|
||||||
|
|
|
||||||
|
|
@ -194,7 +194,7 @@ class HealthChecker:
|
||||||
config = get_llm_config()
|
config = get_llm_config()
|
||||||
|
|
||||||
# Test actual API connection with minimal request
|
# Test actual API connection with minimal request
|
||||||
LLMGateway.show_prompt("test", "test")
|
LLMGateway.show_prompt("test", "test.txt")
|
||||||
|
|
||||||
response_time = int((time.time() - start_time) * 1000)
|
response_time = int((time.time() - start_time) * 1000)
|
||||||
return ComponentHealth(
|
return ComponentHealth(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
|
import signal
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
|
|
@ -288,6 +290,7 @@ def check_node_npm() -> tuple[bool, str]:
|
||||||
Check if Node.js and npm are available.
|
Check if Node.js and npm are available.
|
||||||
Returns (is_available, error_message)
|
Returns (is_available, error_message)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check Node.js
|
# Check Node.js
|
||||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
|
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
|
||||||
|
|
@ -297,8 +300,17 @@ def check_node_npm() -> tuple[bool, str]:
|
||||||
node_version = result.stdout.strip()
|
node_version = result.stdout.strip()
|
||||||
logger.debug(f"Found Node.js version: {node_version}")
|
logger.debug(f"Found Node.js version: {node_version}")
|
||||||
|
|
||||||
# Check npm
|
# Check npm - handle Windows PowerShell scripts
|
||||||
result = subprocess.run(["npm", "--version"], capture_output=True, text=True, timeout=10)
|
if platform.system() == "Windows":
|
||||||
|
# On Windows, npm might be a PowerShell script, so we need to use shell=True
|
||||||
|
result = subprocess.run(
|
||||||
|
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = subprocess.run(
|
||||||
|
["npm", "--version"], capture_output=True, text=True, timeout=10
|
||||||
|
)
|
||||||
|
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
return False, "npm is not installed or not in PATH"
|
return False, "npm is not installed or not in PATH"
|
||||||
|
|
||||||
|
|
@ -320,6 +332,7 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
|
||||||
Install frontend dependencies if node_modules doesn't exist.
|
Install frontend dependencies if node_modules doesn't exist.
|
||||||
This is needed for both development and downloaded frontends since both use npm run dev.
|
This is needed for both development and downloaded frontends since both use npm run dev.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_modules = frontend_path / "node_modules"
|
node_modules = frontend_path / "node_modules"
|
||||||
if node_modules.exists():
|
if node_modules.exists():
|
||||||
logger.debug("Frontend dependencies already installed")
|
logger.debug("Frontend dependencies already installed")
|
||||||
|
|
@ -328,13 +341,24 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
|
||||||
logger.info("Installing frontend dependencies (this may take a few minutes)...")
|
logger.info("Installing frontend dependencies (this may take a few minutes)...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
# Use shell=True on Windows for npm commands
|
||||||
["npm", "install"],
|
if platform.system() == "Windows":
|
||||||
cwd=frontend_path,
|
result = subprocess.run(
|
||||||
capture_output=True,
|
["npm", "install"],
|
||||||
text=True,
|
cwd=frontend_path,
|
||||||
timeout=300, # 5 minutes timeout
|
capture_output=True,
|
||||||
)
|
text=True,
|
||||||
|
timeout=300, # 5 minutes timeout
|
||||||
|
shell=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = subprocess.run(
|
||||||
|
["npm", "install"],
|
||||||
|
cwd=frontend_path,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300, # 5 minutes timeout
|
||||||
|
)
|
||||||
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
logger.info("Frontend dependencies installed successfully")
|
logger.info("Frontend dependencies installed successfully")
|
||||||
|
|
@ -595,14 +619,27 @@ def start_ui(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create frontend in its own process group for clean termination
|
# Create frontend in its own process group for clean termination
|
||||||
process = subprocess.Popen(
|
# Use shell=True on Windows for npm commands
|
||||||
["npm", "run", "dev"],
|
if platform.system() == "Windows":
|
||||||
cwd=frontend_path,
|
process = subprocess.Popen(
|
||||||
env=env,
|
["npm", "run", "dev"],
|
||||||
stdout=subprocess.PIPE,
|
cwd=frontend_path,
|
||||||
stderr=subprocess.PIPE,
|
env=env,
|
||||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
stdout=subprocess.PIPE,
|
||||||
)
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
shell=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
process = subprocess.Popen(
|
||||||
|
["npm", "run", "dev"],
|
||||||
|
cwd=frontend_path,
|
||||||
|
env=env,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||||
|
)
|
||||||
|
|
||||||
# Start threads to stream frontend output with prefix
|
# Start threads to stream frontend output with prefix
|
||||||
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow
|
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow
|
||||||
|
|
|
||||||
|
|
@ -183,10 +183,20 @@ def main() -> int:
|
||||||
|
|
||||||
for pid in spawned_pids:
|
for pid in spawned_pids:
|
||||||
try:
|
try:
|
||||||
pgid = os.getpgid(pid)
|
if hasattr(os, "killpg"):
|
||||||
os.killpg(pgid, signal.SIGTERM)
|
# Unix-like systems: Use process groups
|
||||||
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
|
pgid = os.getpgid(pid)
|
||||||
except (OSError, ProcessLookupError) as e:
|
os.killpg(pgid, signal.SIGTERM)
|
||||||
|
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
|
||||||
|
else:
|
||||||
|
# Windows: Use taskkill to terminate process and its children
|
||||||
|
subprocess.run(
|
||||||
|
["taskkill", "/F", "/T", "/PID", str(pid)],
|
||||||
|
capture_output=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
fmt.success(f"✓ Process {pid} and its children terminated.")
|
||||||
|
except (OSError, ProcessLookupError, subprocess.SubprocessError) as e:
|
||||||
fmt.warning(f"Could not terminate process {pid}: {e}")
|
fmt.warning(f"Could not terminate process {pid}: {e}")
|
||||||
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from cognee.cli.reference import SupportsCliCommand
|
||||||
from cognee.cli import DEFAULT_DOCS_URL
|
from cognee.cli import DEFAULT_DOCS_URL
|
||||||
import cognee.cli.echo as fmt
|
import cognee.cli.echo as fmt
|
||||||
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
||||||
|
from cognee.modules.data.methods.get_deletion_counts import get_deletion_counts
|
||||||
|
|
||||||
|
|
||||||
class DeleteCommand(SupportsCliCommand):
|
class DeleteCommand(SupportsCliCommand):
|
||||||
|
|
@ -41,7 +42,34 @@ Be careful with deletion operations as they are irreversible.
|
||||||
fmt.error("Please specify what to delete: --dataset-name, --user-id, or --all")
|
fmt.error("Please specify what to delete: --dataset-name, --user-id, or --all")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Build confirmation message
|
# If --force is used, skip the preview and go straight to deletion
|
||||||
|
if not args.force:
|
||||||
|
# --- START PREVIEW LOGIC ---
|
||||||
|
fmt.echo("Gathering data for preview...")
|
||||||
|
try:
|
||||||
|
preview_data = asyncio.run(
|
||||||
|
get_deletion_counts(
|
||||||
|
dataset_name=args.dataset_name,
|
||||||
|
user_id=args.user_id,
|
||||||
|
all_data=args.all,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except CliCommandException as e:
|
||||||
|
fmt.error(f"Error occured when fetching preview data: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not preview_data:
|
||||||
|
fmt.success("No data found to delete.")
|
||||||
|
return
|
||||||
|
|
||||||
|
fmt.echo("You are about to delete:")
|
||||||
|
fmt.echo(
|
||||||
|
f"Datasets: {preview_data.datasets}\nEntries: {preview_data.entries}\nUsers: {preview_data.users}"
|
||||||
|
)
|
||||||
|
fmt.echo("-" * 20)
|
||||||
|
# --- END PREVIEW LOGIC ---
|
||||||
|
|
||||||
|
# Build operation message for success/failure logging
|
||||||
if args.all:
|
if args.all:
|
||||||
confirm_msg = "Delete ALL data from cognee?"
|
confirm_msg = "Delete ALL data from cognee?"
|
||||||
operation = "all data"
|
operation = "all data"
|
||||||
|
|
@ -51,8 +79,9 @@ Be careful with deletion operations as they are irreversible.
|
||||||
elif args.user_id:
|
elif args.user_id:
|
||||||
confirm_msg = f"Delete all data for user '{args.user_id}'?"
|
confirm_msg = f"Delete all data for user '{args.user_id}'?"
|
||||||
operation = f"data for user '{args.user_id}'"
|
operation = f"data for user '{args.user_id}'"
|
||||||
|
else:
|
||||||
|
operation = "data"
|
||||||
|
|
||||||
# Confirm deletion unless forced
|
|
||||||
if not args.force:
|
if not args.force:
|
||||||
fmt.warning("This operation is irreversible!")
|
fmt.warning("This operation is irreversible!")
|
||||||
if not fmt.confirm(confirm_msg):
|
if not fmt.confirm(confirm_msg):
|
||||||
|
|
@ -64,6 +93,8 @@ Be careful with deletion operations as they are irreversible.
|
||||||
# Run the async delete function
|
# Run the async delete function
|
||||||
async def run_delete():
|
async def run_delete():
|
||||||
try:
|
try:
|
||||||
|
# NOTE: The underlying cognee.delete() function is currently not working as expected.
|
||||||
|
# This is a separate bug that this preview feature helps to expose.
|
||||||
if args.all:
|
if args.all:
|
||||||
await cognee.delete(dataset_name=None, user_id=args.user_id)
|
await cognee.delete(dataset_name=None, user_id=args.user_id)
|
||||||
else:
|
else:
|
||||||
|
|
@ -72,6 +103,7 @@ Be careful with deletion operations as they are irreversible.
|
||||||
raise CliCommandInnerException(f"Failed to delete: {str(e)}")
|
raise CliCommandInnerException(f"Failed to delete: {str(e)}")
|
||||||
|
|
||||||
asyncio.run(run_delete())
|
asyncio.run(run_delete())
|
||||||
|
# This success message may be inaccurate due to the underlying bug, but we leave it for now.
|
||||||
fmt.success(f"Successfully deleted {operation}")
|
fmt.success(f"Successfully deleted {operation}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
|
||||||
read due to an error.
|
read due to an error.
|
||||||
"""
|
"""
|
||||||
logger = get_logger(level=ERROR)
|
logger = get_logger(level=ERROR)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if base_directory is None:
|
if base_directory is None:
|
||||||
base_directory = get_absolute_path("./infrastructure/llm/prompts")
|
base_directory = get_absolute_path("./infrastructure/llm/prompts")
|
||||||
|
|
@ -35,8 +36,8 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
|
||||||
with open(file_path, "r", encoding="utf-8") as file:
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
return file.read()
|
return file.read()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.error(f"Error: Prompt file not found. Attempted to read: %s {file_path}")
|
logger.error(f"Error: Prompt file not found. Attempted to read: {file_path}")
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An error occurred: %s {e}")
|
logger.error(f"An error occurred: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
1
cognee/infrastructure/llm/prompts/test.txt
Normal file
1
cognee/infrastructure/llm/prompts/test.txt
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
Respond with: test
|
||||||
|
|
@ -73,10 +73,19 @@ class OpenAIAdapter(LLMInterface):
|
||||||
fallback_api_key: str = None,
|
fallback_api_key: str = None,
|
||||||
fallback_endpoint: str = None,
|
fallback_endpoint: str = None,
|
||||||
):
|
):
|
||||||
self.aclient = instructor.from_litellm(
|
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
||||||
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
|
# Make sure all new gpt models will work with this mode as well.
|
||||||
)
|
if "gpt-5" in model:
|
||||||
self.client = instructor.from_litellm(litellm.completion, mode=instructor.Mode.JSON_SCHEMA)
|
self.aclient = instructor.from_litellm(
|
||||||
|
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
|
||||||
|
)
|
||||||
|
self.client = instructor.from_litellm(
|
||||||
|
litellm.completion, mode=instructor.Mode.JSON_SCHEMA
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||||
|
self.client = instructor.from_litellm(litellm.completion)
|
||||||
|
|
||||||
self.transcription_model = transcription_model
|
self.transcription_model = transcription_model
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ class LoaderEngine:
|
||||||
|
|
||||||
self.default_loader_priority = [
|
self.default_loader_priority = [
|
||||||
"text_loader",
|
"text_loader",
|
||||||
|
"advanced_pdf_loader",
|
||||||
"pypdf_loader",
|
"pypdf_loader",
|
||||||
"image_loader",
|
"image_loader",
|
||||||
"audio_loader",
|
"audio_loader",
|
||||||
|
|
@ -86,7 +87,7 @@ class LoaderEngine:
|
||||||
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
||||||
return loader
|
return loader
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Loader does not exist: {loader_name}")
|
logger.info(f"Skipping {loader_name}: Preferred Loader not registered")
|
||||||
|
|
||||||
# Try default priority order
|
# Try default priority order
|
||||||
for loader_name in self.default_loader_priority:
|
for loader_name in self.default_loader_priority:
|
||||||
|
|
@ -95,7 +96,9 @@ class LoaderEngine:
|
||||||
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
||||||
return loader
|
return loader
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Loader does not exist: {loader_name}")
|
logger.info(
|
||||||
|
f"Skipping {loader_name}: Loader not registered (in default priority list)."
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,3 +20,10 @@ try:
|
||||||
__all__.append("UnstructuredLoader")
|
__all__.append("UnstructuredLoader")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .advanced_pdf_loader import AdvancedPdfLoader
|
||||||
|
|
||||||
|
__all__.append("AdvancedPdfLoader")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
|
||||||
244
cognee/infrastructure/loaders/external/advanced_pdf_loader.py
vendored
Normal file
244
cognee/infrastructure/loaders/external/advanced_pdf_loader.py
vendored
Normal file
|
|
@ -0,0 +1,244 @@
|
||||||
|
"""Advanced PDF loader leveraging unstructured for layout-aware extraction."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
import asyncio
|
||||||
|
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||||
|
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
|
||||||
|
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
from cognee.infrastructure.loaders.external.pypdf_loader import PyPdfLoader
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from unstructured.partition.pdf import partition_pdf
|
||||||
|
except ImportError as e:
|
||||||
|
logger.info(
|
||||||
|
"unstructured[pdf] not installed, can't use AdvancedPdfLoader, will use PyPdfLoader instead."
|
||||||
|
)
|
||||||
|
raise ImportError from e
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _PageBuffer:
|
||||||
|
page_num: Optional[int]
|
||||||
|
segments: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedPdfLoader(LoaderInterface):
|
||||||
|
"""
|
||||||
|
PDF loader using unstructured library.
|
||||||
|
|
||||||
|
Extracts text content, images, tables from PDF files page by page, providing
|
||||||
|
structured page information and handling PDF-specific errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_extensions(self) -> List[str]:
|
||||||
|
return ["pdf"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_mime_types(self) -> List[str]:
|
||||||
|
return ["application/pdf"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader_name(self) -> str:
|
||||||
|
return "advanced_pdf_loader"
|
||||||
|
|
||||||
|
def can_handle(self, extension: str, mime_type: str) -> bool:
|
||||||
|
"""Check if file can be handled by this loader."""
|
||||||
|
# Check file extension
|
||||||
|
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def load(self, file_path: str, strategy: str = "auto", **kwargs: Any) -> str:
|
||||||
|
"""Load PDF file using unstructured library. If Exception occurs, fallback to PyPDFLoader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the document file
|
||||||
|
strategy: Partitioning strategy ("auto", "fast", "hi_res", "ocr_only")
|
||||||
|
**kwargs: Additional arguments passed to unstructured partition
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoaderResult with extracted text content and metadata
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Processing PDF: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
file_metadata = await get_file_metadata(f)
|
||||||
|
|
||||||
|
# Name ingested file of current loader based on original file content hash
|
||||||
|
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
|
||||||
|
|
||||||
|
# Set partitioning parameters
|
||||||
|
partition_kwargs: Dict[str, Any] = {
|
||||||
|
"filename": file_path,
|
||||||
|
"strategy": strategy,
|
||||||
|
"infer_table_structure": True,
|
||||||
|
"include_page_breaks": False,
|
||||||
|
"include_metadata": True,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
# Use partition to extract elements
|
||||||
|
elements = partition_pdf(**partition_kwargs)
|
||||||
|
|
||||||
|
# Process elements into text content
|
||||||
|
page_contents = self._format_elements_by_page(elements)
|
||||||
|
|
||||||
|
# Check if there is any content
|
||||||
|
if not page_contents:
|
||||||
|
logger.warning(
|
||||||
|
"AdvancedPdfLoader returned no content. Falling back to PyPDF loader."
|
||||||
|
)
|
||||||
|
return await self._fallback(file_path, **kwargs)
|
||||||
|
|
||||||
|
# Combine all page outputs
|
||||||
|
full_content = "\n".join(page_contents)
|
||||||
|
|
||||||
|
# Store the content
|
||||||
|
storage_config = get_storage_config()
|
||||||
|
data_root_directory = storage_config["data_root_directory"]
|
||||||
|
storage = get_file_storage(data_root_directory)
|
||||||
|
|
||||||
|
full_file_path = await storage.store(storage_file_name, full_content)
|
||||||
|
|
||||||
|
return full_file_path
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to process PDF with AdvancedPdfLoader: %s", exc)
|
||||||
|
return await self._fallback(file_path, **kwargs)
|
||||||
|
|
||||||
|
async def _fallback(self, file_path: str, **kwargs: Any) -> str:
|
||||||
|
logger.info("Falling back to PyPDF loader for %s", file_path)
|
||||||
|
fallback_loader = PyPdfLoader()
|
||||||
|
return await fallback_loader.load(file_path, **kwargs)
|
||||||
|
|
||||||
|
def _format_elements_by_page(self, elements: List[Any]) -> List[str]:
|
||||||
|
"""Format elements by page."""
|
||||||
|
page_buffers: List[_PageBuffer] = []
|
||||||
|
current_buffer = _PageBuffer(page_num=None, segments=[])
|
||||||
|
|
||||||
|
for element in elements:
|
||||||
|
element_dict = self._safe_to_dict(element)
|
||||||
|
metadata = element_dict.get("metadata", {})
|
||||||
|
page_num = metadata.get("page_number")
|
||||||
|
|
||||||
|
if current_buffer.page_num != page_num:
|
||||||
|
if current_buffer.segments:
|
||||||
|
page_buffers.append(current_buffer)
|
||||||
|
current_buffer = _PageBuffer(page_num=page_num, segments=[])
|
||||||
|
|
||||||
|
formatted = self._format_element(element_dict)
|
||||||
|
|
||||||
|
if formatted:
|
||||||
|
current_buffer.segments.append(formatted)
|
||||||
|
|
||||||
|
if current_buffer.segments:
|
||||||
|
page_buffers.append(current_buffer)
|
||||||
|
|
||||||
|
page_contents: List[str] = []
|
||||||
|
for buffer in page_buffers:
|
||||||
|
header = f"Page {buffer.page_num}:\n" if buffer.page_num is not None else "Page:"
|
||||||
|
content = header + "\n\n".join(buffer.segments) + "\n"
|
||||||
|
page_contents.append(str(content))
|
||||||
|
return page_contents
|
||||||
|
|
||||||
|
def _format_element(
|
||||||
|
self,
|
||||||
|
element: Dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""Format element."""
|
||||||
|
element_type = element.get("type")
|
||||||
|
text = self._clean_text(element.get("text", ""))
|
||||||
|
metadata = element.get("metadata", {})
|
||||||
|
|
||||||
|
if element_type.lower() == "table":
|
||||||
|
return self._format_table_element(element) or text
|
||||||
|
|
||||||
|
if element_type.lower() == "image":
|
||||||
|
description = text or self._format_image_element(metadata)
|
||||||
|
return description
|
||||||
|
|
||||||
|
# Ignore header and footer
|
||||||
|
if element_type.lower() in ["header", "footer"]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _format_table_element(self, element: Dict[str, Any]) -> str:
|
||||||
|
"""Format table element."""
|
||||||
|
metadata = element.get("metadata", {})
|
||||||
|
text = self._clean_text(element.get("text", ""))
|
||||||
|
table_html = metadata.get("text_as_html")
|
||||||
|
|
||||||
|
if table_html:
|
||||||
|
return table_html.strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _format_image_element(self, metadata: Dict[str, Any]) -> str:
|
||||||
|
"""Format image."""
|
||||||
|
placeholder = "[Image omitted]"
|
||||||
|
image_text = placeholder
|
||||||
|
coordinates = metadata.get("coordinates", {})
|
||||||
|
points = coordinates.get("points") if isinstance(coordinates, dict) else None
|
||||||
|
if points and isinstance(points, tuple) and len(points) == 4:
|
||||||
|
leftup = points[0]
|
||||||
|
rightdown = points[3]
|
||||||
|
if (
|
||||||
|
isinstance(leftup, tuple)
|
||||||
|
and isinstance(rightdown, tuple)
|
||||||
|
and len(leftup) == 2
|
||||||
|
and len(rightdown) == 2
|
||||||
|
):
|
||||||
|
image_text = f"{placeholder} (bbox=({leftup[0]}, {leftup[1]}, {rightdown[0]}, {rightdown[1]}))"
|
||||||
|
|
||||||
|
layout_width = coordinates.get("layout_width")
|
||||||
|
layout_height = coordinates.get("layout_height")
|
||||||
|
system = coordinates.get("system")
|
||||||
|
if layout_width and layout_height and system:
|
||||||
|
image_text = (
|
||||||
|
image_text
|
||||||
|
+ f", system={system}, layout_width={layout_width}, layout_height={layout_height}))"
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_text
|
||||||
|
|
||||||
|
def _safe_to_dict(self, element: Any) -> Dict[str, Any]:
|
||||||
|
"""Safe to dict."""
|
||||||
|
try:
|
||||||
|
if hasattr(element, "to_dict"):
|
||||||
|
return element.to_dict()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
fallback_type = getattr(element, "category", None)
|
||||||
|
if not fallback_type:
|
||||||
|
fallback_type = getattr(element, "__class__", type("", (), {})).__name__
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": fallback_type,
|
||||||
|
"text": getattr(element, "text", ""),
|
||||||
|
"metadata": getattr(element, "metadata", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _clean_text(self, value: Any) -> str:
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
return str(value).replace("\xa0", " ").strip()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
loader = AdvancedPdfLoader()
|
||||||
|
asyncio.run(
|
||||||
|
loader.load(
|
||||||
|
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
@ -16,3 +16,10 @@ try:
|
||||||
supported_loaders[UnstructuredLoader.loader_name] = UnstructuredLoader
|
supported_loaders[UnstructuredLoader.loader_name] = UnstructuredLoader
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from cognee.infrastructure.loaders.external import AdvancedPdfLoader
|
||||||
|
|
||||||
|
supported_loaders[AdvancedPdfLoader.loader_name] = AdvancedPdfLoader
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
|
||||||
92
cognee/modules/data/methods/get_deletion_counts.py
Normal file
92
cognee/modules/data/methods/get_deletion_counts.py
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from cognee.cli.exceptions import CliCommandException
|
||||||
|
from cognee.infrastructure.databases.exceptions.exceptions import EntityNotFoundError
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.modules.data.models import Dataset, Data, DatasetData
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.methods import get_user
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeletionCountsPreview:
|
||||||
|
datasets: int = 0
|
||||||
|
data_entries: int = 0
|
||||||
|
users: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def get_deletion_counts(
|
||||||
|
dataset_name: str = None, user_id: str = None, all_data: bool = False
|
||||||
|
) -> DeletionCountsPreview:
|
||||||
|
"""
|
||||||
|
Calculates the number of items that will be deleted based on the provided arguments.
|
||||||
|
"""
|
||||||
|
counts = DeletionCountsPreview()
|
||||||
|
relational_engine = get_relational_engine()
|
||||||
|
async with relational_engine.get_async_session() as session:
|
||||||
|
if dataset_name:
|
||||||
|
# Find the dataset by name
|
||||||
|
dataset_result = await session.execute(
|
||||||
|
select(Dataset).where(Dataset.name == dataset_name)
|
||||||
|
)
|
||||||
|
dataset = dataset_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if dataset is None:
|
||||||
|
raise CliCommandException(
|
||||||
|
f"No Dataset exists with the name {dataset_name}", error_code=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count data entries linked to this dataset
|
||||||
|
count_query = (
|
||||||
|
select(func.count())
|
||||||
|
.select_from(DatasetData)
|
||||||
|
.where(DatasetData.dataset_id == dataset.id)
|
||||||
|
)
|
||||||
|
data_entry_count = (await session.execute(count_query)).scalar_one()
|
||||||
|
counts.users = 1
|
||||||
|
counts.datasets = 1
|
||||||
|
counts.entries = data_entry_count
|
||||||
|
return counts
|
||||||
|
|
||||||
|
elif all_data:
|
||||||
|
# Simplified logic: Get total counts directly from the tables.
|
||||||
|
counts.datasets = (
|
||||||
|
await session.execute(select(func.count()).select_from(Dataset))
|
||||||
|
).scalar_one()
|
||||||
|
counts.entries = (
|
||||||
|
await session.execute(select(func.count()).select_from(Data))
|
||||||
|
).scalar_one()
|
||||||
|
counts.users = (
|
||||||
|
await session.execute(select(func.count()).select_from(User))
|
||||||
|
).scalar_one()
|
||||||
|
return counts
|
||||||
|
|
||||||
|
# Placeholder for user_id logic
|
||||||
|
elif user_id:
|
||||||
|
user = None
|
||||||
|
try:
|
||||||
|
user_uuid = UUID(user_id)
|
||||||
|
user = await get_user(user_uuid)
|
||||||
|
except (ValueError, EntityNotFoundError):
|
||||||
|
raise CliCommandException(f"No User exists with ID {user_id}", error_code=1)
|
||||||
|
counts.users = 1
|
||||||
|
# Find all datasets owned by this user
|
||||||
|
datasets_query = select(Dataset).where(Dataset.owner_id == user.id)
|
||||||
|
user_datasets = (await session.execute(datasets_query)).scalars().all()
|
||||||
|
dataset_count = len(user_datasets)
|
||||||
|
counts.datasets = dataset_count
|
||||||
|
if dataset_count > 0:
|
||||||
|
dataset_ids = [d.id for d in user_datasets]
|
||||||
|
# Count all data entries across all of the user's datasets
|
||||||
|
data_count_query = (
|
||||||
|
select(func.count())
|
||||||
|
.select_from(DatasetData)
|
||||||
|
.where(DatasetData.dataset_id.in_(dataset_ids))
|
||||||
|
)
|
||||||
|
data_entry_count = (await session.execute(data_count_query)).scalar_one()
|
||||||
|
counts.entries = data_entry_count
|
||||||
|
else:
|
||||||
|
counts.entries = 0
|
||||||
|
return counts
|
||||||
|
|
@ -12,7 +12,8 @@ from cognee.cli.commands.search_command import SearchCommand
|
||||||
from cognee.cli.commands.cognify_command import CognifyCommand
|
from cognee.cli.commands.cognify_command import CognifyCommand
|
||||||
from cognee.cli.commands.delete_command import DeleteCommand
|
from cognee.cli.commands.delete_command import DeleteCommand
|
||||||
from cognee.cli.commands.config_command import ConfigCommand
|
from cognee.cli.commands.config_command import ConfigCommand
|
||||||
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
from cognee.cli.exceptions import CliCommandException
|
||||||
|
from cognee.modules.data.methods.get_deletion_counts import DeletionCountsPreview
|
||||||
|
|
||||||
|
|
||||||
# Mock asyncio.run to properly handle coroutines
|
# Mock asyncio.run to properly handle coroutines
|
||||||
|
|
@ -282,13 +283,18 @@ class TestDeleteCommand:
|
||||||
assert "all" in actions
|
assert "all" in actions
|
||||||
assert "force" in actions
|
assert "force" in actions
|
||||||
|
|
||||||
|
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
|
||||||
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
||||||
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
|
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
|
||||||
def test_execute_delete_dataset_with_confirmation(self, mock_asyncio_run, mock_confirm):
|
def test_execute_delete_dataset_with_confirmation(
|
||||||
|
self, mock_asyncio_run, mock_confirm, mock_get_deletion_counts
|
||||||
|
):
|
||||||
"""Test execute delete dataset with user confirmation"""
|
"""Test execute delete dataset with user confirmation"""
|
||||||
# Mock the cognee module
|
# Mock the cognee module
|
||||||
mock_cognee = MagicMock()
|
mock_cognee = MagicMock()
|
||||||
mock_cognee.delete = AsyncMock()
|
mock_cognee.delete = AsyncMock()
|
||||||
|
mock_get_deletion_counts = AsyncMock()
|
||||||
|
mock_get_deletion_counts.return_value = DeletionCountsPreview()
|
||||||
|
|
||||||
with patch.dict(sys.modules, {"cognee": mock_cognee}):
|
with patch.dict(sys.modules, {"cognee": mock_cognee}):
|
||||||
command = DeleteCommand()
|
command = DeleteCommand()
|
||||||
|
|
@ -301,13 +307,16 @@ class TestDeleteCommand:
|
||||||
command.execute(args)
|
command.execute(args)
|
||||||
|
|
||||||
mock_confirm.assert_called_once_with(f"Delete dataset '{args.dataset_name}'?")
|
mock_confirm.assert_called_once_with(f"Delete dataset '{args.dataset_name}'?")
|
||||||
mock_asyncio_run.assert_called_once()
|
assert mock_asyncio_run.call_count == 2
|
||||||
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
|
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
|
||||||
mock_cognee.delete.assert_awaited_once_with(dataset_name="test_dataset", user_id=None)
|
mock_cognee.delete.assert_awaited_once_with(dataset_name="test_dataset", user_id=None)
|
||||||
|
|
||||||
|
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
|
||||||
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
||||||
def test_execute_delete_cancelled(self, mock_confirm):
|
def test_execute_delete_cancelled(self, mock_confirm, mock_get_deletion_counts):
|
||||||
"""Test execute when user cancels deletion"""
|
"""Test execute when user cancels deletion"""
|
||||||
|
mock_get_deletion_counts = AsyncMock()
|
||||||
|
mock_get_deletion_counts.return_value = DeletionCountsPreview()
|
||||||
command = DeleteCommand()
|
command = DeleteCommand()
|
||||||
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)
|
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from cognee.cli.commands.cognify_command import CognifyCommand
|
||||||
from cognee.cli.commands.delete_command import DeleteCommand
|
from cognee.cli.commands.delete_command import DeleteCommand
|
||||||
from cognee.cli.commands.config_command import ConfigCommand
|
from cognee.cli.commands.config_command import ConfigCommand
|
||||||
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
||||||
|
from cognee.modules.data.methods.get_deletion_counts import DeletionCountsPreview
|
||||||
|
|
||||||
|
|
||||||
# Mock asyncio.run to properly handle coroutines
|
# Mock asyncio.run to properly handle coroutines
|
||||||
|
|
@ -378,13 +379,18 @@ class TestCognifyCommandEdgeCases:
|
||||||
class TestDeleteCommandEdgeCases:
|
class TestDeleteCommandEdgeCases:
|
||||||
"""Test edge cases for DeleteCommand"""
|
"""Test edge cases for DeleteCommand"""
|
||||||
|
|
||||||
|
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
|
||||||
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
||||||
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
|
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
|
||||||
def test_delete_all_with_user_id(self, mock_asyncio_run, mock_confirm):
|
def test_delete_all_with_user_id(
|
||||||
|
self, mock_asyncio_run, mock_confirm, mock_get_deletion_counts
|
||||||
|
):
|
||||||
"""Test delete command with both --all and --user-id"""
|
"""Test delete command with both --all and --user-id"""
|
||||||
# Mock the cognee module
|
# Mock the cognee module
|
||||||
mock_cognee = MagicMock()
|
mock_cognee = MagicMock()
|
||||||
mock_cognee.delete = AsyncMock()
|
mock_cognee.delete = AsyncMock()
|
||||||
|
mock_get_deletion_counts = AsyncMock()
|
||||||
|
mock_get_deletion_counts.return_value = DeletionCountsPreview()
|
||||||
|
|
||||||
with patch.dict(sys.modules, {"cognee": mock_cognee}):
|
with patch.dict(sys.modules, {"cognee": mock_cognee}):
|
||||||
command = DeleteCommand()
|
command = DeleteCommand()
|
||||||
|
|
@ -396,13 +402,17 @@ class TestDeleteCommandEdgeCases:
|
||||||
command.execute(args)
|
command.execute(args)
|
||||||
|
|
||||||
mock_confirm.assert_called_once_with("Delete ALL data from cognee?")
|
mock_confirm.assert_called_once_with("Delete ALL data from cognee?")
|
||||||
mock_asyncio_run.assert_called_once()
|
assert mock_asyncio_run.call_count == 2
|
||||||
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
|
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
|
||||||
mock_cognee.delete.assert_awaited_once_with(dataset_name=None, user_id="test_user")
|
mock_cognee.delete.assert_awaited_once_with(dataset_name=None, user_id="test_user")
|
||||||
|
|
||||||
|
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
|
||||||
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
||||||
def test_delete_confirmation_keyboard_interrupt(self, mock_confirm):
|
def test_delete_confirmation_keyboard_interrupt(self, mock_confirm, mock_get_deletion_counts):
|
||||||
"""Test delete command when user interrupts confirmation"""
|
"""Test delete command when user interrupts confirmation"""
|
||||||
|
mock_get_deletion_counts = AsyncMock()
|
||||||
|
mock_get_deletion_counts.return_value = DeletionCountsPreview()
|
||||||
|
|
||||||
command = DeleteCommand()
|
command = DeleteCommand()
|
||||||
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)
|
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)
|
||||||
|
|
||||||
|
|
|
||||||
141
cognee/tests/test_advanced_pdf_loader.py
Normal file
141
cognee/tests/test_advanced_pdf_loader.py
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock, mock_open
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from cognee.infrastructure.loaders.external.advanced_pdf_loader import AdvancedPdfLoader
|
||||||
|
|
||||||
|
advanced_pdf_loader_module = sys.modules.get(
|
||||||
|
"cognee.infrastructure.loaders.external.advanced_pdf_loader"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockElement:
|
||||||
|
def __init__(self, category, text, metadata):
|
||||||
|
self.category = category
|
||||||
|
self.text = text
|
||||||
|
self.metadata = metadata
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"type": self.category,
|
||||||
|
"text": self.text,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def loader():
|
||||||
|
return AdvancedPdfLoader()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"extension, mime_type, expected",
|
||||||
|
[
|
||||||
|
("pdf", "application/pdf", True),
|
||||||
|
("txt", "text/plain", False),
|
||||||
|
("pdf", "text/plain", False),
|
||||||
|
("doc", "application/pdf", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_can_handle(loader, extension, mime_type, expected):
|
||||||
|
"""Test can_handle method can correctly identify PDF files"""
|
||||||
|
assert loader.can_handle(extension, mime_type) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.open", new_callable=mock_open)
|
||||||
|
@patch(
|
||||||
|
"cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_metadata",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
)
|
||||||
|
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.get_storage_config")
|
||||||
|
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_storage")
|
||||||
|
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.PyPdfLoader")
|
||||||
|
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.partition_pdf")
|
||||||
|
async def test_load_success_with_unstructured(
|
||||||
|
mock_partition_pdf,
|
||||||
|
mock_pypdf_loader,
|
||||||
|
mock_get_file_storage,
|
||||||
|
mock_get_storage_config,
|
||||||
|
mock_get_file_metadata,
|
||||||
|
mock_open,
|
||||||
|
loader,
|
||||||
|
):
|
||||||
|
"""Test the main flow of using unstructured to successfully process PDF"""
|
||||||
|
# Prepare Mock data and objects
|
||||||
|
mock_elements = [
|
||||||
|
MockElement(
|
||||||
|
category="Title", text="Attention Is All You Need", metadata={"page_number": 1}
|
||||||
|
),
|
||||||
|
MockElement(
|
||||||
|
category="NarrativeText",
|
||||||
|
text="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks.",
|
||||||
|
metadata={"page_number": 1},
|
||||||
|
),
|
||||||
|
MockElement(
|
||||||
|
category="Table",
|
||||||
|
text="This is a table.",
|
||||||
|
metadata={"page_number": 2, "text_as_html": "<table><tr><td>Data</td></tr></table>"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mock_pypdf_loader.return_value.load = AsyncMock(return_value="/fake/path/fallback.txt")
|
||||||
|
mock_partition_pdf.return_value = mock_elements
|
||||||
|
mock_get_file_metadata.return_value = {"content_hash": "abc123def456"}
|
||||||
|
|
||||||
|
mock_storage_instance = MagicMock()
|
||||||
|
mock_storage_instance.store = AsyncMock(return_value="/stored/text_abc123def456.txt")
|
||||||
|
mock_get_file_storage.return_value = mock_storage_instance
|
||||||
|
|
||||||
|
mock_get_storage_config.return_value = {"data_root_directory": "/fake/data/root"}
|
||||||
|
test_file_path = "/fake/path/document.pdf"
|
||||||
|
|
||||||
|
# Run
|
||||||
|
|
||||||
|
result_path = await loader.load(test_file_path)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result_path == "/stored/text_abc123def456.txt"
|
||||||
|
|
||||||
|
# Verify partition_pdf is called with the correct parameters
|
||||||
|
mock_partition_pdf.assert_called_once()
|
||||||
|
call_args, call_kwargs = mock_partition_pdf.call_args
|
||||||
|
assert call_kwargs.get("filename") == test_file_path
|
||||||
|
assert call_kwargs.get("strategy") == "auto" # Default strategy
|
||||||
|
|
||||||
|
# Verify the stored content is correct
|
||||||
|
expected_content = "Page 1:\nAttention Is All You Need\n\nThe dominant sequence transduction models are based on complex recurrent or convolutional neural networks.\n\nPage 2:\n<table><tr><td>Data</td></tr></table>\n"
|
||||||
|
mock_storage_instance.store.assert_awaited_once_with("text_abc123def456.txt", expected_content)
|
||||||
|
|
||||||
|
# Verify fallback is not called
|
||||||
|
mock_pypdf_loader.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.open", new_callable=mock_open)
|
||||||
|
@patch(
|
||||||
|
"cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_metadata",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
)
|
||||||
|
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.PyPdfLoader")
|
||||||
|
@patch(
|
||||||
|
"cognee.infrastructure.loaders.external.advanced_pdf_loader.partition_pdf",
|
||||||
|
side_effect=Exception("Unstructured failed!"),
|
||||||
|
)
|
||||||
|
async def test_load_fallback_on_unstructured_exception(
|
||||||
|
mock_partition_pdf, mock_pypdf_loader, mock_get_file_metadata, mock_open, loader
|
||||||
|
):
|
||||||
|
"""Test fallback to PyPdfLoader when unstructured throws an exception"""
|
||||||
|
# Prepare Mock
|
||||||
|
mock_fallback_instance = MagicMock()
|
||||||
|
mock_fallback_instance.load = AsyncMock(return_value="/fake/path/fallback.txt")
|
||||||
|
mock_pypdf_loader.return_value = mock_fallback_instance
|
||||||
|
mock_get_file_metadata.return_value = {"content_hash": "anyhash"}
|
||||||
|
test_file_path = "/fake/path/document.pdf"
|
||||||
|
|
||||||
|
# Run
|
||||||
|
result_path = await loader.load(test_file_path)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result_path == "/fake/path/fallback.txt"
|
||||||
|
mock_partition_pdf.assert_called_once() # Verify partition_pdf is called
|
||||||
|
mock_fallback_instance.load.assert_awaited_once_with(test_file_path)
|
||||||
|
|
@ -41,7 +41,12 @@ class TestCogneeServerStart(unittest.TestCase):
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
# Terminate the server process
|
# Terminate the server process
|
||||||
if hasattr(cls, "server_process") and cls.server_process:
|
if hasattr(cls, "server_process") and cls.server_process:
|
||||||
os.killpg(os.getpgid(cls.server_process.pid), signal.SIGTERM)
|
if hasattr(os, "killpg"):
|
||||||
|
# Unix-like systems: Use process groups
|
||||||
|
os.killpg(os.getpgid(cls.server_process.pid), signal.SIGTERM)
|
||||||
|
else:
|
||||||
|
# Windows: Just terminate the main process
|
||||||
|
cls.server_process.terminate()
|
||||||
cls.server_process.wait()
|
cls.server_process.wait()
|
||||||
|
|
||||||
def test_server_is_running(self):
|
def test_server_is_running(self):
|
||||||
|
|
|
||||||
1634
poetry.lock
generated
1634
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -95,7 +95,7 @@ chromadb = [
|
||||||
"chromadb>=0.6,<0.7",
|
"chromadb>=0.6,<0.7",
|
||||||
"pypika==0.48.9",
|
"pypika==0.48.9",
|
||||||
]
|
]
|
||||||
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx]>=0.18.1,<19"]
|
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx, pdf]>=0.18.1,<19"]
|
||||||
codegraph = [
|
codegraph = [
|
||||||
"fastembed<=0.6.0 ; python_version < '3.13'",
|
"fastembed<=0.6.0 ; python_version < '3.13'",
|
||||||
"transformers>=4.46.3,<5",
|
"transformers>=4.46.3,<5",
|
||||||
|
|
@ -142,6 +142,7 @@ Homepage = "https://www.cognee.ai"
|
||||||
Repository = "https://github.com/topoteretes/cognee"
|
Repository = "https://github.com/topoteretes/cognee"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
cognee = "cognee.cli._cognee:main"
|
||||||
cognee-cli = "cognee.cli._cognee:main"
|
cognee-cli = "cognee.cli._cognee:main"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue