Merge branch 'dev' into update-endpoint

This commit is contained in:
Igor Ilic 2025-10-07 11:27:07 +02:00 committed by GitHub
commit 9ae3c97aef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 3074 additions and 697 deletions

View file

@ -43,7 +43,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ubuntu-22.04, macos-13, macos-15, windows-latest]
os: [ubuntu-22.04, macos-15, windows-latest]
fail-fast: false
steps:
- name: Check out
@ -79,7 +79,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false
steps:
- name: Check out
@ -115,7 +115,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false
steps:
- name: Check out
@ -151,7 +151,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false
steps:
- name: Check out
@ -180,7 +180,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false
steps:
- name: Check out
@ -210,7 +210,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false
steps:
- name: Check out

View file

@ -3,10 +3,18 @@
import classNames from "classnames";
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
import { forceCollide, forceManyBody } from "d3-force-3d";
import ForceGraph, { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
import dynamic from "next/dynamic";
import { GraphControlsAPI } from "./GraphControls";
import getColorForNodeType from "./getColorForNodeType";
// Dynamically import ForceGraph to prevent SSR issues
const ForceGraph = dynamic(() => import("react-force-graph-2d"), {
ssr: false,
loading: () => <div className="w-full h-full flex items-center justify-center">Loading graph...</div>
});
import type { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
interface GraphVisuzaliationProps {
ref: MutableRefObject<GraphVisualizationAPI>;
data?: GraphData<NodeObject, LinkObject>;
@ -200,7 +208,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
const graphRef = useRef<ForceGraphMethods>();
useEffect(() => {
if (typeof window !== "undefined" && data && graphRef.current) {
if (data && graphRef.current) {
// add collision force
graphRef.current.d3Force("collision", forceCollide(nodeSize * 1.5));
graphRef.current.d3Force("charge", forceManyBody().strength(-10).distanceMin(10).distanceMax(50));
@ -216,56 +224,34 @@ export default function GraphVisualization({ ref, data, graphControls, className
return (
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
{(data && typeof window !== "undefined") ? (
<ForceGraph
ref={graphRef}
width={dimensions.width}
height={dimensions.height}
dagMode={graphShape as unknown as undefined}
dagLevelDistance={300}
onDagError={handleDagError}
graphData={data}
<ForceGraph
ref={graphRef}
width={dimensions.width}
height={dimensions.height}
dagMode={graphShape as unknown as undefined}
dagLevelDistance={data ? 300 : 100}
onDagError={handleDagError}
graphData={data || {
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
}}
nodeLabel="label"
nodeRelSize={nodeSize}
nodeCanvasObject={renderNode}
nodeCanvasObjectMode={() => "replace"}
nodeLabel="label"
nodeRelSize={data ? nodeSize : 20}
nodeCanvasObject={data ? renderNode : renderInitialNode}
nodeCanvasObjectMode={() => data ? "replace" : "after"}
nodeAutoColorBy={data ? undefined : "type"}
linkLabel="label"
linkCanvasObject={renderLink}
linkCanvasObjectMode={() => "after"}
linkDirectionalArrowLength={3.5}
linkDirectionalArrowRelPos={1}
linkLabel="label"
linkCanvasObject={renderLink}
linkCanvasObjectMode={() => "after"}
linkDirectionalArrowLength={3.5}
linkDirectionalArrowRelPos={1}
onNodeClick={handleNodeClick}
onBackgroundClick={handleBackgroundClick}
d3VelocityDecay={0.3}
/>
) : (
<ForceGraph
ref={graphRef}
width={dimensions.width}
height={dimensions.height}
dagMode={graphShape as unknown as undefined}
dagLevelDistance={100}
graphData={{
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
}}
nodeLabel="label"
nodeRelSize={20}
nodeCanvasObject={renderInitialNode}
nodeCanvasObjectMode={() => "after"}
nodeAutoColorBy="type"
linkLabel="label"
linkCanvasObject={renderLink}
linkCanvasObjectMode={() => "after"}
linkDirectionalArrowLength={3.5}
linkDirectionalArrowRelPos={1}
/>
)}
onNodeClick={handleNodeClick}
onBackgroundClick={handleBackgroundClick}
d3VelocityDecay={data ? 0.3 : undefined}
/>
</div>
);
}

View file

@ -51,6 +51,13 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
)
.then((response) => handleServerErrors(response, retry, useCloud))
.catch((error) => {
// Handle network errors more gracefully
if (error.name === 'TypeError' && error.message.includes('fetch')) {
return Promise.reject(
new Error("Backend server is not responding. Please check if the server is running.")
);
}
if (error.detail === undefined) {
return Promise.reject(
new Error("No connection to the server.")
@ -64,8 +71,27 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
});
}
fetch.checkHealth = () => {
return global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
fetch.checkHealth = async () => {
const maxRetries = 5;
const retryDelay = 1000; // 1 second
for (let i = 0; i < maxRetries; i++) {
try {
const response = await global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
if (response.ok) {
return response;
}
} catch (error) {
// If this is the last retry, throw the error
if (i === maxRetries - 1) {
throw error;
}
// Wait before retrying
await new Promise(resolve => setTimeout(resolve, retryDelay));
}
}
throw new Error("Backend server is not responding after multiple attempts");
};
fetch.checkMCPHealth = () => {

View file

@ -194,7 +194,7 @@ class HealthChecker:
config = get_llm_config()
# Test actual API connection with minimal request
LLMGateway.show_prompt("test", "test")
LLMGateway.show_prompt("test", "test.txt")
response_time = int((time.time() - start_time) * 1000)
return ComponentHealth(

View file

@ -1,4 +1,6 @@
import os
import platform
import signal
import socket
import subprocess
import threading
@ -288,6 +290,7 @@ def check_node_npm() -> tuple[bool, str]:
Check if Node.js and npm are available.
Returns (is_available, error_message)
"""
try:
# Check Node.js
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
@ -297,8 +300,17 @@ def check_node_npm() -> tuple[bool, str]:
node_version = result.stdout.strip()
logger.debug(f"Found Node.js version: {node_version}")
# Check npm
result = subprocess.run(["npm", "--version"], capture_output=True, text=True, timeout=10)
# Check npm - handle Windows PowerShell scripts
if platform.system() == "Windows":
# On Windows, npm might be a PowerShell script, so we need to use shell=True
result = subprocess.run(
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
)
else:
result = subprocess.run(
["npm", "--version"], capture_output=True, text=True, timeout=10
)
if result.returncode != 0:
return False, "npm is not installed or not in PATH"
@ -320,6 +332,7 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
Install frontend dependencies if node_modules doesn't exist.
This is needed for both development and downloaded frontends since both use npm run dev.
"""
node_modules = frontend_path / "node_modules"
if node_modules.exists():
logger.debug("Frontend dependencies already installed")
@ -328,13 +341,24 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
logger.info("Installing frontend dependencies (this may take a few minutes)...")
try:
result = subprocess.run(
["npm", "install"],
cwd=frontend_path,
capture_output=True,
text=True,
timeout=300, # 5 minutes timeout
)
# Use shell=True on Windows for npm commands
if platform.system() == "Windows":
result = subprocess.run(
["npm", "install"],
cwd=frontend_path,
capture_output=True,
text=True,
timeout=300, # 5 minutes timeout
shell=True,
)
else:
result = subprocess.run(
["npm", "install"],
cwd=frontend_path,
capture_output=True,
text=True,
timeout=300, # 5 minutes timeout
)
if result.returncode == 0:
logger.info("Frontend dependencies installed successfully")
@ -595,14 +619,27 @@ def start_ui(
try:
# Create frontend in its own process group for clean termination
process = subprocess.Popen(
["npm", "run", "dev"],
cwd=frontend_path,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
# Use shell=True on Windows for npm commands
if platform.system() == "Windows":
process = subprocess.Popen(
["npm", "run", "dev"],
cwd=frontend_path,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
shell=True,
)
else:
process = subprocess.Popen(
["npm", "run", "dev"],
cwd=frontend_path,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
# Start threads to stream frontend output with prefix
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow

View file

@ -183,10 +183,20 @@ def main() -> int:
for pid in spawned_pids:
try:
pgid = os.getpgid(pid)
os.killpg(pgid, signal.SIGTERM)
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
except (OSError, ProcessLookupError) as e:
if hasattr(os, "killpg"):
# Unix-like systems: Use process groups
pgid = os.getpgid(pid)
os.killpg(pgid, signal.SIGTERM)
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
else:
# Windows: Use taskkill to terminate process and its children
subprocess.run(
["taskkill", "/F", "/T", "/PID", str(pid)],
capture_output=True,
check=False,
)
fmt.success(f"✓ Process {pid} and its children terminated.")
except (OSError, ProcessLookupError, subprocess.SubprocessError) as e:
fmt.warning(f"Could not terminate process {pid}: {e}")
sys.exit(0)

View file

@ -6,6 +6,7 @@ from cognee.cli.reference import SupportsCliCommand
from cognee.cli import DEFAULT_DOCS_URL
import cognee.cli.echo as fmt
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
from cognee.modules.data.methods.get_deletion_counts import get_deletion_counts
class DeleteCommand(SupportsCliCommand):
@ -41,7 +42,34 @@ Be careful with deletion operations as they are irreversible.
fmt.error("Please specify what to delete: --dataset-name, --user-id, or --all")
return
# Build confirmation message
# If --force is used, skip the preview and go straight to deletion
if not args.force:
# --- START PREVIEW LOGIC ---
fmt.echo("Gathering data for preview...")
try:
preview_data = asyncio.run(
get_deletion_counts(
dataset_name=args.dataset_name,
user_id=args.user_id,
all_data=args.all,
)
)
except CliCommandException as e:
fmt.error(f"Error occured when fetching preview data: {str(e)}")
return
if not preview_data:
fmt.success("No data found to delete.")
return
fmt.echo("You are about to delete:")
fmt.echo(
f"Datasets: {preview_data.datasets}\nEntries: {preview_data.entries}\nUsers: {preview_data.users}"
)
fmt.echo("-" * 20)
# --- END PREVIEW LOGIC ---
# Build operation message for success/failure logging
if args.all:
confirm_msg = "Delete ALL data from cognee?"
operation = "all data"
@ -51,8 +79,9 @@ Be careful with deletion operations as they are irreversible.
elif args.user_id:
confirm_msg = f"Delete all data for user '{args.user_id}'?"
operation = f"data for user '{args.user_id}'"
else:
operation = "data"
# Confirm deletion unless forced
if not args.force:
fmt.warning("This operation is irreversible!")
if not fmt.confirm(confirm_msg):
@ -64,6 +93,8 @@ Be careful with deletion operations as they are irreversible.
# Run the async delete function
async def run_delete():
try:
# NOTE: The underlying cognee.delete() function is currently not working as expected.
# This is a separate bug that this preview feature helps to expose.
if args.all:
await cognee.delete(dataset_name=None, user_id=args.user_id)
else:
@ -72,6 +103,7 @@ Be careful with deletion operations as they are irreversible.
raise CliCommandInnerException(f"Failed to delete: {str(e)}")
asyncio.run(run_delete())
# This success message may be inaccurate due to the underlying bug, but we leave it for now.
fmt.success(f"Successfully deleted {operation}")
except Exception as e:

View file

@ -26,6 +26,7 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
read due to an error.
"""
logger = get_logger(level=ERROR)
try:
if base_directory is None:
base_directory = get_absolute_path("./infrastructure/llm/prompts")
@ -35,8 +36,8 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
with open(file_path, "r", encoding="utf-8") as file:
return file.read()
except FileNotFoundError:
logger.error(f"Error: Prompt file not found. Attempted to read: %s {file_path}")
logger.error(f"Error: Prompt file not found. Attempted to read: {file_path}")
return None
except Exception as e:
logger.error(f"An error occurred: %s {e}")
logger.error(f"An error occurred: {e}")
return None

View file

@ -0,0 +1 @@
Respond with: test

View file

@ -73,10 +73,19 @@ class OpenAIAdapter(LLMInterface):
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
)
self.client = instructor.from_litellm(litellm.completion, mode=instructor.Mode.JSON_SCHEMA)
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
# Make sure all new gpt models will work with this mode as well.
if "gpt-5" in model:
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
)
self.client = instructor.from_litellm(
litellm.completion, mode=instructor.Mode.JSON_SCHEMA
)
else:
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
self.transcription_model = transcription_model
self.model = model
self.api_key = api_key

View file

@ -27,6 +27,7 @@ class LoaderEngine:
self.default_loader_priority = [
"text_loader",
"advanced_pdf_loader",
"pypdf_loader",
"image_loader",
"audio_loader",
@ -86,7 +87,7 @@ class LoaderEngine:
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
return loader
else:
raise ValueError(f"Loader does not exist: {loader_name}")
logger.info(f"Skipping {loader_name}: Preferred Loader not registered")
# Try default priority order
for loader_name in self.default_loader_priority:
@ -95,7 +96,9 @@ class LoaderEngine:
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
return loader
else:
raise ValueError(f"Loader does not exist: {loader_name}")
logger.info(
f"Skipping {loader_name}: Loader not registered (in default priority list)."
)
return None

View file

@ -20,3 +20,10 @@ try:
__all__.append("UnstructuredLoader")
except ImportError:
pass
try:
from .advanced_pdf_loader import AdvancedPdfLoader
__all__.append("AdvancedPdfLoader")
except ImportError:
pass

View file

@ -0,0 +1,244 @@
"""Advanced PDF loader leveraging unstructured for layout-aware extraction."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import asyncio
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.loaders.external.pypdf_loader import PyPdfLoader
logger = get_logger(__name__)
try:
from unstructured.partition.pdf import partition_pdf
except ImportError as e:
logger.info(
"unstructured[pdf] not installed, can't use AdvancedPdfLoader, will use PyPdfLoader instead."
)
raise ImportError from e
@dataclass
class _PageBuffer:
page_num: Optional[int]
segments: List[str]
class AdvancedPdfLoader(LoaderInterface):
"""
PDF loader using unstructured library.
Extracts text content, images, tables from PDF files page by page, providing
structured page information and handling PDF-specific errors.
"""
@property
def supported_extensions(self) -> List[str]:
return ["pdf"]
@property
def supported_mime_types(self) -> List[str]:
return ["application/pdf"]
@property
def loader_name(self) -> str:
return "advanced_pdf_loader"
def can_handle(self, extension: str, mime_type: str) -> bool:
"""Check if file can be handled by this loader."""
# Check file extension
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
return True
return False
async def load(self, file_path: str, strategy: str = "auto", **kwargs: Any) -> str:
"""Load PDF file using unstructured library. If Exception occurs, fallback to PyPDFLoader.
Args:
file_path: Path to the document file
strategy: Partitioning strategy ("auto", "fast", "hi_res", "ocr_only")
**kwargs: Additional arguments passed to unstructured partition
Returns:
LoaderResult with extracted text content and metadata
"""
try:
logger.info(f"Processing PDF: {file_path}")
with open(file_path, "rb") as f:
file_metadata = await get_file_metadata(f)
# Name ingested file of current loader based on original file content hash
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
# Set partitioning parameters
partition_kwargs: Dict[str, Any] = {
"filename": file_path,
"strategy": strategy,
"infer_table_structure": True,
"include_page_breaks": False,
"include_metadata": True,
**kwargs,
}
# Use partition to extract elements
elements = partition_pdf(**partition_kwargs)
# Process elements into text content
page_contents = self._format_elements_by_page(elements)
# Check if there is any content
if not page_contents:
logger.warning(
"AdvancedPdfLoader returned no content. Falling back to PyPDF loader."
)
return await self._fallback(file_path, **kwargs)
# Combine all page outputs
full_content = "\n".join(page_contents)
# Store the content
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
storage = get_file_storage(data_root_directory)
full_file_path = await storage.store(storage_file_name, full_content)
return full_file_path
except Exception as exc:
logger.warning("Failed to process PDF with AdvancedPdfLoader: %s", exc)
return await self._fallback(file_path, **kwargs)
async def _fallback(self, file_path: str, **kwargs: Any) -> str:
logger.info("Falling back to PyPDF loader for %s", file_path)
fallback_loader = PyPdfLoader()
return await fallback_loader.load(file_path, **kwargs)
def _format_elements_by_page(self, elements: List[Any]) -> List[str]:
"""Format elements by page."""
page_buffers: List[_PageBuffer] = []
current_buffer = _PageBuffer(page_num=None, segments=[])
for element in elements:
element_dict = self._safe_to_dict(element)
metadata = element_dict.get("metadata", {})
page_num = metadata.get("page_number")
if current_buffer.page_num != page_num:
if current_buffer.segments:
page_buffers.append(current_buffer)
current_buffer = _PageBuffer(page_num=page_num, segments=[])
formatted = self._format_element(element_dict)
if formatted:
current_buffer.segments.append(formatted)
if current_buffer.segments:
page_buffers.append(current_buffer)
page_contents: List[str] = []
for buffer in page_buffers:
header = f"Page {buffer.page_num}:\n" if buffer.page_num is not None else "Page:"
content = header + "\n\n".join(buffer.segments) + "\n"
page_contents.append(str(content))
return page_contents
def _format_element(
self,
element: Dict[str, Any],
) -> str:
"""Format element."""
element_type = element.get("type")
text = self._clean_text(element.get("text", ""))
metadata = element.get("metadata", {})
if element_type.lower() == "table":
return self._format_table_element(element) or text
if element_type.lower() == "image":
description = text or self._format_image_element(metadata)
return description
# Ignore header and footer
if element_type.lower() in ["header", "footer"]:
pass
return text
def _format_table_element(self, element: Dict[str, Any]) -> str:
"""Format table element."""
metadata = element.get("metadata", {})
text = self._clean_text(element.get("text", ""))
table_html = metadata.get("text_as_html")
if table_html:
return table_html.strip()
return text
def _format_image_element(self, metadata: Dict[str, Any]) -> str:
"""Format image."""
placeholder = "[Image omitted]"
image_text = placeholder
coordinates = metadata.get("coordinates", {})
points = coordinates.get("points") if isinstance(coordinates, dict) else None
if points and isinstance(points, tuple) and len(points) == 4:
leftup = points[0]
rightdown = points[3]
if (
isinstance(leftup, tuple)
and isinstance(rightdown, tuple)
and len(leftup) == 2
and len(rightdown) == 2
):
image_text = f"{placeholder} (bbox=({leftup[0]}, {leftup[1]}, {rightdown[0]}, {rightdown[1]}))"
layout_width = coordinates.get("layout_width")
layout_height = coordinates.get("layout_height")
system = coordinates.get("system")
if layout_width and layout_height and system:
image_text = (
image_text
+ f", system={system}, layout_width={layout_width}, layout_height={layout_height}))"
)
return image_text
def _safe_to_dict(self, element: Any) -> Dict[str, Any]:
"""Safe to dict."""
try:
if hasattr(element, "to_dict"):
return element.to_dict()
except Exception:
pass
fallback_type = getattr(element, "category", None)
if not fallback_type:
fallback_type = getattr(element, "__class__", type("", (), {})).__name__
return {
"type": fallback_type,
"text": getattr(element, "text", ""),
"metadata": getattr(element, "metadata", {}),
}
def _clean_text(self, value: Any) -> str:
if value is None:
return ""
return str(value).replace("\xa0", " ").strip()
if __name__ == "__main__":
loader = AdvancedPdfLoader()
asyncio.run(
loader.load(
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
)
)

View file

@ -16,3 +16,10 @@ try:
supported_loaders[UnstructuredLoader.loader_name] = UnstructuredLoader
except ImportError:
pass
try:
from cognee.infrastructure.loaders.external import AdvancedPdfLoader
supported_loaders[AdvancedPdfLoader.loader_name] = AdvancedPdfLoader
except ImportError:
pass

View file

@ -0,0 +1,92 @@
from uuid import UUID
from cognee.cli.exceptions import CliCommandException
from cognee.infrastructure.databases.exceptions.exceptions import EntityNotFoundError
from sqlalchemy import select
from sqlalchemy.sql import func
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Dataset, Data, DatasetData
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_user
from dataclasses import dataclass
@dataclass
class DeletionCountsPreview:
datasets: int = 0
data_entries: int = 0
users: int = 0
async def get_deletion_counts(
dataset_name: str = None, user_id: str = None, all_data: bool = False
) -> DeletionCountsPreview:
"""
Calculates the number of items that will be deleted based on the provided arguments.
"""
counts = DeletionCountsPreview()
relational_engine = get_relational_engine()
async with relational_engine.get_async_session() as session:
if dataset_name:
# Find the dataset by name
dataset_result = await session.execute(
select(Dataset).where(Dataset.name == dataset_name)
)
dataset = dataset_result.scalar_one_or_none()
if dataset is None:
raise CliCommandException(
f"No Dataset exists with the name {dataset_name}", error_code=1
)
# Count data entries linked to this dataset
count_query = (
select(func.count())
.select_from(DatasetData)
.where(DatasetData.dataset_id == dataset.id)
)
data_entry_count = (await session.execute(count_query)).scalar_one()
counts.users = 1
counts.datasets = 1
counts.entries = data_entry_count
return counts
elif all_data:
# Simplified logic: Get total counts directly from the tables.
counts.datasets = (
await session.execute(select(func.count()).select_from(Dataset))
).scalar_one()
counts.entries = (
await session.execute(select(func.count()).select_from(Data))
).scalar_one()
counts.users = (
await session.execute(select(func.count()).select_from(User))
).scalar_one()
return counts
# Placeholder for user_id logic
elif user_id:
user = None
try:
user_uuid = UUID(user_id)
user = await get_user(user_uuid)
except (ValueError, EntityNotFoundError):
raise CliCommandException(f"No User exists with ID {user_id}", error_code=1)
counts.users = 1
# Find all datasets owned by this user
datasets_query = select(Dataset).where(Dataset.owner_id == user.id)
user_datasets = (await session.execute(datasets_query)).scalars().all()
dataset_count = len(user_datasets)
counts.datasets = dataset_count
if dataset_count > 0:
dataset_ids = [d.id for d in user_datasets]
# Count all data entries across all of the user's datasets
data_count_query = (
select(func.count())
.select_from(DatasetData)
.where(DatasetData.dataset_id.in_(dataset_ids))
)
data_entry_count = (await session.execute(data_count_query)).scalar_one()
counts.entries = data_entry_count
else:
counts.entries = 0
return counts

View file

@ -12,7 +12,8 @@ from cognee.cli.commands.search_command import SearchCommand
from cognee.cli.commands.cognify_command import CognifyCommand
from cognee.cli.commands.delete_command import DeleteCommand
from cognee.cli.commands.config_command import ConfigCommand
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
from cognee.cli.exceptions import CliCommandException
from cognee.modules.data.methods.get_deletion_counts import DeletionCountsPreview
# Mock asyncio.run to properly handle coroutines
@ -282,13 +283,18 @@ class TestDeleteCommand:
assert "all" in actions
assert "force" in actions
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
@patch("cognee.cli.commands.delete_command.fmt.confirm")
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
def test_execute_delete_dataset_with_confirmation(self, mock_asyncio_run, mock_confirm):
def test_execute_delete_dataset_with_confirmation(
self, mock_asyncio_run, mock_confirm, mock_get_deletion_counts
):
"""Test execute delete dataset with user confirmation"""
# Mock the cognee module
mock_cognee = MagicMock()
mock_cognee.delete = AsyncMock()
mock_get_deletion_counts = AsyncMock()
mock_get_deletion_counts.return_value = DeletionCountsPreview()
with patch.dict(sys.modules, {"cognee": mock_cognee}):
command = DeleteCommand()
@ -301,13 +307,16 @@ class TestDeleteCommand:
command.execute(args)
mock_confirm.assert_called_once_with(f"Delete dataset '{args.dataset_name}'?")
mock_asyncio_run.assert_called_once()
assert mock_asyncio_run.call_count == 2
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
mock_cognee.delete.assert_awaited_once_with(dataset_name="test_dataset", user_id=None)
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
@patch("cognee.cli.commands.delete_command.fmt.confirm")
def test_execute_delete_cancelled(self, mock_confirm):
def test_execute_delete_cancelled(self, mock_confirm, mock_get_deletion_counts):
"""Test execute when user cancels deletion"""
mock_get_deletion_counts = AsyncMock()
mock_get_deletion_counts.return_value = DeletionCountsPreview()
command = DeleteCommand()
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)

View file

@ -13,6 +13,7 @@ from cognee.cli.commands.cognify_command import CognifyCommand
from cognee.cli.commands.delete_command import DeleteCommand
from cognee.cli.commands.config_command import ConfigCommand
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
from cognee.modules.data.methods.get_deletion_counts import DeletionCountsPreview
# Mock asyncio.run to properly handle coroutines
@ -378,13 +379,18 @@ class TestCognifyCommandEdgeCases:
class TestDeleteCommandEdgeCases:
"""Test edge cases for DeleteCommand"""
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
@patch("cognee.cli.commands.delete_command.fmt.confirm")
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
def test_delete_all_with_user_id(self, mock_asyncio_run, mock_confirm):
def test_delete_all_with_user_id(
self, mock_asyncio_run, mock_confirm, mock_get_deletion_counts
):
"""Test delete command with both --all and --user-id"""
# Mock the cognee module
mock_cognee = MagicMock()
mock_cognee.delete = AsyncMock()
mock_get_deletion_counts = AsyncMock()
mock_get_deletion_counts.return_value = DeletionCountsPreview()
with patch.dict(sys.modules, {"cognee": mock_cognee}):
command = DeleteCommand()
@ -396,13 +402,17 @@ class TestDeleteCommandEdgeCases:
command.execute(args)
mock_confirm.assert_called_once_with("Delete ALL data from cognee?")
mock_asyncio_run.assert_called_once()
assert mock_asyncio_run.call_count == 2
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
mock_cognee.delete.assert_awaited_once_with(dataset_name=None, user_id="test_user")
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
@patch("cognee.cli.commands.delete_command.fmt.confirm")
def test_delete_confirmation_keyboard_interrupt(self, mock_confirm):
def test_delete_confirmation_keyboard_interrupt(self, mock_confirm, mock_get_deletion_counts):
"""Test delete command when user interrupts confirmation"""
mock_get_deletion_counts = AsyncMock()
mock_get_deletion_counts.return_value = DeletionCountsPreview()
command = DeleteCommand()
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)

View file

@ -0,0 +1,141 @@
import sys
from unittest.mock import patch, MagicMock, AsyncMock, mock_open
import pytest
from cognee.infrastructure.loaders.external.advanced_pdf_loader import AdvancedPdfLoader
advanced_pdf_loader_module = sys.modules.get(
"cognee.infrastructure.loaders.external.advanced_pdf_loader"
)
class MockElement:
def __init__(self, category, text, metadata):
self.category = category
self.text = text
self.metadata = metadata
def to_dict(self):
return {
"type": self.category,
"text": self.text,
"metadata": self.metadata,
}
@pytest.fixture
def loader():
return AdvancedPdfLoader()
@pytest.mark.parametrize(
"extension, mime_type, expected",
[
("pdf", "application/pdf", True),
("txt", "text/plain", False),
("pdf", "text/plain", False),
("doc", "application/pdf", False),
],
)
def test_can_handle(loader, extension, mime_type, expected):
"""Test can_handle method can correctly identify PDF files"""
assert loader.can_handle(extension, mime_type) == expected
@pytest.mark.asyncio
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.open", new_callable=mock_open)
@patch(
"cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_metadata",
new_callable=AsyncMock,
)
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.get_storage_config")
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_storage")
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.PyPdfLoader")
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.partition_pdf")
async def test_load_success_with_unstructured(
mock_partition_pdf,
mock_pypdf_loader,
mock_get_file_storage,
mock_get_storage_config,
mock_get_file_metadata,
mock_open,
loader,
):
"""Test the main flow of using unstructured to successfully process PDF"""
# Prepare Mock data and objects
mock_elements = [
MockElement(
category="Title", text="Attention Is All You Need", metadata={"page_number": 1}
),
MockElement(
category="NarrativeText",
text="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks.",
metadata={"page_number": 1},
),
MockElement(
category="Table",
text="This is a table.",
metadata={"page_number": 2, "text_as_html": "<table><tr><td>Data</td></tr></table>"},
),
]
mock_pypdf_loader.return_value.load = AsyncMock(return_value="/fake/path/fallback.txt")
mock_partition_pdf.return_value = mock_elements
mock_get_file_metadata.return_value = {"content_hash": "abc123def456"}
mock_storage_instance = MagicMock()
mock_storage_instance.store = AsyncMock(return_value="/stored/text_abc123def456.txt")
mock_get_file_storage.return_value = mock_storage_instance
mock_get_storage_config.return_value = {"data_root_directory": "/fake/data/root"}
test_file_path = "/fake/path/document.pdf"
# Run
result_path = await loader.load(test_file_path)
# Assert
assert result_path == "/stored/text_abc123def456.txt"
# Verify partition_pdf is called with the correct parameters
mock_partition_pdf.assert_called_once()
call_args, call_kwargs = mock_partition_pdf.call_args
assert call_kwargs.get("filename") == test_file_path
assert call_kwargs.get("strategy") == "auto" # Default strategy
# Verify the stored content is correct
expected_content = "Page 1:\nAttention Is All You Need\n\nThe dominant sequence transduction models are based on complex recurrent or convolutional neural networks.\n\nPage 2:\n<table><tr><td>Data</td></tr></table>\n"
mock_storage_instance.store.assert_awaited_once_with("text_abc123def456.txt", expected_content)
# Verify fallback is not called
mock_pypdf_loader.assert_not_called()
@pytest.mark.asyncio
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.open", new_callable=mock_open)
@patch(
"cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_metadata",
new_callable=AsyncMock,
)
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.PyPdfLoader")
@patch(
"cognee.infrastructure.loaders.external.advanced_pdf_loader.partition_pdf",
side_effect=Exception("Unstructured failed!"),
)
async def test_load_fallback_on_unstructured_exception(
mock_partition_pdf, mock_pypdf_loader, mock_get_file_metadata, mock_open, loader
):
"""Test fallback to PyPdfLoader when unstructured throws an exception"""
# Prepare Mock
mock_fallback_instance = MagicMock()
mock_fallback_instance.load = AsyncMock(return_value="/fake/path/fallback.txt")
mock_pypdf_loader.return_value = mock_fallback_instance
mock_get_file_metadata.return_value = {"content_hash": "anyhash"}
test_file_path = "/fake/path/document.pdf"
# Run
result_path = await loader.load(test_file_path)
# Assert
assert result_path == "/fake/path/fallback.txt"
mock_partition_pdf.assert_called_once() # Verify partition_pdf is called
mock_fallback_instance.load.assert_awaited_once_with(test_file_path)

View file

@ -41,7 +41,12 @@ class TestCogneeServerStart(unittest.TestCase):
def tearDownClass(cls):
# Terminate the server process
if hasattr(cls, "server_process") and cls.server_process:
os.killpg(os.getpgid(cls.server_process.pid), signal.SIGTERM)
if hasattr(os, "killpg"):
# Unix-like systems: Use process groups
os.killpg(os.getpgid(cls.server_process.pid), signal.SIGTERM)
else:
# Windows: Just terminate the main process
cls.server_process.terminate()
cls.server_process.wait()
def test_server_is_running(self):

1634
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -95,7 +95,7 @@ chromadb = [
"chromadb>=0.6,<0.7",
"pypika==0.48.9",
]
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx]>=0.18.1,<19"]
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx, pdf]>=0.18.1,<19"]
codegraph = [
"fastembed<=0.6.0 ; python_version < '3.13'",
"transformers>=4.46.3,<5",
@ -142,6 +142,7 @@ Homepage = "https://www.cognee.ai"
Repository = "https://github.com/topoteretes/cognee"
[project.scripts]
cognee = "cognee.cli._cognee:main"
cognee-cli = "cognee.cli._cognee:main"
[build-system]

1320
uv.lock generated

File diff suppressed because it is too large Load diff