Merge branch 'dev' into update-endpoint
This commit is contained in:
commit
9ae3c97aef
22 changed files with 3074 additions and 697 deletions
|
|
@ -43,7 +43,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ubuntu-22.04, macos-13, macos-15, windows-latest]
|
||||
os: [ubuntu-22.04, macos-15, windows-latest]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -79,7 +79,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -115,7 +115,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -151,7 +151,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -180,7 +180,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
@ -210,7 +210,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ ubuntu-22.04, macos-13, macos-15, windows-latest ]
|
||||
os: [ ubuntu-22.04, macos-15, windows-latest ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
|
|||
|
|
@ -3,10 +3,18 @@
|
|||
import classNames from "classnames";
|
||||
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
|
||||
import { forceCollide, forceManyBody } from "d3-force-3d";
|
||||
import ForceGraph, { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
|
||||
import dynamic from "next/dynamic";
|
||||
import { GraphControlsAPI } from "./GraphControls";
|
||||
import getColorForNodeType from "./getColorForNodeType";
|
||||
|
||||
// Dynamically import ForceGraph to prevent SSR issues
|
||||
const ForceGraph = dynamic(() => import("react-force-graph-2d"), {
|
||||
ssr: false,
|
||||
loading: () => <div className="w-full h-full flex items-center justify-center">Loading graph...</div>
|
||||
});
|
||||
|
||||
import type { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
|
||||
|
||||
interface GraphVisuzaliationProps {
|
||||
ref: MutableRefObject<GraphVisualizationAPI>;
|
||||
data?: GraphData<NodeObject, LinkObject>;
|
||||
|
|
@ -200,7 +208,7 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
const graphRef = useRef<ForceGraphMethods>();
|
||||
|
||||
useEffect(() => {
|
||||
if (typeof window !== "undefined" && data && graphRef.current) {
|
||||
if (data && graphRef.current) {
|
||||
// add collision force
|
||||
graphRef.current.d3Force("collision", forceCollide(nodeSize * 1.5));
|
||||
graphRef.current.d3Force("charge", forceManyBody().strength(-10).distanceMin(10).distanceMax(50));
|
||||
|
|
@ -216,56 +224,34 @@ export default function GraphVisualization({ ref, data, graphControls, className
|
|||
|
||||
return (
|
||||
<div ref={containerRef} className={classNames("w-full h-full", className)} id="graph-container">
|
||||
{(data && typeof window !== "undefined") ? (
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
dagMode={graphShape as unknown as undefined}
|
||||
dagLevelDistance={300}
|
||||
onDagError={handleDagError}
|
||||
graphData={data}
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
dagMode={graphShape as unknown as undefined}
|
||||
dagLevelDistance={data ? 300 : 100}
|
||||
onDagError={handleDagError}
|
||||
graphData={data || {
|
||||
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
|
||||
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
|
||||
}}
|
||||
|
||||
nodeLabel="label"
|
||||
nodeRelSize={nodeSize}
|
||||
nodeCanvasObject={renderNode}
|
||||
nodeCanvasObjectMode={() => "replace"}
|
||||
nodeLabel="label"
|
||||
nodeRelSize={data ? nodeSize : 20}
|
||||
nodeCanvasObject={data ? renderNode : renderInitialNode}
|
||||
nodeCanvasObjectMode={() => data ? "replace" : "after"}
|
||||
nodeAutoColorBy={data ? undefined : "type"}
|
||||
|
||||
linkLabel="label"
|
||||
linkCanvasObject={renderLink}
|
||||
linkCanvasObjectMode={() => "after"}
|
||||
linkDirectionalArrowLength={3.5}
|
||||
linkDirectionalArrowRelPos={1}
|
||||
linkLabel="label"
|
||||
linkCanvasObject={renderLink}
|
||||
linkCanvasObjectMode={() => "after"}
|
||||
linkDirectionalArrowLength={3.5}
|
||||
linkDirectionalArrowRelPos={1}
|
||||
|
||||
onNodeClick={handleNodeClick}
|
||||
onBackgroundClick={handleBackgroundClick}
|
||||
d3VelocityDecay={0.3}
|
||||
/>
|
||||
) : (
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
dagMode={graphShape as unknown as undefined}
|
||||
dagLevelDistance={100}
|
||||
graphData={{
|
||||
nodes: [{ id: 1, label: "Add" }, { id: 2, label: "Cognify" }, { id: 3, label: "Search" }],
|
||||
links: [{ source: 1, target: 2, label: "but don't forget to" }, { source: 2, target: 3, label: "and after that you can" }],
|
||||
}}
|
||||
|
||||
nodeLabel="label"
|
||||
nodeRelSize={20}
|
||||
nodeCanvasObject={renderInitialNode}
|
||||
nodeCanvasObjectMode={() => "after"}
|
||||
nodeAutoColorBy="type"
|
||||
|
||||
linkLabel="label"
|
||||
linkCanvasObject={renderLink}
|
||||
linkCanvasObjectMode={() => "after"}
|
||||
linkDirectionalArrowLength={3.5}
|
||||
linkDirectionalArrowRelPos={1}
|
||||
/>
|
||||
)}
|
||||
onNodeClick={handleNodeClick}
|
||||
onBackgroundClick={handleBackgroundClick}
|
||||
d3VelocityDecay={data ? 0.3 : undefined}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,6 +51,13 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
|
|||
)
|
||||
.then((response) => handleServerErrors(response, retry, useCloud))
|
||||
.catch((error) => {
|
||||
// Handle network errors more gracefully
|
||||
if (error.name === 'TypeError' && error.message.includes('fetch')) {
|
||||
return Promise.reject(
|
||||
new Error("Backend server is not responding. Please check if the server is running.")
|
||||
);
|
||||
}
|
||||
|
||||
if (error.detail === undefined) {
|
||||
return Promise.reject(
|
||||
new Error("No connection to the server.")
|
||||
|
|
@ -64,8 +71,27 @@ export default async function fetch(url: string, options: RequestInit = {}, useC
|
|||
});
|
||||
}
|
||||
|
||||
fetch.checkHealth = () => {
|
||||
return global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
|
||||
fetch.checkHealth = async () => {
|
||||
const maxRetries = 5;
|
||||
const retryDelay = 1000; // 1 second
|
||||
|
||||
for (let i = 0; i < maxRetries; i++) {
|
||||
try {
|
||||
const response = await global.fetch(`${backendApiUrl.replace("/api", "")}/health`);
|
||||
if (response.ok) {
|
||||
return response;
|
||||
}
|
||||
} catch (error) {
|
||||
// If this is the last retry, throw the error
|
||||
if (i === maxRetries - 1) {
|
||||
throw error;
|
||||
}
|
||||
// Wait before retrying
|
||||
await new Promise(resolve => setTimeout(resolve, retryDelay));
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error("Backend server is not responding after multiple attempts");
|
||||
};
|
||||
|
||||
fetch.checkMCPHealth = () => {
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ class HealthChecker:
|
|||
config = get_llm_config()
|
||||
|
||||
# Test actual API connection with minimal request
|
||||
LLMGateway.show_prompt("test", "test")
|
||||
LLMGateway.show_prompt("test", "test.txt")
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
return ComponentHealth(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
import platform
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import threading
|
||||
|
|
@ -288,6 +290,7 @@ def check_node_npm() -> tuple[bool, str]:
|
|||
Check if Node.js and npm are available.
|
||||
Returns (is_available, error_message)
|
||||
"""
|
||||
|
||||
try:
|
||||
# Check Node.js
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=10)
|
||||
|
|
@ -297,8 +300,17 @@ def check_node_npm() -> tuple[bool, str]:
|
|||
node_version = result.stdout.strip()
|
||||
logger.debug(f"Found Node.js version: {node_version}")
|
||||
|
||||
# Check npm
|
||||
result = subprocess.run(["npm", "--version"], capture_output=True, text=True, timeout=10)
|
||||
# Check npm - handle Windows PowerShell scripts
|
||||
if platform.system() == "Windows":
|
||||
# On Windows, npm might be a PowerShell script, so we need to use shell=True
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10, shell=True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["npm", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return False, "npm is not installed or not in PATH"
|
||||
|
||||
|
|
@ -320,6 +332,7 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
|
|||
Install frontend dependencies if node_modules doesn't exist.
|
||||
This is needed for both development and downloaded frontends since both use npm run dev.
|
||||
"""
|
||||
|
||||
node_modules = frontend_path / "node_modules"
|
||||
if node_modules.exists():
|
||||
logger.debug("Frontend dependencies already installed")
|
||||
|
|
@ -328,13 +341,24 @@ def install_frontend_dependencies(frontend_path: Path) -> bool:
|
|||
logger.info("Installing frontend dependencies (this may take a few minutes)...")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
)
|
||||
# Use shell=True on Windows for npm commands
|
||||
if platform.system() == "Windows":
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minutes timeout
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("Frontend dependencies installed successfully")
|
||||
|
|
@ -595,14 +619,27 @@ def start_ui(
|
|||
|
||||
try:
|
||||
# Create frontend in its own process group for clean termination
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
# Use shell=True on Windows for npm commands
|
||||
if platform.system() == "Windows":
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
|
||||
# Start threads to stream frontend output with prefix
|
||||
_stream_process_output(process, "stdout", "[FRONTEND]", "\033[33m") # Yellow
|
||||
|
|
|
|||
|
|
@ -183,10 +183,20 @@ def main() -> int:
|
|||
|
||||
for pid in spawned_pids:
|
||||
try:
|
||||
pgid = os.getpgid(pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
|
||||
except (OSError, ProcessLookupError) as e:
|
||||
if hasattr(os, "killpg"):
|
||||
# Unix-like systems: Use process groups
|
||||
pgid = os.getpgid(pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
|
||||
else:
|
||||
# Windows: Use taskkill to terminate process and its children
|
||||
subprocess.run(
|
||||
["taskkill", "/F", "/T", "/PID", str(pid)],
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
fmt.success(f"✓ Process {pid} and its children terminated.")
|
||||
except (OSError, ProcessLookupError, subprocess.SubprocessError) as e:
|
||||
fmt.warning(f"Could not terminate process {pid}: {e}")
|
||||
|
||||
sys.exit(0)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from cognee.cli.reference import SupportsCliCommand
|
|||
from cognee.cli import DEFAULT_DOCS_URL
|
||||
import cognee.cli.echo as fmt
|
||||
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
||||
from cognee.modules.data.methods.get_deletion_counts import get_deletion_counts
|
||||
|
||||
|
||||
class DeleteCommand(SupportsCliCommand):
|
||||
|
|
@ -41,7 +42,34 @@ Be careful with deletion operations as they are irreversible.
|
|||
fmt.error("Please specify what to delete: --dataset-name, --user-id, or --all")
|
||||
return
|
||||
|
||||
# Build confirmation message
|
||||
# If --force is used, skip the preview and go straight to deletion
|
||||
if not args.force:
|
||||
# --- START PREVIEW LOGIC ---
|
||||
fmt.echo("Gathering data for preview...")
|
||||
try:
|
||||
preview_data = asyncio.run(
|
||||
get_deletion_counts(
|
||||
dataset_name=args.dataset_name,
|
||||
user_id=args.user_id,
|
||||
all_data=args.all,
|
||||
)
|
||||
)
|
||||
except CliCommandException as e:
|
||||
fmt.error(f"Error occured when fetching preview data: {str(e)}")
|
||||
return
|
||||
|
||||
if not preview_data:
|
||||
fmt.success("No data found to delete.")
|
||||
return
|
||||
|
||||
fmt.echo("You are about to delete:")
|
||||
fmt.echo(
|
||||
f"Datasets: {preview_data.datasets}\nEntries: {preview_data.entries}\nUsers: {preview_data.users}"
|
||||
)
|
||||
fmt.echo("-" * 20)
|
||||
# --- END PREVIEW LOGIC ---
|
||||
|
||||
# Build operation message for success/failure logging
|
||||
if args.all:
|
||||
confirm_msg = "Delete ALL data from cognee?"
|
||||
operation = "all data"
|
||||
|
|
@ -51,8 +79,9 @@ Be careful with deletion operations as they are irreversible.
|
|||
elif args.user_id:
|
||||
confirm_msg = f"Delete all data for user '{args.user_id}'?"
|
||||
operation = f"data for user '{args.user_id}'"
|
||||
else:
|
||||
operation = "data"
|
||||
|
||||
# Confirm deletion unless forced
|
||||
if not args.force:
|
||||
fmt.warning("This operation is irreversible!")
|
||||
if not fmt.confirm(confirm_msg):
|
||||
|
|
@ -64,6 +93,8 @@ Be careful with deletion operations as they are irreversible.
|
|||
# Run the async delete function
|
||||
async def run_delete():
|
||||
try:
|
||||
# NOTE: The underlying cognee.delete() function is currently not working as expected.
|
||||
# This is a separate bug that this preview feature helps to expose.
|
||||
if args.all:
|
||||
await cognee.delete(dataset_name=None, user_id=args.user_id)
|
||||
else:
|
||||
|
|
@ -72,6 +103,7 @@ Be careful with deletion operations as they are irreversible.
|
|||
raise CliCommandInnerException(f"Failed to delete: {str(e)}")
|
||||
|
||||
asyncio.run(run_delete())
|
||||
# This success message may be inaccurate due to the underlying bug, but we leave it for now.
|
||||
fmt.success(f"Successfully deleted {operation}")
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
|
|||
read due to an error.
|
||||
"""
|
||||
logger = get_logger(level=ERROR)
|
||||
|
||||
try:
|
||||
if base_directory is None:
|
||||
base_directory = get_absolute_path("./infrastructure/llm/prompts")
|
||||
|
|
@ -35,8 +36,8 @@ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
|
|||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Error: Prompt file not found. Attempted to read: %s {file_path}")
|
||||
logger.error(f"Error: Prompt file not found. Attempted to read: {file_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred: %s {e}")
|
||||
logger.error(f"An error occurred: {e}")
|
||||
return None
|
||||
|
|
|
|||
1
cognee/infrastructure/llm/prompts/test.txt
Normal file
1
cognee/infrastructure/llm/prompts/test.txt
Normal file
|
|
@ -0,0 +1 @@
|
|||
Respond with: test
|
||||
|
|
@ -73,10 +73,19 @@ class OpenAIAdapter(LLMInterface):
|
|||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
):
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
|
||||
)
|
||||
self.client = instructor.from_litellm(litellm.completion, mode=instructor.Mode.JSON_SCHEMA)
|
||||
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
||||
# Make sure all new gpt models will work with this mode as well.
|
||||
if "gpt-5" in model:
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
|
||||
)
|
||||
self.client = instructor.from_litellm(
|
||||
litellm.completion, mode=instructor.Mode.JSON_SCHEMA
|
||||
)
|
||||
else:
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
self.client = instructor.from_litellm(litellm.completion)
|
||||
|
||||
self.transcription_model = transcription_model
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class LoaderEngine:
|
|||
|
||||
self.default_loader_priority = [
|
||||
"text_loader",
|
||||
"advanced_pdf_loader",
|
||||
"pypdf_loader",
|
||||
"image_loader",
|
||||
"audio_loader",
|
||||
|
|
@ -86,7 +87,7 @@ class LoaderEngine:
|
|||
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
||||
return loader
|
||||
else:
|
||||
raise ValueError(f"Loader does not exist: {loader_name}")
|
||||
logger.info(f"Skipping {loader_name}: Preferred Loader not registered")
|
||||
|
||||
# Try default priority order
|
||||
for loader_name in self.default_loader_priority:
|
||||
|
|
@ -95,7 +96,9 @@ class LoaderEngine:
|
|||
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
||||
return loader
|
||||
else:
|
||||
raise ValueError(f"Loader does not exist: {loader_name}")
|
||||
logger.info(
|
||||
f"Skipping {loader_name}: Loader not registered (in default priority list)."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -20,3 +20,10 @@ try:
|
|||
__all__.append("UnstructuredLoader")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from .advanced_pdf_loader import AdvancedPdfLoader
|
||||
|
||||
__all__.append("AdvancedPdfLoader")
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
|||
244
cognee/infrastructure/loaders/external/advanced_pdf_loader.py
vendored
Normal file
244
cognee/infrastructure/loaders/external/advanced_pdf_loader.py
vendored
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
"""Advanced PDF loader leveraging unstructured for layout-aware extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
import asyncio
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.infrastructure.loaders.external.pypdf_loader import PyPdfLoader
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from unstructured.partition.pdf import partition_pdf
|
||||
except ImportError as e:
|
||||
logger.info(
|
||||
"unstructured[pdf] not installed, can't use AdvancedPdfLoader, will use PyPdfLoader instead."
|
||||
)
|
||||
raise ImportError from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PageBuffer:
|
||||
page_num: Optional[int]
|
||||
segments: List[str]
|
||||
|
||||
|
||||
class AdvancedPdfLoader(LoaderInterface):
|
||||
"""
|
||||
PDF loader using unstructured library.
|
||||
|
||||
Extracts text content, images, tables from PDF files page by page, providing
|
||||
structured page information and handling PDF-specific errors.
|
||||
"""
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
return ["pdf"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
return ["application/pdf"]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
return "advanced_pdf_loader"
|
||||
|
||||
def can_handle(self, extension: str, mime_type: str) -> bool:
|
||||
"""Check if file can be handled by this loader."""
|
||||
# Check file extension
|
||||
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def load(self, file_path: str, strategy: str = "auto", **kwargs: Any) -> str:
|
||||
"""Load PDF file using unstructured library. If Exception occurs, fallback to PyPDFLoader.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
strategy: Partitioning strategy ("auto", "fast", "hi_res", "ocr_only")
|
||||
**kwargs: Additional arguments passed to unstructured partition
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content and metadata
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Processing PDF: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_metadata = await get_file_metadata(f)
|
||||
|
||||
# Name ingested file of current loader based on original file content hash
|
||||
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
|
||||
|
||||
# Set partitioning parameters
|
||||
partition_kwargs: Dict[str, Any] = {
|
||||
"filename": file_path,
|
||||
"strategy": strategy,
|
||||
"infer_table_structure": True,
|
||||
"include_page_breaks": False,
|
||||
"include_metadata": True,
|
||||
**kwargs,
|
||||
}
|
||||
# Use partition to extract elements
|
||||
elements = partition_pdf(**partition_kwargs)
|
||||
|
||||
# Process elements into text content
|
||||
page_contents = self._format_elements_by_page(elements)
|
||||
|
||||
# Check if there is any content
|
||||
if not page_contents:
|
||||
logger.warning(
|
||||
"AdvancedPdfLoader returned no content. Falling back to PyPDF loader."
|
||||
)
|
||||
return await self._fallback(file_path, **kwargs)
|
||||
|
||||
# Combine all page outputs
|
||||
full_content = "\n".join(page_contents)
|
||||
|
||||
# Store the content
|
||||
storage_config = get_storage_config()
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
storage = get_file_storage(data_root_directory)
|
||||
|
||||
full_file_path = await storage.store(storage_file_name, full_content)
|
||||
|
||||
return full_file_path
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to process PDF with AdvancedPdfLoader: %s", exc)
|
||||
return await self._fallback(file_path, **kwargs)
|
||||
|
||||
async def _fallback(self, file_path: str, **kwargs: Any) -> str:
|
||||
logger.info("Falling back to PyPDF loader for %s", file_path)
|
||||
fallback_loader = PyPdfLoader()
|
||||
return await fallback_loader.load(file_path, **kwargs)
|
||||
|
||||
def _format_elements_by_page(self, elements: List[Any]) -> List[str]:
|
||||
"""Format elements by page."""
|
||||
page_buffers: List[_PageBuffer] = []
|
||||
current_buffer = _PageBuffer(page_num=None, segments=[])
|
||||
|
||||
for element in elements:
|
||||
element_dict = self._safe_to_dict(element)
|
||||
metadata = element_dict.get("metadata", {})
|
||||
page_num = metadata.get("page_number")
|
||||
|
||||
if current_buffer.page_num != page_num:
|
||||
if current_buffer.segments:
|
||||
page_buffers.append(current_buffer)
|
||||
current_buffer = _PageBuffer(page_num=page_num, segments=[])
|
||||
|
||||
formatted = self._format_element(element_dict)
|
||||
|
||||
if formatted:
|
||||
current_buffer.segments.append(formatted)
|
||||
|
||||
if current_buffer.segments:
|
||||
page_buffers.append(current_buffer)
|
||||
|
||||
page_contents: List[str] = []
|
||||
for buffer in page_buffers:
|
||||
header = f"Page {buffer.page_num}:\n" if buffer.page_num is not None else "Page:"
|
||||
content = header + "\n\n".join(buffer.segments) + "\n"
|
||||
page_contents.append(str(content))
|
||||
return page_contents
|
||||
|
||||
def _format_element(
|
||||
self,
|
||||
element: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Format element."""
|
||||
element_type = element.get("type")
|
||||
text = self._clean_text(element.get("text", ""))
|
||||
metadata = element.get("metadata", {})
|
||||
|
||||
if element_type.lower() == "table":
|
||||
return self._format_table_element(element) or text
|
||||
|
||||
if element_type.lower() == "image":
|
||||
description = text or self._format_image_element(metadata)
|
||||
return description
|
||||
|
||||
# Ignore header and footer
|
||||
if element_type.lower() in ["header", "footer"]:
|
||||
pass
|
||||
|
||||
return text
|
||||
|
||||
def _format_table_element(self, element: Dict[str, Any]) -> str:
|
||||
"""Format table element."""
|
||||
metadata = element.get("metadata", {})
|
||||
text = self._clean_text(element.get("text", ""))
|
||||
table_html = metadata.get("text_as_html")
|
||||
|
||||
if table_html:
|
||||
return table_html.strip()
|
||||
|
||||
return text
|
||||
|
||||
def _format_image_element(self, metadata: Dict[str, Any]) -> str:
|
||||
"""Format image."""
|
||||
placeholder = "[Image omitted]"
|
||||
image_text = placeholder
|
||||
coordinates = metadata.get("coordinates", {})
|
||||
points = coordinates.get("points") if isinstance(coordinates, dict) else None
|
||||
if points and isinstance(points, tuple) and len(points) == 4:
|
||||
leftup = points[0]
|
||||
rightdown = points[3]
|
||||
if (
|
||||
isinstance(leftup, tuple)
|
||||
and isinstance(rightdown, tuple)
|
||||
and len(leftup) == 2
|
||||
and len(rightdown) == 2
|
||||
):
|
||||
image_text = f"{placeholder} (bbox=({leftup[0]}, {leftup[1]}, {rightdown[0]}, {rightdown[1]}))"
|
||||
|
||||
layout_width = coordinates.get("layout_width")
|
||||
layout_height = coordinates.get("layout_height")
|
||||
system = coordinates.get("system")
|
||||
if layout_width and layout_height and system:
|
||||
image_text = (
|
||||
image_text
|
||||
+ f", system={system}, layout_width={layout_width}, layout_height={layout_height}))"
|
||||
)
|
||||
|
||||
return image_text
|
||||
|
||||
def _safe_to_dict(self, element: Any) -> Dict[str, Any]:
|
||||
"""Safe to dict."""
|
||||
try:
|
||||
if hasattr(element, "to_dict"):
|
||||
return element.to_dict()
|
||||
except Exception:
|
||||
pass
|
||||
fallback_type = getattr(element, "category", None)
|
||||
if not fallback_type:
|
||||
fallback_type = getattr(element, "__class__", type("", (), {})).__name__
|
||||
|
||||
return {
|
||||
"type": fallback_type,
|
||||
"text": getattr(element, "text", ""),
|
||||
"metadata": getattr(element, "metadata", {}),
|
||||
}
|
||||
|
||||
def _clean_text(self, value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value).replace("\xa0", " ").strip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = AdvancedPdfLoader()
|
||||
asyncio.run(
|
||||
loader.load(
|
||||
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
|
||||
)
|
||||
)
|
||||
|
|
@ -16,3 +16,10 @@ try:
|
|||
supported_loaders[UnstructuredLoader.loader_name] = UnstructuredLoader
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from cognee.infrastructure.loaders.external import AdvancedPdfLoader
|
||||
|
||||
supported_loaders[AdvancedPdfLoader.loader_name] = AdvancedPdfLoader
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
|||
92
cognee/modules/data/methods/get_deletion_counts.py
Normal file
92
cognee/modules/data/methods/get_deletion_counts.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
from uuid import UUID
|
||||
from cognee.cli.exceptions import CliCommandException
|
||||
from cognee.infrastructure.databases.exceptions.exceptions import EntityNotFoundError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.sql import func
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Dataset, Data, DatasetData
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_user
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeletionCountsPreview:
|
||||
datasets: int = 0
|
||||
data_entries: int = 0
|
||||
users: int = 0
|
||||
|
||||
|
||||
async def get_deletion_counts(
|
||||
dataset_name: str = None, user_id: str = None, all_data: bool = False
|
||||
) -> DeletionCountsPreview:
|
||||
"""
|
||||
Calculates the number of items that will be deleted based on the provided arguments.
|
||||
"""
|
||||
counts = DeletionCountsPreview()
|
||||
relational_engine = get_relational_engine()
|
||||
async with relational_engine.get_async_session() as session:
|
||||
if dataset_name:
|
||||
# Find the dataset by name
|
||||
dataset_result = await session.execute(
|
||||
select(Dataset).where(Dataset.name == dataset_name)
|
||||
)
|
||||
dataset = dataset_result.scalar_one_or_none()
|
||||
|
||||
if dataset is None:
|
||||
raise CliCommandException(
|
||||
f"No Dataset exists with the name {dataset_name}", error_code=1
|
||||
)
|
||||
|
||||
# Count data entries linked to this dataset
|
||||
count_query = (
|
||||
select(func.count())
|
||||
.select_from(DatasetData)
|
||||
.where(DatasetData.dataset_id == dataset.id)
|
||||
)
|
||||
data_entry_count = (await session.execute(count_query)).scalar_one()
|
||||
counts.users = 1
|
||||
counts.datasets = 1
|
||||
counts.entries = data_entry_count
|
||||
return counts
|
||||
|
||||
elif all_data:
|
||||
# Simplified logic: Get total counts directly from the tables.
|
||||
counts.datasets = (
|
||||
await session.execute(select(func.count()).select_from(Dataset))
|
||||
).scalar_one()
|
||||
counts.entries = (
|
||||
await session.execute(select(func.count()).select_from(Data))
|
||||
).scalar_one()
|
||||
counts.users = (
|
||||
await session.execute(select(func.count()).select_from(User))
|
||||
).scalar_one()
|
||||
return counts
|
||||
|
||||
# Placeholder for user_id logic
|
||||
elif user_id:
|
||||
user = None
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
user = await get_user(user_uuid)
|
||||
except (ValueError, EntityNotFoundError):
|
||||
raise CliCommandException(f"No User exists with ID {user_id}", error_code=1)
|
||||
counts.users = 1
|
||||
# Find all datasets owned by this user
|
||||
datasets_query = select(Dataset).where(Dataset.owner_id == user.id)
|
||||
user_datasets = (await session.execute(datasets_query)).scalars().all()
|
||||
dataset_count = len(user_datasets)
|
||||
counts.datasets = dataset_count
|
||||
if dataset_count > 0:
|
||||
dataset_ids = [d.id for d in user_datasets]
|
||||
# Count all data entries across all of the user's datasets
|
||||
data_count_query = (
|
||||
select(func.count())
|
||||
.select_from(DatasetData)
|
||||
.where(DatasetData.dataset_id.in_(dataset_ids))
|
||||
)
|
||||
data_entry_count = (await session.execute(data_count_query)).scalar_one()
|
||||
counts.entries = data_entry_count
|
||||
else:
|
||||
counts.entries = 0
|
||||
return counts
|
||||
|
|
@ -12,7 +12,8 @@ from cognee.cli.commands.search_command import SearchCommand
|
|||
from cognee.cli.commands.cognify_command import CognifyCommand
|
||||
from cognee.cli.commands.delete_command import DeleteCommand
|
||||
from cognee.cli.commands.config_command import ConfigCommand
|
||||
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
||||
from cognee.cli.exceptions import CliCommandException
|
||||
from cognee.modules.data.methods.get_deletion_counts import DeletionCountsPreview
|
||||
|
||||
|
||||
# Mock asyncio.run to properly handle coroutines
|
||||
|
|
@ -282,13 +283,18 @@ class TestDeleteCommand:
|
|||
assert "all" in actions
|
||||
assert "force" in actions
|
||||
|
||||
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
|
||||
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
||||
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
|
||||
def test_execute_delete_dataset_with_confirmation(self, mock_asyncio_run, mock_confirm):
|
||||
def test_execute_delete_dataset_with_confirmation(
|
||||
self, mock_asyncio_run, mock_confirm, mock_get_deletion_counts
|
||||
):
|
||||
"""Test execute delete dataset with user confirmation"""
|
||||
# Mock the cognee module
|
||||
mock_cognee = MagicMock()
|
||||
mock_cognee.delete = AsyncMock()
|
||||
mock_get_deletion_counts = AsyncMock()
|
||||
mock_get_deletion_counts.return_value = DeletionCountsPreview()
|
||||
|
||||
with patch.dict(sys.modules, {"cognee": mock_cognee}):
|
||||
command = DeleteCommand()
|
||||
|
|
@ -301,13 +307,16 @@ class TestDeleteCommand:
|
|||
command.execute(args)
|
||||
|
||||
mock_confirm.assert_called_once_with(f"Delete dataset '{args.dataset_name}'?")
|
||||
mock_asyncio_run.assert_called_once()
|
||||
assert mock_asyncio_run.call_count == 2
|
||||
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
|
||||
mock_cognee.delete.assert_awaited_once_with(dataset_name="test_dataset", user_id=None)
|
||||
|
||||
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
|
||||
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
||||
def test_execute_delete_cancelled(self, mock_confirm):
|
||||
def test_execute_delete_cancelled(self, mock_confirm, mock_get_deletion_counts):
|
||||
"""Test execute when user cancels deletion"""
|
||||
mock_get_deletion_counts = AsyncMock()
|
||||
mock_get_deletion_counts.return_value = DeletionCountsPreview()
|
||||
command = DeleteCommand()
|
||||
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from cognee.cli.commands.cognify_command import CognifyCommand
|
|||
from cognee.cli.commands.delete_command import DeleteCommand
|
||||
from cognee.cli.commands.config_command import ConfigCommand
|
||||
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
||||
from cognee.modules.data.methods.get_deletion_counts import DeletionCountsPreview
|
||||
|
||||
|
||||
# Mock asyncio.run to properly handle coroutines
|
||||
|
|
@ -378,13 +379,18 @@ class TestCognifyCommandEdgeCases:
|
|||
class TestDeleteCommandEdgeCases:
|
||||
"""Test edge cases for DeleteCommand"""
|
||||
|
||||
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
|
||||
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
||||
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
|
||||
def test_delete_all_with_user_id(self, mock_asyncio_run, mock_confirm):
|
||||
def test_delete_all_with_user_id(
|
||||
self, mock_asyncio_run, mock_confirm, mock_get_deletion_counts
|
||||
):
|
||||
"""Test delete command with both --all and --user-id"""
|
||||
# Mock the cognee module
|
||||
mock_cognee = MagicMock()
|
||||
mock_cognee.delete = AsyncMock()
|
||||
mock_get_deletion_counts = AsyncMock()
|
||||
mock_get_deletion_counts.return_value = DeletionCountsPreview()
|
||||
|
||||
with patch.dict(sys.modules, {"cognee": mock_cognee}):
|
||||
command = DeleteCommand()
|
||||
|
|
@ -396,13 +402,17 @@ class TestDeleteCommandEdgeCases:
|
|||
command.execute(args)
|
||||
|
||||
mock_confirm.assert_called_once_with("Delete ALL data from cognee?")
|
||||
mock_asyncio_run.assert_called_once()
|
||||
assert mock_asyncio_run.call_count == 2
|
||||
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
|
||||
mock_cognee.delete.assert_awaited_once_with(dataset_name=None, user_id="test_user")
|
||||
|
||||
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
|
||||
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
||||
def test_delete_confirmation_keyboard_interrupt(self, mock_confirm):
|
||||
def test_delete_confirmation_keyboard_interrupt(self, mock_confirm, mock_get_deletion_counts):
|
||||
"""Test delete command when user interrupts confirmation"""
|
||||
mock_get_deletion_counts = AsyncMock()
|
||||
mock_get_deletion_counts.return_value = DeletionCountsPreview()
|
||||
|
||||
command = DeleteCommand()
|
||||
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)
|
||||
|
||||
|
|
|
|||
141
cognee/tests/test_advanced_pdf_loader.py
Normal file
141
cognee/tests/test_advanced_pdf_loader.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
import sys
|
||||
from unittest.mock import patch, MagicMock, AsyncMock, mock_open
|
||||
import pytest
|
||||
|
||||
from cognee.infrastructure.loaders.external.advanced_pdf_loader import AdvancedPdfLoader
|
||||
|
||||
advanced_pdf_loader_module = sys.modules.get(
|
||||
"cognee.infrastructure.loaders.external.advanced_pdf_loader"
|
||||
)
|
||||
|
||||
|
||||
class MockElement:
|
||||
def __init__(self, category, text, metadata):
|
||||
self.category = category
|
||||
self.text = text
|
||||
self.metadata = metadata
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"type": self.category,
|
||||
"text": self.text,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader():
|
||||
return AdvancedPdfLoader()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extension, mime_type, expected",
|
||||
[
|
||||
("pdf", "application/pdf", True),
|
||||
("txt", "text/plain", False),
|
||||
("pdf", "text/plain", False),
|
||||
("doc", "application/pdf", False),
|
||||
],
|
||||
)
|
||||
def test_can_handle(loader, extension, mime_type, expected):
|
||||
"""Test can_handle method can correctly identify PDF files"""
|
||||
assert loader.can_handle(extension, mime_type) == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.open", new_callable=mock_open)
|
||||
@patch(
|
||||
"cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_metadata",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.get_storage_config")
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_storage")
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.PyPdfLoader")
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.partition_pdf")
|
||||
async def test_load_success_with_unstructured(
|
||||
mock_partition_pdf,
|
||||
mock_pypdf_loader,
|
||||
mock_get_file_storage,
|
||||
mock_get_storage_config,
|
||||
mock_get_file_metadata,
|
||||
mock_open,
|
||||
loader,
|
||||
):
|
||||
"""Test the main flow of using unstructured to successfully process PDF"""
|
||||
# Prepare Mock data and objects
|
||||
mock_elements = [
|
||||
MockElement(
|
||||
category="Title", text="Attention Is All You Need", metadata={"page_number": 1}
|
||||
),
|
||||
MockElement(
|
||||
category="NarrativeText",
|
||||
text="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks.",
|
||||
metadata={"page_number": 1},
|
||||
),
|
||||
MockElement(
|
||||
category="Table",
|
||||
text="This is a table.",
|
||||
metadata={"page_number": 2, "text_as_html": "<table><tr><td>Data</td></tr></table>"},
|
||||
),
|
||||
]
|
||||
mock_pypdf_loader.return_value.load = AsyncMock(return_value="/fake/path/fallback.txt")
|
||||
mock_partition_pdf.return_value = mock_elements
|
||||
mock_get_file_metadata.return_value = {"content_hash": "abc123def456"}
|
||||
|
||||
mock_storage_instance = MagicMock()
|
||||
mock_storage_instance.store = AsyncMock(return_value="/stored/text_abc123def456.txt")
|
||||
mock_get_file_storage.return_value = mock_storage_instance
|
||||
|
||||
mock_get_storage_config.return_value = {"data_root_directory": "/fake/data/root"}
|
||||
test_file_path = "/fake/path/document.pdf"
|
||||
|
||||
# Run
|
||||
|
||||
result_path = await loader.load(test_file_path)
|
||||
|
||||
# Assert
|
||||
assert result_path == "/stored/text_abc123def456.txt"
|
||||
|
||||
# Verify partition_pdf is called with the correct parameters
|
||||
mock_partition_pdf.assert_called_once()
|
||||
call_args, call_kwargs = mock_partition_pdf.call_args
|
||||
assert call_kwargs.get("filename") == test_file_path
|
||||
assert call_kwargs.get("strategy") == "auto" # Default strategy
|
||||
|
||||
# Verify the stored content is correct
|
||||
expected_content = "Page 1:\nAttention Is All You Need\n\nThe dominant sequence transduction models are based on complex recurrent or convolutional neural networks.\n\nPage 2:\n<table><tr><td>Data</td></tr></table>\n"
|
||||
mock_storage_instance.store.assert_awaited_once_with("text_abc123def456.txt", expected_content)
|
||||
|
||||
# Verify fallback is not called
|
||||
mock_pypdf_loader.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.open", new_callable=mock_open)
|
||||
@patch(
|
||||
"cognee.infrastructure.loaders.external.advanced_pdf_loader.get_file_metadata",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
@patch("cognee.infrastructure.loaders.external.advanced_pdf_loader.PyPdfLoader")
|
||||
@patch(
|
||||
"cognee.infrastructure.loaders.external.advanced_pdf_loader.partition_pdf",
|
||||
side_effect=Exception("Unstructured failed!"),
|
||||
)
|
||||
async def test_load_fallback_on_unstructured_exception(
|
||||
mock_partition_pdf, mock_pypdf_loader, mock_get_file_metadata, mock_open, loader
|
||||
):
|
||||
"""Test fallback to PyPdfLoader when unstructured throws an exception"""
|
||||
# Prepare Mock
|
||||
mock_fallback_instance = MagicMock()
|
||||
mock_fallback_instance.load = AsyncMock(return_value="/fake/path/fallback.txt")
|
||||
mock_pypdf_loader.return_value = mock_fallback_instance
|
||||
mock_get_file_metadata.return_value = {"content_hash": "anyhash"}
|
||||
test_file_path = "/fake/path/document.pdf"
|
||||
|
||||
# Run
|
||||
result_path = await loader.load(test_file_path)
|
||||
|
||||
# Assert
|
||||
assert result_path == "/fake/path/fallback.txt"
|
||||
mock_partition_pdf.assert_called_once() # Verify partition_pdf is called
|
||||
mock_fallback_instance.load.assert_awaited_once_with(test_file_path)
|
||||
|
|
@ -41,7 +41,12 @@ class TestCogneeServerStart(unittest.TestCase):
|
|||
def tearDownClass(cls):
|
||||
# Terminate the server process
|
||||
if hasattr(cls, "server_process") and cls.server_process:
|
||||
os.killpg(os.getpgid(cls.server_process.pid), signal.SIGTERM)
|
||||
if hasattr(os, "killpg"):
|
||||
# Unix-like systems: Use process groups
|
||||
os.killpg(os.getpgid(cls.server_process.pid), signal.SIGTERM)
|
||||
else:
|
||||
# Windows: Just terminate the main process
|
||||
cls.server_process.terminate()
|
||||
cls.server_process.wait()
|
||||
|
||||
def test_server_is_running(self):
|
||||
|
|
|
|||
1634
poetry.lock
generated
1634
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -95,7 +95,7 @@ chromadb = [
|
|||
"chromadb>=0.6,<0.7",
|
||||
"pypika==0.48.9",
|
||||
]
|
||||
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx]>=0.18.1,<19"]
|
||||
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx, pdf]>=0.18.1,<19"]
|
||||
codegraph = [
|
||||
"fastembed<=0.6.0 ; python_version < '3.13'",
|
||||
"transformers>=4.46.3,<5",
|
||||
|
|
@ -142,6 +142,7 @@ Homepage = "https://www.cognee.ai"
|
|||
Repository = "https://github.com/topoteretes/cognee"
|
||||
|
||||
[project.scripts]
|
||||
cognee = "cognee.cli._cognee:main"
|
||||
cognee-cli = "cognee.cli._cognee:main"
|
||||
|
||||
[build-system]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue