Merge branch 'dev' into update-endpoint

This commit is contained in:
Igor Ilic 2025-10-07 11:27:07 +02:00 committed by GitHub
commit 9ae3c97aef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 3074 additions and 697 deletions

View file

@ -43,7 +43,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: ${{ fromJSON(inputs.python-versions) }} python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ubuntu-22.04, macos-13, macos-15, windows-latest] os: [ubuntu-22.04, macos-15, windows-latest]
fail-fast: false fail-fast: false
steps: steps:
- name: Check out - name: Check out
@ -79,7 +79,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: ${{ fromJSON(inputs.python-versions) }} python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ] os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false fail-fast: false
steps: steps:
- name: Check out - name: Check out
@ -115,7 +115,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: ${{ fromJSON(inputs.python-versions) }} python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ] os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false fail-fast: false
steps: steps:
- name: Check out - name: Check out
@ -151,7 +151,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: ${{ fromJSON(inputs.python-versions) }} python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ] os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false fail-fast: false
steps: steps:
- name: Check out - name: Check out
@ -180,7 +180,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: ${{ fromJSON(inputs.python-versions) }} python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ] os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false fail-fast: false
steps: steps:
- name: Check out - name: Check out
@ -210,7 +210,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: ${{ fromJSON(inputs.python-versions) }} python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ] os: [ ubuntu-22.04, macos-15, windows-latest ]
fail-fast: false fail-fast: false
steps: steps:
- name: Check out - name: Check out

View file

@ -3,10 +3,18 @@
import classNames from "classnames"; import classNames from "classnames";
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react"; import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
import { forceCollide, forceManyBody } from "d3-force-3d"; import { forceCollide, forceManyBody } from "d3-force-3d";
import ForceGraph, { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d"; import dynamic from "next/dynamic";
import { GraphControlsAPI } from "./GraphControls"; import { GraphControlsAPI } from "./GraphControls";
import getColorForNodeType from "./getColorForNodeType"; import getColorForNodeType from "./getColorForNodeType";
// Dynamically import ForceGraph to prevent SSR issues
const ForceGraph = dynamic(() => import("react-force-graph-2d"), {
ssr: false,
loading: () => <div className="w-full h-full flex items-center justify-center">Loading graph...</div>
});
import type { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
interface GraphVisuzaliationProps { interface GraphVisuzaliationProps {
ref: MutableRefObject<GraphVisualizationAPI>; ref: MutableRefObject<GraphVisualizationAPI>;
data?: GraphData<NodeObject, LinkObject>; data?: GraphData<NodeObject, LinkObject>;
@ -200,7 +208,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
const graphRef = useRef<ForceGraphMethods>(); const graphRef = useRef<ForceGraphMethods>();
useEffect(() => { useEffect(() => {
if (typeof window !== "undefined" && data && graphRef.current) { if (data && graphRef.current) {
// add collision force // add collision force
graphRef.current.d3Force("collision", forceCollide(nodeSize * 1.5)); graphRef.current.d3Force("collision", forceCollide(nodeSize * 1.5));
graphRef.current.d3Force("charge", forceManyBody().strength(-10).distanceMin(10).distanceMax(50)); graphRef.current.d3Force("charge", forceManyBody().strength(-10).distanceMin(10).distanceMax(50));
@ -216,56 +224,34 @@ export default function GraphVisualization({ ref, data, graphControls, className
return ( return (
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container"> <div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
{(data && typeof window !== "undefined") ? ( <ForceGraph
<ForceGraph ref={graphRef}
ref={graphRef} width={dimensions.width}
width={dimensions.width} height={dimensions.height}
height={dimensions.height} dagMode={graphShape as unknown as undefined}
dagMode={graphShape as unknown as undefined} dagLevelDistance={data ? 300 : 100}
dagLevelDistance={300} onDagError={handleDagError}
onDagError={handleDagError} graphData={data || {
graphData={data} nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
}}
nodeLabel="label" nodeLabel="label"
nodeRelSize={nodeSize} nodeRelSize={data ? nodeSize : 20}
nodeCanvasObject={renderNode} nodeCanvasObject={data ? renderNode : renderInitialNode}
nodeCanvasObjectMode={() => "replace"} nodeCanvasObjectMode={() => data ? "replace" : "after"}
nodeAutoColorBy={data ? undefined : "type"}
linkLabel="label" linkLabel="label"
linkCanvasObject={renderLink} linkCanvasObject={renderLink}
linkCanvasObjectMode={() => "after"} linkCanvasObjectMode={() => "after"}
linkDirectionalArrowLength={3.5} linkDirectionalArrowLength={3.5}
linkDirectionalArrowRelPos={1} linkDirectionalArrowRelPos={1}
onNodeClick={handleNodeClick} onNodeClick={handleNodeClick}
onBackgroundClick={handleBackgroundClick} onBackgroundClick={handleBackgroundClick}
d3VelocityDecay={0.3} d3VelocityDecay={data ? 0.3 : undefined}
/> />
) : (
<ForceGraph
ref={graphRef}
width={dimensions.width}
height={dimensions.height}
dagMode={graphShape as unknown as undefined}
dagLevelDistance={100}
graphData={{
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
}}
nodeLabel="label"
nodeRelSize={20}
nodeCanvasObject={renderInitialNode}
nodeCanvasObjectMode={() => "after"}
nodeAutoColorBy="type"
linkLabel="label"
linkCanvasObject={renderLink}
linkCanvasObjectMode={() => "after"}
linkDirectionalArrowLength={3.5}
linkDirectionalArrowRelPos={1}
/>
)}
</div> </div>
); );
} }

View file

@ -51,6 +51,13 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
) )
.then((response) => handleServerErrors(response, retry, useCloud)) .then((response) => handleServerErrors(response, retry, useCloud))
.catch((error) => { .catch((error) => {
// Handle network errors more gracefully
if (error.name === 'TypeError' && error.message.includes('fetch')) {
return Promise.reject(
new Error("Backend server is not responding. Please check if the server is running.")
);
}
if (error.detail === undefined) { if (error.detail === undefined) {
return Promise.reject( return Promise.reject(
new Error("No connection to the server.") new Error("No connection to the server.")
@ -64,8 +71,27 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
}); });
} }
fetch.checkHealth = () => { fetch.checkHealth = async () => {
return global.fetch(`${backendApiUrl.replace("/api", "")}/health`); const maxRetries = 5;
const retryDelay = 1000; // 1 second
for (let i = 0; i < maxRetries; i++) {
try {
const response = await global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
if (response.ok) {
return response;
}
} catch (error) {
// If this is the last retry, throw the error
if (i === maxRetries - 1) {
throw error;
}
// Wait before retrying
await new Promise(resolve => setTimeout(resolve, retryDelay));
}
}
throw new Error("Backend server is not responding after multiple attempts");
}; };
fetch.checkMCPHealth = () => { fetch.checkMCPHealth = () => {

View file

@ -194,7 +194,7 @@ class HealthChecker:
config = get_llm_config() config = get_llm_config()
# Test actual API connection with minimal request # Test actual API connection with minimal request
LLMGateway.show_prompt("test", "test") LLMGateway.show_prompt("test", "test.txt")
response_time = int((time.time() - start_time) * 1000) response_time = int((time.time() - start_time) * 1000)
return ComponentHealth( return ComponentHealth(

View file

@ -1,4 +1,6 @@
import os import os
import platform
import signal
import socket import socket
import subprocess import subprocess
import threading import threading
@ -288,6 +290,7 @@ def check_node_npm() -> tuple[bool, str]:
Check if Node.js and npm are available. Check if Node.js and npm are available.
Returns (is_available, error_message) Returns (is_available, error_message)
""" """
try: try:
# Check Node.js # Check Node.js
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10) result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
@ -297,8 +300,17 @@ def check_node_npm() -> tuple[bool, str]:
node_version = result.stdout.strip() node_version = result.stdout.strip()
logger.debug(f"Found Node.js version: {node_version}") logger.debug(f"Found Node.js version: {node_version}")
# Check npm # Check npm - handle Windows PowerShell scripts
result = subprocess.run(["npm", "--version"], capture_output=True, text=True, timeout=10) if platform.system() == "Windows":
# On Windows, npm might be a PowerShell script, so we need to use shell=True
result = subprocess.run(
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
)
else:
result = subprocess.run(
["npm", "--version"], capture_output=True, text=True, timeout=10
)
if result.returncode != 0: if result.returncode != 0:
return False, "npm is not installed or not in PATH" return False, "npm is not installed or not in PATH"
@ -320,6 +332,7 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
Install frontend dependencies if node_modules doesn't exist. Install frontend dependencies if node_modules doesn't exist.
This is needed for both development and downloaded frontends since both use npm run dev. This is needed for both development and downloaded frontends since both use npm run dev.
""" """
node_modules = frontend_path / "node_modules" node_modules = frontend_path / "node_modules"
if node_modules.exists(): if node_modules.exists():
logger.debug("Frontend dependencies already installed") logger.debug("Frontend dependencies already installed")
@ -328,13 +341,24 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
logger.info("Installing frontend dependencies (this may take a few minutes)...") logger.info("Installing frontend dependencies (this may take a few minutes)...")
try: try:
result = subprocess.run( # Use shell=True on Windows for npm commands
["npm", "install"], if platform.system() == "Windows":
cwd=frontend_path, result = subprocess.run(
capture_output=True, ["npm", "install"],
text=True, cwd=frontend_path,
timeout=300, # 5 minutes timeout capture_output=True,
) text=True,
timeout=300, # 5 minutes timeout
shell=True,
)
else:
result = subprocess.run(
["npm", "install"],
cwd=frontend_path,
capture_output=True,
text=True,
timeout=300, # 5 minutes timeout
)
if result.returncode == 0: if result.returncode == 0:
logger.info("Frontend dependencies installed successfully") logger.info("Frontend dependencies installed successfully")
@ -595,14 +619,27 @@ def start_ui(
try: try:
# Create frontend in its own process group for clean termination # Create frontend in its own process group for clean termination
process = subprocess.Popen( # Use shell=True on Windows for npm commands
["npm", "run", "dev"], if platform.system() == "Windows":
cwd=frontend_path, process = subprocess.Popen(
env=env, ["npm", "run", "dev"],
stdout=subprocess.PIPE, cwd=frontend_path,
stderr=subprocess.PIPE, env=env,
preexec_fn=os.setsid if hasattr(os, "setsid") else None, stdout=subprocess.PIPE,
) stderr=subprocess.PIPE,
text=True,
shell=True,
)
else:
process = subprocess.Popen(
["npm", "run", "dev"],
cwd=frontend_path,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
# Start threads to stream frontend output with prefix # Start threads to stream frontend output with prefix
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow _stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow

View file

@ -183,10 +183,20 @@ def main() -> int:
for pid in spawned_pids: for pid in spawned_pids:
try: try:
pgid = os.getpgid(pid) if hasattr(os, "killpg"):
os.killpg(pgid, signal.SIGTERM) # Unix-like systems: Use process groups
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.") pgid = os.getpgid(pid)
except (OSError, ProcessLookupError) as e: os.killpg(pgid, signal.SIGTERM)
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
else:
# Windows: Use taskkill to terminate process and its children
subprocess.run(
["taskkill", "/F", "/T", "/PID", str(pid)],
capture_output=True,
check=False,
)
fmt.success(f"✓ Process {pid} and its children terminated.")
except (OSError, ProcessLookupError, subprocess.SubprocessError) as e:
fmt.warning(f"Could not terminate process {pid}: {e}") fmt.warning(f"Could not terminate process {pid}: {e}")
sys.exit(0) sys.exit(0)

View file

@ -6,6 +6,7 @@ from cognee.cli.reference import SupportsCliCommand
from cognee.cli import DEFAULT_DOCS_URL from cognee.cli import DEFAULT_DOCS_URL
import cognee.cli.echo as fmt import cognee.cli.echo as fmt
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
from cognee.modules.data.methods.get_deletion_counts import get_deletion_counts
class DeleteCommand(SupportsCliCommand): class DeleteCommand(SupportsCliCommand):
@ -41,7 +42,34 @@ Be careful with deletion operations as they are irreversible.
fmt.error("Please specify what to delete: --dataset-name, --user-id, or --all") fmt.error("Please specify what to delete: --dataset-name, --user-id, or --all")
return return
# Build confirmation message # If --force is used, skip the preview and go straight to deletion
if not args.force:
# --- START PREVIEW LOGIC ---
fmt.echo("Gathering data for preview...")
try:
preview_data = asyncio.run(
get_deletion_counts(
dataset_name=args.dataset_name,
user_id=args.user_id,
all_data=args.all,
)
)
except CliCommandException as e:
fmt.error(f"Error occured when fetching preview data: {str(e)}")
return
if not preview_data:
fmt.success("No data found to delete.")
return
fmt.echo("You are about to delete:")
fmt.echo(
f"Datasets: {preview_data.datasets}\nEntries: {preview_data.entries}\nUsers: {preview_data.users}"
)
fmt.echo("-" * 20)
# --- END PREVIEW LOGIC ---
# Build operation message for success/failure logging
if args.all: if args.all:
confirm_msg = "Delete ALL data from cognee?" confirm_msg = "Delete ALL data from cognee?"
operation = "all data" operation = "all data"
@ -51,8 +79,9 @@ Be careful with deletion operations as they are irreversible.
elif args.user_id: elif args.user_id:
confirm_msg = f"Delete all data for user '{args.user_id}'?" confirm_msg = f"Delete all data for user '{args.user_id}'?"
operation = f"data for user '{args.user_id}'" operation = f"data for user '{args.user_id}'"
else:
operation = "data"
# Confirm deletion unless forced
if not args.force: if not args.force:
fmt.warning("This operation is irreversible!") fmt.warning("This operation is irreversible!")
if not fmt.confirm(confirm_msg): if not fmt.confirm(confirm_msg):
@ -64,6 +93,8 @@ Be careful with deletion operations as they are irreversible.
# Run the async delete function # Run the async delete function
async def run_delete(): async def run_delete():
try: try:
# NOTE: The underlying cognee.delete() function is currently not working as expected.
# This is a separate bug that this preview feature helps to expose.
if args.all: if args.all:
await cognee.delete(dataset_name=None, user_id=args.user_id) await cognee.delete(dataset_name=None, user_id=args.user_id)
else: else:
@ -72,6 +103,7 @@ Be careful with deletion operations as they are irreversible.
raise CliCommandInnerException(f"Failed to delete: {str(e)}") raise CliCommandInnerException(f"Failed to delete: {str(e)}")
asyncio.run(run_delete()) asyncio.run(run_delete())
# This success message may be inaccurate due to the underlying bug, but we leave it for now.
fmt.success(f"Successfully deleted {operation}") fmt.success(f"Successfully deleted {operation}")
except Exception as e: except Exception as e:

View file

@ -26,6 +26,7 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
read due to an error. read due to an error.
""" """
logger = get_logger(level=ERROR) logger = get_logger(level=ERROR)
try: try:
if base_directory is None: if base_directory is None:
base_directory = get_absolute_path("./infrastructure/llm/prompts") base_directory = get_absolute_path("./infrastructure/llm/prompts")
@ -35,8 +36,8 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
with open(file_path, "r", encoding="utf-8") as file: with open(file_path, "r", encoding="utf-8") as file:
return file.read() return file.read()
except FileNotFoundError: except FileNotFoundError:
logger.error(f"Error: Prompt file not found. Attempted to read: %s {file_path}") logger.error(f"Error: Prompt file not found. Attempted to read: {file_path}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"An error occurred: %s {e}") logger.error(f"An error occurred: {e}")
return None return None

View file

@ -0,0 +1 @@
Respond with: test

View file

@ -73,10 +73,19 @@ class OpenAIAdapter(LLMInterface):
fallback_api_key: str = None, fallback_api_key: str = None,
fallback_endpoint: str = None, fallback_endpoint: str = None,
): ):
self.aclient = instructor.from_litellm( # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA # Make sure all new gpt models will work with this mode as well.
) if "gpt-5" in model:
self.client = instructor.from_litellm(litellm.completion, mode=instructor.Mode.JSON_SCHEMA) self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
)
self.client = instructor.from_litellm(
litellm.completion, mode=instructor.Mode.JSON_SCHEMA
)
else:
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
self.transcription_model = transcription_model self.transcription_model = transcription_model
self.model = model self.model = model
self.api_key = api_key self.api_key = api_key

View file

@ -27,6 +27,7 @@ class LoaderEngine:
self.default_loader_priority = [ self.default_loader_priority = [
"text_loader", "text_loader",
"advanced_pdf_loader",
"pypdf_loader", "pypdf_loader",
"image_loader", "image_loader",
"audio_loader", "audio_loader",
@ -86,7 +87,7 @@ class LoaderEngine:
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime): if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
return loader return loader
else: else:
raise ValueError(f"Loader does not exist: {loader_name}") logger.info(f"Skipping {loader_name}: Preferred Loader not registered")
# Try default priority order # Try default priority order
for loader_name in self.default_loader_priority: for loader_name in self.default_loader_priority:
@ -95,7 +96,9 @@ class LoaderEngine:
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime): if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
return loader return loader
else: else:
raise ValueError(f"Loader does not exist: {loader_name}") logger.info(
f"Skipping {loader_name}: Loader not registered (in default priority list)."
)
return None return None

View file

@ -20,3 +20,10 @@ try:
__all__.append("UnstructuredLoader") __all__.append("UnstructuredLoader")
except ImportError: except ImportError:
pass pass
try:
from .advanced_pdf_loader import AdvancedPdfLoader
__all__.append("AdvancedPdfLoader")
except ImportError:
pass

View file

@ -0,0 +1,244 @@
"""Advanced PDF loader leveraging unstructured for layout-aware extraction."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import asyncio
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.loaders.external.pypdf_loader import PyPdfLoader
logger = get_logger(__name__)
try:
from unstructured.partition.pdf import partition_pdf
except ImportError as e:
logger.info(
"unstructured[pdf] not installed, can't use AdvancedPdfLoader, will use PyPdfLoader instead."
)
raise ImportError from e
@dataclass
class _PageBuffer:
page_num: Optional[int]
segments: List[str]
class AdvancedPdfLoader(LoaderInterface):
"""
PDF loader using unstructured library.
Extracts text content, images, tables from PDF files page by page, providing
structured page information and handling PDF-specific errors.
"""
@property
def supported_extensions(self) -> List[str]:
return ["pdf"]
@property
def supported_mime_types(self) -> List[str]:
return ["application/pdf"]
@property
def loader_name(self) -> str:
return "advanced_pdf_loader"
def can_handle(self, extension: str, mime_type: str) -> bool:
"""Check if file can be handled by this loader."""
# Check file extension
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
return True
return False
async def load(self, file_path: str, strategy: str = "auto", **kwargs: Any) -> str:
"""Load PDF file using unstructured library. If Exception occurs, fallback to PyPDFLoader.
Args:
file_path: Path to the document file
strategy: Partitioning strategy ("auto", "fast", "hi_res", "ocr_only")
**kwargs: Additional arguments passed to unstructured partition
Returns:
LoaderResult with extracted text content and metadata
"""
try:
logger.info(f"Processing PDF: {file_path}")
with open(file_path, "rb") as f:
file_metadata = await get_file_metadata(f)
# Name ingested file of current loader based on original file content hash
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
# Set partitioning parameters
partition_kwargs: Dict[str, Any] = {
"filename": file_path,
"strategy": strategy,
"infer_table_structure": True,
"include_page_breaks": False,
"include_metadata": True,
**kwargs,
}
# Use partition to extract elements
elements = partition_pdf(**partition_kwargs)
# Process elements into text content
page_contents = self._format_elements_by_page(elements)
# Check if there is any content
if not page_contents:
logger.warning(
"AdvancedPdfLoader returned no content. Falling back to PyPDF loader."
)
return await self._fallback(file_path, **kwargs)
# Combine all page outputs
full_content = "\n".join(page_contents)
# Store the content
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
storage = get_file_storage(data_root_directory)
full_file_path = await storage.store(storage_file_name, full_content)
return full_file_path
except Exception as exc:
logger.warning("Failed to process PDF with AdvancedPdfLoader: %s", exc)
return await self._fallback(file_path, **kwargs)
async def _fallback(self, file_path: str, **kwargs: Any) -> str:
logger.info("Falling back to PyPDF loader for %s", file_path)
fallback_loader = PyPdfLoader()
return await fallback_loader.load(file_path, **kwargs)
def _format_elements_by_page(self, elements: List[Any]) -> List[str]:
"""Format elements by page."""
page_buffers: List[_PageBuffer] = []
current_buffer = _PageBuffer(page_num=None, segments=[])
for element in elements:
element_dict = self._safe_to_dict(element)
metadata = element_dict.get("metadata", {})
page_num = metadata.get("page_number")
if current_buffer.page_num != page_num:
if current_buffer.segments:
page_buffers.append(current_buffer)
current_buffer = _PageBuffer(page_num=page_num, segments=[])
formatted = self._format_element(element_dict)
if formatted:
current_buffer.segments.append(formatted)
if current_buffer.segments:
page_buffers.append(current_buffer)
page_contents: List[str] = []
for buffer in page_buffers:
header = f"Page {buffer.page_num}:\n" if buffer.page_num is not None else "Page:"
content = header + "\n\n".join(buffer.segments) + "\n"
page_contents.append(str(content))
return page_contents
def _format_element(
self,
element: Dict[str, Any],
) -> str:
"""Format element."""
element_type = element.get("type")
text = self._clean_text(element.get("text", ""))
metadata = element.get("metadata", {})
if element_type.lower() == "table":
return self._format_table_element(element) or text
if element_type.lower() == "image":
description = text or self._format_image_element(metadata)
return description
# Ignore header and footer
if element_type.lower() in ["header", "footer"]:
pass
return text
def _format_table_element(self, element: Dict[str, Any]) -> str:
"""Format table element."""
metadata = element.get("metadata", {})
text = self._clean_text(element.get("text", ""))
table_html = metadata.get("text_as_html")
if table_html:
return table_html.strip()
return text
def _format_image_element(self, metadata: Dict[str, Any]) -> str:
"""Format image."""
placeholder = "[Image omitted]"
image_text = placeholder
coordinates = metadata.get("coordinates", {})
points = coordinates.get("points") if isinstance(coordinates, dict) else None
if points and isinstance(points, tuple) and len(points) == 4:
leftup = points[0]
rightdown = points[3]
if (
isinstance(leftup, tuple)
and isinstance(rightdown, tuple)
and len(leftup) == 2
and len(rightdown) == 2
):
image_text = f"{placeholder} (bbox=({leftup[0]}, {leftup[1]}, {rightdown[0]}, {rightdown[1]}))"
layout_width = coordinates.get("layout_width")
layout_height = coordinates.get("layout_height")
system = coordinates.get("system")
if layout_width and layout_height and system:
image_text = (
image_text
+ f", system={system}, layout_width={layout_width}, layout_height={layout_height}))"
)
return image_text
def _safe_to_dict(self, element: Any) -> Dict[str, Any]:
"""Safe to dict."""
try:
if hasattr(element, "to_dict"):
return element.to_dict()
except Exception:
pass
fallback_type = getattr(element, "category", None)
if not fallback_type:
fallback_type = getattr(element, "__class__", type("", (), {})).__name__
return {
"type": fallback_type,
"text": getattr(element, "text", ""),
"metadata": getattr(element, "metadata", {}),
}
def _clean_text(self, value: Any) -> str:
if value is None:
return ""
return str(value).replace("\xa0", " ").strip()
if __name__ == "__main__":
loader = AdvancedPdfLoader()
asyncio.run(
loader.load(
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
)
)

View file

@ -16,3 +16,10 @@ try:
supported_loaders[UnstructuredLoader.loader_name] = UnstructuredLoader supported_loaders[UnstructuredLoader.loader_name] = UnstructuredLoader
except ImportError: except ImportError:
pass pass
try:
from cognee.infrastructure.loaders.external import AdvancedPdfLoader
supported_loaders[AdvancedPdfLoader.loader_name] = AdvancedPdfLoader
except ImportError:
pass

View file

@ -0,0 +1,92 @@
from uuid import UUID
from cognee.cli.exceptions import CliCommandException
from cognee.infrastructure.databases.exceptions.exceptions import EntityNotFoundError
from sqlalchemy import select
from sqlalchemy.sql import func
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Dataset, Data, DatasetData
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_user
from dataclasses import dataclass
@dataclass
class DeletionCountsPreview:
datasets: int = 0
data_entries: int = 0
users: int = 0
async def get_deletion_counts(
dataset_name: str = None, user_id: str = None, all_data: bool = False
) -> DeletionCountsPreview:
"""
Calculates the number of items that will be deleted based on the provided arguments.
"""
counts = DeletionCountsPreview()
relational_engine = get_relational_engine()
async with relational_engine.get_async_session() as session:
if dataset_name:
# Find the dataset by name
dataset_result = await session.execute(
select(Dataset).where(Dataset.name == dataset_name)
)
dataset = dataset_result.scalar_one_or_none()
if dataset is None:
raise CliCommandException(
f"No Dataset exists with the name {dataset_name}", error_code=1
)
# Count data entries linked to this dataset
count_query = (
select(func.count())
.select_from(DatasetData)
.where(DatasetData.dataset_id == dataset.id)
)
data_entry_count = (await session.execute(count_query)).scalar_one()
counts.users = 1
counts.datasets = 1
counts.entries = data_entry_count
return counts
elif all_data:
# Simplified logic: Get total counts directly from the tables.
counts.datasets = (
await session.execute(select(func.count()).select_from(Dataset))
).scalar_one()
counts.entries = (
await session.execute(select(func.count()).select_from(Data))
).scalar_one()
counts.users = (
await session.execute(select(func.count()).select_from(User))
).scalar_one()
return counts
# Placeholder for user_id logic
elif user_id:
user = None
try:
user_uuid = UUID(user_id)
user = await get_user(user_uuid)
except (ValueError, EntityNotFoundError):
raise CliCommandException(f"No User exists with ID {user_id}", error_code=1)
counts.users = 1
# Find all datasets owned by this user
datasets_query = select(Dataset).where(Dataset.owner_id == user.id)
user_datasets = (await session.execute(datasets_query)).scalars().all()
dataset_count = len(user_datasets)
counts.datasets = dataset_count
if dataset_count > 0:
dataset_ids = [d.id for d in user_datasets]
# Count all data entries across all of the user's datasets
data_count_query = (
select(func.count())
.select_from(DatasetData)
.where(DatasetData.dataset_id.in_(dataset_ids))
)
data_entry_count = (await session.execute(data_count_query)).scalar_one()
counts.entries = data_entry_count
else:
counts.entries = 0
return counts

View file

@ -12,7 +12,8 @@ from cognee.cli.commands.search_command import SearchCommand
from cognee.cli.commands.cognify_command import CognifyCommand from cognee.cli.commands.cognify_command import CognifyCommand
from cognee.cli.commands.delete_command import DeleteCommand from cognee.cli.commands.delete_command import DeleteCommand
from cognee.cli.commands.config_command import ConfigCommand from cognee.cli.commands.config_command import ConfigCommand
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException from cognee.cli.exceptions import CliCommandException
from cognee.modules.data.methods.get_deletion_counts import DeletionCountsPreview
# Mock asyncio.run to properly handle coroutines # Mock asyncio.run to properly handle coroutines
@ -282,13 +283,18 @@ class TestDeleteCommand:
assert "all" in actions assert "all" in actions
assert "force" in actions assert "force" in actions
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
@patch("cognee.cli.commands.delete_command.fmt.confirm") @patch("cognee.cli.commands.delete_command.fmt.confirm")
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run) @patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
def test_execute_delete_dataset_with_confirmation(self, mock_asyncio_run, mock_confirm): def test_execute_delete_dataset_with_confirmation(
self, mock_asyncio_run, mock_confirm, mock_get_deletion_counts
):
"""Test execute delete dataset with user confirmation""" """Test execute delete dataset with user confirmation"""
# Mock the cognee module # Mock the cognee module
mock_cognee = MagicMock() mock_cognee = MagicMock()
mock_cognee.delete = AsyncMock() mock_cognee.delete = AsyncMock()
mock_get_deletion_counts = AsyncMock()
mock_get_deletion_counts.return_value = DeletionCountsPreview()
with patch.dict(sys.modules, {"cognee": mock_cognee}): with patch.dict(sys.modules, {"cognee": mock_cognee}):
command = DeleteCommand() command = DeleteCommand()
@ -301,13 +307,16 @@ class TestDeleteCommand:
command.execute(args) command.execute(args)
mock_confirm.assert_called_once_with(f"Delete dataset '{args.dataset_name}'?") mock_confirm.assert_called_once_with(f"Delete dataset '{args.dataset_name}'?")
mock_asyncio_run.assert_called_once() assert mock_asyncio_run.call_count == 2
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0]) assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
mock_cognee.delete.assert_awaited_once_with(dataset_name="test_dataset", user_id=None) mock_cognee.delete.assert_awaited_once_with(dataset_name="test_dataset", user_id=None)
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
@patch("cognee.cli.commands.delete_command.fmt.confirm") @patch("cognee.cli.commands.delete_command.fmt.confirm")
def test_execute_delete_cancelled(self, mock_confirm): def test_execute_delete_cancelled(self, mock_confirm, mock_get_deletion_counts):
"""Test execute when user cancels deletion""" """Test execute when user cancels deletion"""
mock_get_deletion_counts = AsyncMock()
mock_get_deletion_counts.return_value = DeletionCountsPreview()
command = DeleteCommand() command = DeleteCommand()
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False) args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)

View file

@ -13,6 +13,7 @@ from cognee.cli.commands.cognify_command import CognifyCommand
from cognee.cli.commands.delete_command import DeleteCommand from cognee.cli.commands.delete_command import DeleteCommand
from cognee.cli.commands.config_command import ConfigCommand from cognee.cli.commands.config_command import ConfigCommand
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
from cognee.modules.data.methods.get_deletion_counts import DeletionCountsPreview
# Mock asyncio.run to properly handle coroutines # Mock asyncio.run to properly handle coroutines
@ -378,13 +379,18 @@ class TestCognifyCommandEdgeCases:
class TestDeleteCommandEdgeCases: class TestDeleteCommandEdgeCases:
"""Test edge cases for DeleteCommand""" """Test edge cases for DeleteCommand"""
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
@patch("cognee.cli.commands.delete_command.fmt.confirm") @patch("cognee.cli.commands.delete_command.fmt.confirm")
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run) @patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
def test_delete_all_with_user_id(self, mock_asyncio_run, mock_confirm): def test_delete_all_with_user_id(
self, mock_asyncio_run, mock_confirm, mock_get_deletion_counts
):
"""Test delete command with both --all and --user-id""" """Test delete command with both --all and --user-id"""
# Mock the cognee module # Mock the cognee module
mock_cognee = MagicMock() mock_cognee = MagicMock()
mock_cognee.delete = AsyncMock() mock_cognee.delete = AsyncMock()
mock_get_deletion_counts = AsyncMock()
mock_get_deletion_counts.return_value = DeletionCountsPreview()
with patch.dict(sys.modules, {"cognee": mock_cognee}): with patch.dict(sys.modules, {"cognee": mock_cognee}):
command = DeleteCommand() command = DeleteCommand()
@ -396,13 +402,17 @@ class TestDeleteCommandEdgeCases:
command.execute(args) command.execute(args)
mock_confirm.assert_called_once_with("Delete ALL data from cognee?") mock_confirm.assert_called_once_with("Delete ALL data from cognee?")
mock_asyncio_run.assert_called_once() assert mock_asyncio_run.call_count == 2
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0]) assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
mock_cognee.delete.assert_awaited_once_with(dataset_name=None, user_id="test_user") mock_cognee.delete.assert_awaited_once_with(dataset_name=None, user_id="test_user")
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
@patch("cognee.cli.commands.delete_command.fmt.confirm") @patch("cognee.cli.commands.delete_command.fmt.confirm")
def test_delete_confirmation_keyboard_interrupt(self, mock_confirm): def test_delete_confirmation_keyboard_interrupt(self, mock_confirm, mock_get_deletion_counts):
"""Test delete command when user interrupts confirmation""" """Test delete command when user interrupts confirmation"""
mock_get_deletion_counts = AsyncMock()
mock_get_deletion_counts.return_value = DeletionCountsPreview()
command = DeleteCommand() command = DeleteCommand()
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False) args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)

View file

@ -0,0 +1,141 @@
import sys
from unittest.mock import patch, MagicMock, AsyncMock, mock_open
import pytest
from cognee.infrastructure.loaders.external.advanced_pdf_loader import AdvancedPdfLoader
advanced_pdf_loader_module = sys.modules.get(
"cognee.infrastructure.loaders.external.advanced_pdf_loader"
)
class MockElement:
def __init__(self, category, text, metadata):
self.category = category
self.text = text
self.metadata = metadata
def to_dict(self):
return {
"type": self.category,
"text": self.text,
"metadata": self.metadata,
}
@pytest.fixture
def loader():
return AdvancedPdfLoader()
@pytest.mark.parametrize(
"extension, mime_type, expected",
[
("pdf", "application/pdf", True),
("txt", "text/plain", False),
("pdf", "text/plain", False),
("doc", "application/pdf", False),
],
)
def test_can_handle(loader, extension, mime_type, expected):
"""Test can_handle method can correctly identify PDF files"""
assert loader.can_handle(extension, mime_type) == expected
@pytest.mark.asyncio
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.open", new_callable=mock_open)
@patch(
"cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_metadata",
new_callable=AsyncMock,
)
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.get_storage_config")
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_storage")
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.PyPdfLoader")
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.partition_pdf")
async def test_load_success_with_unstructured(
mock_partition_pdf,
mock_pypdf_loader,
mock_get_file_storage,
mock_get_storage_config,
mock_get_file_metadata,
mock_open,
loader,
):
"""Test the main flow of using unstructured to successfully process PDF"""
# Prepare Mock data and objects
mock_elements = [
MockElement(
category="Title", text="Attention Is All You Need", metadata={"page_number": 1}
),
MockElement(
category="NarrativeText",
text="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks.",
metadata={"page_number": 1},
),
MockElement(
category="Table",
text="This is a table.",
metadata={"page_number": 2, "text_as_html": "<table><tr><td>Data</td></tr></table>"},
),
]
mock_pypdf_loader.return_value.load = AsyncMock(return_value="/fake/path/fallback.txt")
mock_partition_pdf.return_value = mock_elements
mock_get_file_metadata.return_value = {"content_hash": "abc123def456"}
mock_storage_instance = MagicMock()
mock_storage_instance.store = AsyncMock(return_value="/stored/text_abc123def456.txt")
mock_get_file_storage.return_value = mock_storage_instance
mock_get_storage_config.return_value = {"data_root_directory": "/fake/data/root"}
test_file_path = "/fake/path/document.pdf"
# Run
result_path = await loader.load(test_file_path)
# Assert
assert result_path == "/stored/text_abc123def456.txt"
# Verify partition_pdf is called with the correct parameters
mock_partition_pdf.assert_called_once()
call_args, call_kwargs = mock_partition_pdf.call_args
assert call_kwargs.get("filename") == test_file_path
assert call_kwargs.get("strategy") == "auto" # Default strategy
# Verify the stored content is correct
expected_content = "Page 1:\nAttention Is All You Need\n\nThe dominant sequence transduction models are based on complex recurrent or convolutional neural networks.\n\nPage 2:\n<table><tr><td>Data</td></tr></table>\n"
mock_storage_instance.store.assert_awaited_once_with("text_abc123def456.txt", expected_content)
# Verify fallback is not called
mock_pypdf_loader.assert_not_called()
@pytest.mark.asyncio
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.open", new_callable=mock_open)
@patch(
"cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_metadata",
new_callable=AsyncMock,
)
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.PyPdfLoader")
@patch(
"cognee.infrastructure.loaders.external.advanced_pdf_loader.partition_pdf",
side_effect=Exception("Unstructured failed!"),
)
async def test_load_fallback_on_unstructured_exception(
mock_partition_pdf, mock_pypdf_loader, mock_get_file_metadata, mock_open, loader
):
"""Test fallback to PyPdfLoader when unstructured throws an exception"""
# Prepare Mock
mock_fallback_instance = MagicMock()
mock_fallback_instance.load = AsyncMock(return_value="/fake/path/fallback.txt")
mock_pypdf_loader.return_value = mock_fallback_instance
mock_get_file_metadata.return_value = {"content_hash": "anyhash"}
test_file_path = "/fake/path/document.pdf"
# Run
result_path = await loader.load(test_file_path)
# Assert
assert result_path == "/fake/path/fallback.txt"
mock_partition_pdf.assert_called_once() # Verify partition_pdf is called
mock_fallback_instance.load.assert_awaited_once_with(test_file_path)

View file

@ -41,7 +41,12 @@ class TestCogneeServerStart(unittest.TestCase):
def tearDownClass(cls): def tearDownClass(cls):
# Terminate the server process # Terminate the server process
if hasattr(cls, "server_process") and cls.server_process: if hasattr(cls, "server_process") and cls.server_process:
os.killpg(os.getpgid(cls.server_process.pid), signal.SIGTERM) if hasattr(os, "killpg"):
# Unix-like systems: Use process groups
os.killpg(os.getpgid(cls.server_process.pid), signal.SIGTERM)
else:
# Windows: Just terminate the main process
cls.server_process.terminate()
cls.server_process.wait() cls.server_process.wait()
def test_server_is_running(self): def test_server_is_running(self):

1634
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -95,7 +95,7 @@ chromadb = [
"chromadb>=0.6,<0.7", "chromadb>=0.6,<0.7",
"pypika==0.48.9", "pypika==0.48.9",
] ]
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx]>=0.18.1,<19"] docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx, pdf]>=0.18.1,<19"]
codegraph = [ codegraph = [
"fastembed<=0.6.0 ; python_version < '3.13'", "fastembed<=0.6.0 ; python_version < '3.13'",
"transformers>=4.46.3,<5", "transformers>=4.46.3,<5",
@ -142,6 +142,7 @@ Homepage = "https://www.cognee.ai"
Repository = "https://github.com/topoteretes/cognee" Repository = "https://github.com/topoteretes/cognee"
[project.scripts] [project.scripts]
cognee = "cognee.cli._cognee:main"
cognee-cli = "cognee.cli._cognee:main" cognee-cli = "cognee.cli._cognee:main"
[build-system] [build-system]

1320
uv.lock generated

File diff suppressed because it is too large Load diff