Merge branch 'dev' into add_cli
This commit is contained in:
commit
1a15669779
52 changed files with 1793 additions and 1032 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -37,6 +37,7 @@ share/python-wheels/
|
|||
.installed.cfg
|
||||
*.egg
|
||||
.python-version
|
||||
cognee-mcp/.python-version
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
|
|
|
|||
|
|
@ -9,6 +9,12 @@ Create Date: 2025-07-24 17:11:52.174737
|
|||
import os
|
||||
from typing import Sequence, Union
|
||||
|
||||
from cognee.infrastructure.databases.graph.kuzu.kuzu_migrate import (
|
||||
kuzu_migration,
|
||||
read_kuzu_storage_version,
|
||||
)
|
||||
import kuzu
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b9274c27a25a"
|
||||
down_revision: Union[str, None] = "e4ebee1091e7"
|
||||
|
|
@ -18,38 +24,48 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||
|
||||
def upgrade() -> None:
|
||||
# This migration is only for multi-user Cognee mode
|
||||
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
return
|
||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
base_config = get_base_config()
|
||||
|
||||
base_config = get_base_config()
|
||||
databases_root = os.path.join(base_config.system_root_directory, "databases")
|
||||
if not os.path.isdir(databases_root):
|
||||
raise FileNotFoundError(f"Directory not found: {databases_root}")
|
||||
|
||||
databases_root = os.path.join(base_config.system_root_directory, "databases")
|
||||
if not os.path.isdir(databases_root):
|
||||
raise FileNotFoundError(f"Directory not found: {databases_root}")
|
||||
for current_path, dirnames, _ in os.walk(databases_root):
|
||||
# If file is kuzu graph database
|
||||
if ".pkl" in current_path[-4:]:
|
||||
kuzu_db_version = read_kuzu_storage_version(current_path)
|
||||
if (
|
||||
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
|
||||
) and kuzu_db_version != kuzu.__version__:
|
||||
# Try to migrate kuzu database to latest version
|
||||
kuzu_migration(
|
||||
new_db=current_path + "_new",
|
||||
old_db=current_path,
|
||||
new_version=kuzu.__version__,
|
||||
old_version=kuzu_db_version,
|
||||
overwrite=True,
|
||||
)
|
||||
else:
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
|
||||
for current_path, dirnames, _ in os.walk(databases_root):
|
||||
# If file is kuzu graph database
|
||||
if ".pkl" in current_path[-4:]:
|
||||
from cognee.infrastructure.databases.graph.kuzu.kuzu_migrate import (
|
||||
kuzu_migration,
|
||||
read_kuzu_storage_version,
|
||||
)
|
||||
import kuzu
|
||||
|
||||
kuzu_db_version = read_kuzu_storage_version(current_path)
|
||||
if (
|
||||
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
|
||||
) and kuzu_db_version != kuzu.__version__:
|
||||
# Try to migrate kuzu database to latest version
|
||||
kuzu_migration(
|
||||
new_db=current_path + "new",
|
||||
old_db=current_path,
|
||||
new_version=kuzu.__version__,
|
||||
old_version=kuzu_db_version,
|
||||
overwrite=True,
|
||||
)
|
||||
graph_config = get_graph_config()
|
||||
if graph_config.graph_database_provider.lower() == "kuzu":
|
||||
if os.path.exists(graph_config.graph_file_path):
|
||||
kuzu_db_version = read_kuzu_storage_version(graph_config.graph_file_path)
|
||||
if (
|
||||
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
|
||||
) and kuzu_db_version != kuzu.__version__:
|
||||
# Try to migrate kuzu database to latest version
|
||||
kuzu_migration(
|
||||
new_db=graph_config.graph_file_path + "_new",
|
||||
old_db=graph_config.graph_file_path,
|
||||
new_version=kuzu.__version__,
|
||||
old_version=kuzu_db_version,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState } from "react";
|
||||
import { MutableRefObject, useEffect, useImperativeHandle, useRef, useState, useCallback } from "react";
|
||||
import { forceCollide, forceManyBody } from "d3-force-3d";
|
||||
import ForceGraph, { ForceGraphMethods, GraphData, LinkObject, NodeObject } from "react-force-graph-2d";
|
||||
import { GraphControlsAPI } from "./GraphControls";
|
||||
|
|
@ -22,6 +22,45 @@ export default function GraphVisualization({ ref, data, graphControls }: GraphVi
|
|||
const nodeSize = 15;
|
||||
// const addNodeDistanceFromSourceNode = 15;
|
||||
|
||||
// State for tracking container dimensions
|
||||
const [dimensions, setDimensions] = useState({ width: 0, height: 0 });
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Handle resize
|
||||
const handleResize = useCallback(() => {
|
||||
if (containerRef.current) {
|
||||
const { clientWidth, clientHeight } = containerRef.current;
|
||||
setDimensions({ width: clientWidth, height: clientHeight });
|
||||
|
||||
// Trigger graph refresh after resize
|
||||
if (graphRef.current) {
|
||||
// Small delay to ensure DOM has updated
|
||||
setTimeout(() => {
|
||||
graphRef.current?.zoomToFit(1000,50);
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Set up resize observer
|
||||
useEffect(() => {
|
||||
// Initial size calculation
|
||||
handleResize();
|
||||
|
||||
// ResizeObserver
|
||||
const resizeObserver = new ResizeObserver(() => {
|
||||
handleResize();
|
||||
});
|
||||
|
||||
if (containerRef.current) {
|
||||
resizeObserver.observe(containerRef.current);
|
||||
}
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
}, [handleResize]);
|
||||
|
||||
const handleNodeClick = (node: NodeObject) => {
|
||||
graphControls.current?.setSelectedNode(node);
|
||||
// ref.current?.d3ReheatSimulation()
|
||||
|
|
@ -174,10 +213,12 @@ export default function GraphVisualization({ ref, data, graphControls }: GraphVi
|
|||
}));
|
||||
|
||||
return (
|
||||
<div className="w-full h-full" id="graph-container">
|
||||
<div ref={containerRef} className="w-full h-full" id="graph-container">
|
||||
{(data && typeof window !== "undefined") ? (
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
dagMode={graphShape as unknown as undefined}
|
||||
dagLevelDistance={300}
|
||||
onDagError={handleDagError}
|
||||
|
|
@ -201,6 +242,8 @@ export default function GraphVisualization({ ref, data, graphControls }: GraphVi
|
|||
) : (
|
||||
<ForceGraph
|
||||
ref={graphRef}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
dagMode={graphShape as unknown as undefined}
|
||||
dagLevelDistance={100}
|
||||
graphData={{
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
3.11.5
|
||||
|
|
@ -51,7 +51,7 @@ RUN apt-get update && apt-get install -y \
|
|||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=uv /root/.local /root/.local
|
||||
COPY --from=uv /usr/local /usr/local
|
||||
COPY --from=uv /app /app
|
||||
|
||||
RUN chmod +x /app/entrypoint.sh
|
||||
|
|
|
|||
|
|
@ -48,27 +48,27 @@ if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then
|
|||
if [ "$DEBUG" = "true" ]; then
|
||||
echo "Waiting for the debugger to attach..."
|
||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport sse
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport sse --no-migration
|
||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport http --host 0.0.0.0 --port $HTTP_PORT
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
else
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport stdio
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m cognee --transport stdio --no-migration
|
||||
fi
|
||||
else
|
||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||
exec cognee --transport sse
|
||||
exec cognee --transport sse --no-migration
|
||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT
|
||||
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
else
|
||||
exec cognee --transport stdio
|
||||
exec cognee --transport stdio --no-migration
|
||||
fi
|
||||
fi
|
||||
else
|
||||
if [ "$TRANSPORT_MODE" = "sse" ]; then
|
||||
exec cognee --transport sse
|
||||
exec cognee --transport sse --no-migration
|
||||
elif [ "$TRANSPORT_MODE" = "http" ]; then
|
||||
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT
|
||||
exec cognee --transport http --host 0.0.0.0 --port $HTTP_PORT --no-migration
|
||||
else
|
||||
exec cognee --transport stdio
|
||||
exec cognee --transport stdio --no-migration
|
||||
fi
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ requires-python = ">=3.10"
|
|||
dependencies = [
|
||||
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
|
||||
# "cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/vasilije/Projects/tiktok/cognee",
|
||||
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]>=0.2.0,<1.0.0",
|
||||
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.2.1",
|
||||
"fastmcp>=2.10.0,<3.0.0",
|
||||
"mcp>=1.12.0,<2.0.0",
|
||||
"uv>=0.6.3,<1.0.0",
|
||||
|
|
|
|||
|
|
@ -123,11 +123,34 @@ async def cognee_add_developer_rules(
|
|||
@mcp.tool()
|
||||
async def cognify(data: str, graph_model_file: str = None, graph_model_name: str = None) -> list:
|
||||
"""
|
||||
Transform data into a structured knowledge graph in Cognee's memory layer.
|
||||
Transform ingested data into a structured knowledge graph.
|
||||
|
||||
This function launches a background task that processes the provided text/file location and
|
||||
generates a knowledge graph representation. The function returns immediately while
|
||||
the processing continues in the background due to MCP timeout constraints.
|
||||
This is the core processing step in Cognee that converts raw text and documents
|
||||
into an intelligent knowledge graph. It analyzes content, extracts entities and
|
||||
relationships, and creates semantic connections for enhanced search and reasoning.
|
||||
|
||||
Prerequisites:
|
||||
- **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation)
|
||||
- **Data Added**: Must have data previously added via `cognee.add()`
|
||||
- **Vector Database**: Must be accessible for embeddings storage
|
||||
- **Graph Database**: Must be accessible for relationship storage
|
||||
|
||||
Input Requirements:
|
||||
- **Content Types**: Works with any text-extractable content including:
|
||||
* Natural language documents
|
||||
* Structured data (CSV, JSON)
|
||||
* Code repositories
|
||||
* Academic papers and technical documentation
|
||||
* Mixed multimedia content (with text extraction)
|
||||
|
||||
Processing Pipeline:
|
||||
1. **Document Classification**: Identifies document types and structures
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
5. **Relationship Detection**: Discovers connections between entities
|
||||
6. **Graph Construction**: Builds semantic knowledge graph with embeddings
|
||||
7. **Content Summarization**: Creates hierarchical summaries for navigation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -152,11 +175,60 @@ async def cognify(data: str, graph_model_file: str = None, graph_model_name: str
|
|||
A list containing a single TextContent object with information about the
|
||||
background task launch and how to check its status.
|
||||
|
||||
Next Steps:
|
||||
After successful cognify processing, use search functions to query the knowledge:
|
||||
|
||||
```python
|
||||
import cognee
|
||||
from cognee import SearchType
|
||||
|
||||
# Process your data into knowledge graph
|
||||
await cognee.cognify()
|
||||
|
||||
# Query for insights using different search types:
|
||||
|
||||
# 1. Natural language completion with graph context
|
||||
insights = await cognee.search(
|
||||
"What are the main themes?",
|
||||
query_type=SearchType.GRAPH_COMPLETION
|
||||
)
|
||||
|
||||
# 2. Get entity relationships and connections
|
||||
relationships = await cognee.search(
|
||||
"connections between concepts",
|
||||
query_type=SearchType.INSIGHTS
|
||||
)
|
||||
|
||||
# 3. Find relevant document chunks
|
||||
chunks = await cognee.search(
|
||||
"specific topic",
|
||||
query_type=SearchType.CHUNKS
|
||||
)
|
||||
```
|
||||
|
||||
Environment Variables:
|
||||
Required:
|
||||
- LLM_API_KEY: API key for your LLM provider
|
||||
|
||||
Optional:
|
||||
- LLM_PROVIDER, LLM_MODEL, VECTOR_DB_PROVIDER, GRAPH_DATABASE_PROVIDER
|
||||
- LLM_RATE_LIMIT_ENABLED: Enable rate limiting (default: False)
|
||||
- LLM_RATE_LIMIT_REQUESTS: Max requests per interval (default: 60)
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The function launches a background task and returns immediately
|
||||
- The actual cognify process may take significant time depending on text length
|
||||
- Use the cognify_status tool to check the progress of the operation
|
||||
|
||||
Raises
|
||||
------
|
||||
InvalidValueError
|
||||
If LLM_API_KEY is not set
|
||||
ValueError
|
||||
If chunks exceed max token limits (reduce chunk_size)
|
||||
DatabaseNotCreatedError
|
||||
If databases are not properly initialized
|
||||
"""
|
||||
|
||||
async def cognify_task(
|
||||
|
|
@ -327,17 +399,69 @@ async def codify(repo_path: str) -> list:
|
|||
@mcp.tool()
|
||||
async def search(search_query: str, search_type: str) -> list:
|
||||
"""
|
||||
Search the Cognee knowledge graph for information relevant to the query.
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
||||
This function executes a search against the Cognee knowledge graph using the
|
||||
specified query and search type. It returns formatted results based on the
|
||||
search type selected.
|
||||
This is the final step in the Cognee workflow that retrieves information from the
|
||||
processed knowledge graph. It supports multiple search modes optimized for different
|
||||
use cases - from simple fact retrieval to complex reasoning and code analysis.
|
||||
|
||||
Search Prerequisites:
|
||||
- **LLM_API_KEY**: Required for GRAPH_COMPLETION and RAG_COMPLETION search types
|
||||
- **Data Added**: Must have data previously added via `cognee.add()`
|
||||
- **Knowledge Graph Built**: Must have processed data via `cognee.cognify()`
|
||||
- **Vector Database**: Must be accessible for semantic search functionality
|
||||
|
||||
Search Types & Use Cases:
|
||||
|
||||
**GRAPH_COMPLETION** (Recommended):
|
||||
Natural language Q&A using full graph context and LLM reasoning.
|
||||
Best for: Complex questions, analysis, summaries, insights.
|
||||
Returns: Conversational AI responses with graph-backed context.
|
||||
|
||||
**RAG_COMPLETION**:
|
||||
Traditional RAG using document chunks without graph structure.
|
||||
Best for: Direct document retrieval, specific fact-finding.
|
||||
Returns: LLM responses based on relevant text chunks.
|
||||
|
||||
**INSIGHTS**:
|
||||
Structured entity relationships and semantic connections.
|
||||
Best for: Understanding concept relationships, knowledge mapping.
|
||||
Returns: Formatted relationship data and entity connections.
|
||||
|
||||
**CHUNKS**:
|
||||
Raw text segments that match the query semantically.
|
||||
Best for: Finding specific passages, citations, exact content.
|
||||
Returns: Ranked list of relevant text chunks with metadata.
|
||||
|
||||
**SUMMARIES**:
|
||||
Pre-generated hierarchical summaries of content.
|
||||
Best for: Quick overviews, document abstracts, topic summaries.
|
||||
Returns: Multi-level summaries from detailed to high-level.
|
||||
|
||||
**CODE**:
|
||||
Code-specific search with syntax and semantic understanding.
|
||||
Best for: Finding functions, classes, implementation patterns.
|
||||
Returns: Structured code information with context and relationships.
|
||||
|
||||
**CYPHER**:
|
||||
Direct graph database queries using Cypher syntax.
|
||||
Best for: Advanced users, specific graph traversals, debugging.
|
||||
Returns: Raw graph query results.
|
||||
|
||||
**FEELING_LUCKY**:
|
||||
Intelligently selects and runs the most appropriate search type.
|
||||
Best for: General-purpose queries or when you're unsure which search type is best.
|
||||
Returns: The results from the automatically selected search type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
search_query : str
|
||||
The search query in natural language. This can be a question, instruction, or
|
||||
any text that expresses what information is needed from the knowledge graph.
|
||||
Your question or search query in natural language.
|
||||
Examples:
|
||||
- "What are the main themes in this research?"
|
||||
- "How do these concepts relate to each other?"
|
||||
- "Find information about machine learning algorithms"
|
||||
- "What functions handle user authentication?"
|
||||
|
||||
search_type : str
|
||||
The type of search to perform. Valid options include:
|
||||
|
|
@ -346,6 +470,9 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
- "CODE": Returns code-related knowledge in JSON format
|
||||
- "CHUNKS": Returns raw text chunks from the knowledge graph
|
||||
- "INSIGHTS": Returns relationships between nodes in readable format
|
||||
- "SUMMARIES": Returns pre-generated hierarchical summaries
|
||||
- "CYPHER": Direct graph database queries
|
||||
- "FEELING_LUCKY": Automatically selects best search type
|
||||
|
||||
The search_type is case-insensitive and will be converted to uppercase.
|
||||
|
||||
|
|
@ -354,16 +481,45 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
list
|
||||
A list containing a single TextContent object with the search results.
|
||||
The format of the result depends on the search_type:
|
||||
- For CODE: JSON-formatted search results
|
||||
- For GRAPH_COMPLETION/RAG_COMPLETION: A single text completion
|
||||
- For CHUNKS: String representation of the raw chunks
|
||||
- For INSIGHTS: Formatted string showing node relationships
|
||||
- For other types: String representation of the search results
|
||||
- **GRAPH_COMPLETION/RAG_COMPLETION**: Conversational AI response strings
|
||||
- **INSIGHTS**: Formatted relationship descriptions and entity connections
|
||||
- **CHUNKS**: Relevant text passages with source metadata
|
||||
- **SUMMARIES**: Hierarchical summaries from general to specific
|
||||
- **CODE**: Structured code information with context
|
||||
- **FEELING_LUCKY**: Results in format of automatically selected search type
|
||||
- **CYPHER**: Raw graph query results
|
||||
|
||||
Performance & Optimization:
|
||||
- **GRAPH_COMPLETION**: Slower but most intelligent, uses LLM + graph context
|
||||
- **RAG_COMPLETION**: Medium speed, uses LLM + document chunks (no graph traversal)
|
||||
- **INSIGHTS**: Fast, returns structured relationships without LLM processing
|
||||
- **CHUNKS**: Fastest, pure vector similarity search without LLM
|
||||
- **SUMMARIES**: Fast, returns pre-computed summaries
|
||||
- **CODE**: Medium speed, specialized for code understanding
|
||||
- **FEELING_LUCKY**: Variable speed, uses LLM + search type selection intelligently
|
||||
|
||||
Environment Variables:
|
||||
Required for LLM-based search types (GRAPH_COMPLETION, RAG_COMPLETION):
|
||||
- LLM_API_KEY: API key for your LLM provider
|
||||
|
||||
Optional:
|
||||
- LLM_PROVIDER, LLM_MODEL: Configure LLM for search responses
|
||||
- VECTOR_DB_PROVIDER: Must match what was used during cognify
|
||||
- GRAPH_DATABASE_PROVIDER: Must match what was used during cognify
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Different search types produce different output formats
|
||||
- The function handles the conversion between Cognee's internal result format and MCP's output format
|
||||
|
||||
Raises
|
||||
------
|
||||
InvalidValueError
|
||||
If LLM_API_KEY is not set (for LLM-based search types)
|
||||
ValueError
|
||||
If query_text is empty or search parameters are invalid
|
||||
NoDataError
|
||||
If no relevant data found for the search query
|
||||
"""
|
||||
|
||||
async def search_task(search_query: str, search_type: str) -> str:
|
||||
|
|
@ -380,7 +536,7 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
elif (
|
||||
search_type.upper() == "GRAPH_COMPLETION" or search_type.upper() == "RAG_COMPLETION"
|
||||
):
|
||||
return search_results[0]
|
||||
return str(search_results[0])
|
||||
elif search_type.upper() == "CHUNKS":
|
||||
return str(search_results)
|
||||
elif search_type.upper() == "INSIGHTS":
|
||||
|
|
@ -782,30 +938,38 @@ async def main():
|
|||
help="Log level for the HTTP server (default: info)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run Alembic migrations from the main cognee directory where alembic.ini is located
|
||||
print("Running database migrations...")
|
||||
migration_result = subprocess.run(
|
||||
["alembic", "upgrade", "head"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=Path(__file__).resolve().parent.parent.parent,
|
||||
parser.add_argument(
|
||||
"--no-migration",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Argument stops database migration from being attempted",
|
||||
)
|
||||
|
||||
if migration_result.returncode != 0:
|
||||
migration_output = migration_result.stderr + migration_result.stdout
|
||||
# Check for the expected UserAlreadyExists error (which is not critical)
|
||||
if (
|
||||
"UserAlreadyExists" in migration_output
|
||||
or "User default_user@example.com already exists" in migration_output
|
||||
):
|
||||
print("Warning: Default user already exists, continuing startup...")
|
||||
else:
|
||||
print(f"Migration failed with unexpected error: {migration_output}")
|
||||
sys.exit(1)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Database migrations done.")
|
||||
if not args.no_migration:
|
||||
# Run Alembic migrations from the main cognee directory where alembic.ini is located
|
||||
logger.info("Running database migrations...")
|
||||
migration_result = subprocess.run(
|
||||
["python", "-m", "alembic", "upgrade", "head"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=Path(__file__).resolve().parent.parent.parent,
|
||||
)
|
||||
|
||||
if migration_result.returncode != 0:
|
||||
migration_output = migration_result.stderr + migration_result.stdout
|
||||
# Check for the expected UserAlreadyExists error (which is not critical)
|
||||
if (
|
||||
"UserAlreadyExists" in migration_output
|
||||
or "User default_user@example.com already exists" in migration_output
|
||||
):
|
||||
logger.warning("Warning: Default user already exists, continuing startup...")
|
||||
else:
|
||||
logger.error(f"Migration failed with unexpected error: {migration_output}")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("Database migrations done.")
|
||||
|
||||
logger.info(f"Starting MCP server with transport: {args.transport}")
|
||||
if args.transport == "stdio":
|
||||
|
|
|
|||
9
cognee-mcp/uv.lock
generated
9
cognee-mcp/uv.lock
generated
|
|
@ -620,7 +620,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "cognee"
|
||||
version = "0.2.1.dev7"
|
||||
version = "0.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
|
|
@ -663,9 +663,10 @@ dependencies = [
|
|||
{ name = "tiktoken" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f4/84/da4d45a0d74f91ac9302ec66c728aa14c93a980fa24fa590f3c30d2f7e23/cognee-0.2.1.dev7.tar.gz", hash = "sha256:fb280ccf900753fd984d23a984440a1e4ca52c1fcba267743fb4149047301c49", size = 15492306 }
|
||||
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/41/46/e7df1faebc92fa31ef8e33faf81feb435782727a789de5532d178e047224/cognee-0.2.1.tar.gz", hash = "sha256:bf5208383fc841981641c040e5b6588e58111af4d771f9eab6552f441e6a8e6c", size = 15497626, upload-time = "2025-07-25T15:53:57.009Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/39/4ea24c1e43c99ff2ec0b2904a12cf6947a26d9a875f92a6be776d7101397/cognee-0.2.1.dev7-py3-none-any.whl", hash = "sha256:5923395fd7ec3fa52bbcf84f1b85c4a3a69004b4c7ed721fa90797c7bd727f22", size = 1016188 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/0e/b705c6eeb538dcdd8fbbb331be25fe8e0bbc1af7d76e61566ec9845b29d3/cognee-0.2.1-py3-none-any.whl", hash = "sha256:6e9d437e0c58a16233841ebf19b1a3d8b67da069460a4f08d0c0e00301b1d36d", size = 1019851, upload-time = "2025-07-25T15:53:53.488Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
|
|
@ -711,7 +712,7 @@ dev = [
|
|||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "cognee", extras = ["postgres", "codegraph", "gemini", "huggingface", "docs", "neo4j"], specifier = ">=0.2.0,<1.0.0" },
|
||||
{ name = "cognee", extras = ["postgres", "codegraph", "gemini", "huggingface", "docs", "neo4j"], specifier = "==0.2.1" },
|
||||
{ name = "fastmcp", specifier = ">=2.10.0,<3.0.0" },
|
||||
{ name = "mcp", specifier = ">=1.12.0,<2.0.0" },
|
||||
{ name = "uv", specifier = ">=0.6.3,<1.0.0" },
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ async def add(
|
|||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
dataset_id: Optional[UUID] = None,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
"""
|
||||
Add data to Cognee for knowledge graph processing.
|
||||
|
|
@ -153,6 +154,7 @@ async def add(
|
|||
pipeline_name="add_pipeline",
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
):
|
||||
pipeline_run_info = run_info
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from typing import List, Optional, Union, Literal
|
|||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.pipelines.models import PipelineRunErrored
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -100,6 +101,8 @@ def get_add_router() -> APIRouter:
|
|||
else:
|
||||
add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId)
|
||||
|
||||
if isinstance(add_run, PipelineRunErrored):
|
||||
return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
|
||||
return add_run.model_dump()
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
|
|
|||
|
|
@ -79,7 +79,9 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
|
|||
async for run_status in non_code_pipeline_run:
|
||||
yield run_status
|
||||
|
||||
async for run_status in run_tasks(tasks, dataset.id, repo_path, user, "cognify_code_pipeline"):
|
||||
async for run_status in run_tasks(
|
||||
tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False
|
||||
):
|
||||
yield run_status
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ async def cognify(
|
|||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
run_in_background: bool = False,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
"""
|
||||
Transform ingested data into a structured knowledge graph.
|
||||
|
|
@ -194,6 +195,7 @@ async def cognify(
|
|||
datasets=datasets,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
else:
|
||||
return await run_cognify_blocking(
|
||||
|
|
@ -202,6 +204,7 @@ async def cognify(
|
|||
datasets=datasets,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -211,6 +214,7 @@ async def run_cognify_blocking(
|
|||
datasets,
|
||||
graph_db_config: dict = None,
|
||||
vector_db_config: dict = False,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
total_run_info = {}
|
||||
|
||||
|
|
@ -221,6 +225,7 @@ async def run_cognify_blocking(
|
|||
pipeline_name="cognify_pipeline",
|
||||
graph_db_config=graph_db_config,
|
||||
vector_db_config=vector_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
):
|
||||
if run_info.dataset_id:
|
||||
total_run_info[run_info.dataset_id] = run_info
|
||||
|
|
@ -236,6 +241,7 @@ async def run_cognify_as_background_process(
|
|||
datasets,
|
||||
graph_db_config: dict = None,
|
||||
vector_db_config: dict = False,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
# Convert dataset to list if it's a string
|
||||
if isinstance(datasets, str):
|
||||
|
|
@ -246,6 +252,7 @@ async def run_cognify_as_background_process(
|
|||
|
||||
async def handle_rest_of_the_run(pipeline_list):
|
||||
# Execute all provided pipelines one by one to avoid database write conflicts
|
||||
# TODO: Convert to async gather task instead of for loop when Queue mechanism for database is created
|
||||
for pipeline in pipeline_list:
|
||||
while True:
|
||||
try:
|
||||
|
|
@ -270,6 +277,7 @@ async def run_cognify_as_background_process(
|
|||
pipeline_name="cognify_pipeline",
|
||||
graph_db_config=graph_db_config,
|
||||
vector_db_config=vector_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
|
||||
# Save dataset Pipeline run started info
|
||||
|
|
|
|||
|
|
@ -16,7 +16,11 @@ from cognee.modules.graph.methods import get_formatted_graph_data
|
|||
from cognee.modules.users.get_user_manager import get_user_manager_context
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.authentication.default.default_jwt_strategy import DefaultJWTStrategy
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import PipelineRunCompleted, PipelineRunInfo
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunCompleted,
|
||||
PipelineRunInfo,
|
||||
PipelineRunErrored,
|
||||
)
|
||||
from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
|
||||
get_from_queue,
|
||||
initialize_queue,
|
||||
|
|
@ -105,6 +109,9 @@ def get_cognify_router() -> APIRouter:
|
|||
datasets, user, run_in_background=payload.run_in_background
|
||||
)
|
||||
|
||||
# If any cognify run errored return JSONResponse with proper error status code
|
||||
if any(isinstance(v, PipelineRunErrored) for v in cognify_run.values()):
|
||||
return JSONResponse(status_code=420, content=cognify_run)
|
||||
return cognify_run
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
|
|
|||
|
|
@ -353,7 +353,7 @@ def get_datasets_router() -> APIRouter:
|
|||
|
||||
@router.get("/status", response_model=dict[str, PipelineRunStatus])
|
||||
async def get_dataset_status(
|
||||
datasets: Annotated[List[UUID], Query(alias="dataset")] = None,
|
||||
datasets: Annotated[List[UUID], Query(alias="dataset")] = [],
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -71,6 +71,12 @@ async def search(
|
|||
Best for: Advanced users, specific graph traversals, debugging.
|
||||
Returns: Raw graph query results.
|
||||
|
||||
**FEELING_LUCKY**:
|
||||
Intelligently selects and runs the most appropriate search type.
|
||||
Best for: General-purpose queries or when you're unsure which search type is best.
|
||||
Returns: The results from the automatically selected search type.
|
||||
|
||||
|
||||
Args:
|
||||
query_text: Your question or search query in natural language.
|
||||
Examples:
|
||||
|
|
@ -119,6 +125,9 @@ async def search(
|
|||
**CODE**:
|
||||
[List of structured code information with context]
|
||||
|
||||
**FEELING_LUCKY**:
|
||||
[List of results in the format of the search type that is automatically selected]
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -130,6 +139,7 @@ async def search(
|
|||
- **CHUNKS**: Fastest, pure vector similarity search without LLM
|
||||
- **SUMMARIES**: Fast, returns pre-computed summaries
|
||||
- **CODE**: Medium speed, specialized for code understanding
|
||||
- **FEELING_LUCKY**: Variable speed, uses LLM + search type selection intelligently
|
||||
- **top_k**: Start with 10, increase for comprehensive analysis (max 100)
|
||||
- **datasets**: Specify datasets to improve speed and relevance
|
||||
|
||||
|
|
|
|||
|
|
@ -86,12 +86,11 @@ class KuzuAdapter(GraphDBInterface):
|
|||
if (
|
||||
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
|
||||
) and kuzu_db_version != kuzu.__version__:
|
||||
# TODO: Write migration script that will handle all user graph databases in multi-user mode
|
||||
# Try to migrate kuzu database to latest version
|
||||
from .kuzu_migrate import kuzu_migration
|
||||
|
||||
kuzu_migration(
|
||||
new_db=self.db_path + "new",
|
||||
new_db=self.db_path + "_new",
|
||||
old_db=self.db_path,
|
||||
new_version=kuzu.__version__,
|
||||
old_version=kuzu_db_version,
|
||||
|
|
@ -1464,11 +1463,8 @@ class KuzuAdapter(GraphDBInterface):
|
|||
It raises exceptions for failures occurring during deletion processes.
|
||||
"""
|
||||
try:
|
||||
# Use DETACH DELETE to remove both nodes and their relationships in one operation
|
||||
await self.query("MATCH (n:Node) DETACH DELETE n")
|
||||
logger.info("Cleared all data from graph while preserving structure")
|
||||
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
self.connection = None
|
||||
if self.db:
|
||||
self.db.close()
|
||||
|
|
|
|||
|
|
@ -94,6 +94,7 @@ def ensure_env(version: str, export_dir) -> str:
|
|||
|
||||
print(f"→ Setting up venv for Kùzu {version}...", file=sys.stderr)
|
||||
# Create venv
|
||||
# NOTE: Running python in debug mode can cause issues with creating a virtual environment from that python instance
|
||||
subprocess.run([sys.executable, "-m", "venv", base], check=True)
|
||||
# Install the specific Kùzu version
|
||||
subprocess.run([py_bin, "-m", "pip", "install", "--upgrade", "pip"], check=True)
|
||||
|
|
|
|||
|
|
@ -410,6 +410,38 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
return await self.query(query, params)
|
||||
|
||||
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Flatten edge properties to handle nested dictionaries like weights.
|
||||
|
||||
Neo4j doesn't support nested dictionaries as property values, so we need to
|
||||
flatten the 'weights' dictionary into individual properties with prefixes.
|
||||
|
||||
Args:
|
||||
properties: Dictionary of edge properties that may contain nested dicts
|
||||
|
||||
Returns:
|
||||
Flattened properties dictionary suitable for Neo4j storage
|
||||
"""
|
||||
flattened = {}
|
||||
|
||||
for key, value in properties.items():
|
||||
if key == "weights" and isinstance(value, dict):
|
||||
# Flatten weights dictionary into individual properties
|
||||
for weight_name, weight_value in value.items():
|
||||
flattened[f"weight_{weight_name}"] = weight_value
|
||||
elif isinstance(value, dict):
|
||||
# For other nested dictionaries, serialize as JSON string
|
||||
flattened[f"{key}_json"] = json.dumps(value, cls=JSONEncoder)
|
||||
elif isinstance(value, list):
|
||||
# For lists, serialize as JSON string
|
||||
flattened[f"{key}_json"] = json.dumps(value, cls=JSONEncoder)
|
||||
else:
|
||||
# Keep primitive types as-is
|
||||
flattened[key] = value
|
||||
|
||||
return flattened
|
||||
|
||||
@record_graph_changes
|
||||
@override_distributed(queued_add_edges)
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||
|
|
@ -448,11 +480,13 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
"properties": {
|
||||
**(edge[3] if edge[3] else {}),
|
||||
"source_node_id": str(edge[0]),
|
||||
"target_node_id": str(edge[1]),
|
||||
},
|
||||
"properties": self._flatten_edge_properties(
|
||||
{
|
||||
**(edge[3] if edge[3] else {}),
|
||||
"source_node_id": str(edge[0]),
|
||||
"target_node_id": str(edge[1]),
|
||||
}
|
||||
),
|
||||
}
|
||||
for edge in edges
|
||||
]
|
||||
|
|
|
|||
|
|
@ -177,7 +177,12 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
elif "mistral" in self.provider.lower():
|
||||
tokenizer = MistralTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
else:
|
||||
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
|
||||
try:
|
||||
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
|
||||
logger.info("Switching to TikToken default tokenizer.")
|
||||
tokenizer = TikTokenTokenizer(model=None, max_tokens=self.max_tokens)
|
||||
|
||||
logger.debug(f"Tokenizer loaded for model: {self.model}")
|
||||
return tokenizer
|
||||
|
|
|
|||
|
|
@ -0,0 +1,130 @@
|
|||
You are an expert query analyzer for a **GraphRAG system**. Your primary goal is to analyze a user's query and select the single most appropriate `SearchType` tool to answer it.
|
||||
|
||||
Here are the available `SearchType` tools and their specific functions:
|
||||
|
||||
- **`SUMMARIES`**: The `SUMMARIES` search type retrieves summarized information from the knowledge graph.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Getting concise overviews of topics
|
||||
- Summarizing large amounts of information
|
||||
- Quick understanding of complex subjects
|
||||
|
||||
* **`INSIGHTS`**: The `INSIGHTS` search type discovers connections and relationships between entities in the knowledge graph.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Discovering how entities are connected
|
||||
- Understanding relationships between concepts
|
||||
- Exploring the structure of your knowledge graph
|
||||
|
||||
* **`CHUNKS`**: The `CHUNKS` search type retrieves specific facts and information chunks from the knowledge graph.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Finding specific facts
|
||||
- Getting direct answers to questions
|
||||
- Retrieving precise information
|
||||
|
||||
* **`RAG_COMPLETION`**: Use for direct factual questions that can likely be answered by retrieving a specific text passage from a document. It does not use the graph's relationship structure.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Getting detailed explanations or comprehensive answers
|
||||
- Combining multiple pieces of information
|
||||
- Getting a single, coherent answer that is generated from relevant text passages
|
||||
|
||||
* **`GRAPH_COMPLETION`**: The `GRAPH_COMPLETION` search type leverages the graph structure to provide more contextually aware completions.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Complex queries requiring graph traversal
|
||||
- Questions that benefit from understanding relationships
|
||||
- Queries where context from connected entities matters
|
||||
|
||||
* **`GRAPH_SUMMARY_COMPLETION`**: The `GRAPH_SUMMARY_COMPLETION` search type combines graph traversal with summarization to provide concise but comprehensive answers.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Getting summarized information that requires understanding relationships
|
||||
- Complex topics that need concise explanations
|
||||
- Queries that benefit from both graph structure and summarization
|
||||
|
||||
* **`GRAPH_COMPLETION_COT`**: The `GRAPH_COMPLETION_COT` search type combines graph traversal with chain of thought to provide answers to complex multi hop questions.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Multi-hop questions that require following several linked concepts or entities
|
||||
- Tracing relational paths in a knowledge graph while also getting clear step-by-step reasoning
|
||||
- Summarizing completx linkages into a concise, human-readable answer once all hops have been explored
|
||||
|
||||
* **`GRAPH_COMPLETION_CONTEXT_EXTENSION`**: The `GRAPH_COMPLETION_CONTEXT_EXTENSION` search type combines graph traversal with multi-round context extension.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Iterative, multi-hop queries where intermediate facts aren’t all present upfront
|
||||
- Complex linkages that benefit from multi-round “search → extend context → reason” loops to uncover deep connections.
|
||||
- Sparse or evolving graphs that require on-the-fly expansion—issuing follow-up searches to discover missing nodes or properties
|
||||
|
||||
* **`CODE`**: The `CODE` search type is specialized for retrieving and understanding code-related information from the knowledge graph.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Code-related queries
|
||||
- Programming examples and patterns
|
||||
- Technical documentation searches
|
||||
|
||||
* **`CYPHER`**: The `CYPHER` search type allows user to execute raw Cypher queries directly against your graph database.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Executing precise graph queries with full control
|
||||
- Leveraging Cypher features and functions
|
||||
- Getting raw data directly from the graph database
|
||||
|
||||
* **`NATURAL_LANGUAGE`**: The `NATURAL_LANGUAGE` search type translates a natural language question into a precise Cypher query that is executed directly against the graph database.
|
||||
|
||||
**Best for:**
|
||||
|
||||
- Getting precise, structured answers from the graph using natural language.
|
||||
- Performing advanced graph operations like filtering and aggregating data using natural language.
|
||||
- Asking precise, database-style questions without needing to write Cypher.
|
||||
|
||||
**Examples:**
|
||||
|
||||
Query: "Summarize the key findings from these research papers"
|
||||
Response: `SUMMARIES`
|
||||
|
||||
Query: "What is the relationship between the methodologies used in these papers?"
|
||||
Response: `INSIGHTS`
|
||||
|
||||
Query: "When was Einstein born?"
|
||||
Response: `CHUNKS`
|
||||
|
||||
Query: "Explain Einstein's contributions to physics"
|
||||
Response: `RAG_COMPLETION`
|
||||
|
||||
Query: "Provide a comprehensive analysis of how these papers contribute to the field"
|
||||
Response: `GRAPH_COMPLETION`
|
||||
|
||||
Query: "Explain the overall architecture of this codebase"
|
||||
Response: `GRAPH_SUMMARY_COMPLETION`
|
||||
|
||||
Query: "Who was the father of the person who invented the lightbulb"
|
||||
Response: `GRAPH_COMPLETION_COT`
|
||||
|
||||
Query: "What county was XY born in"
|
||||
Response: `GRAPH_COMPLETION_CONTEXT_EXTENSION`
|
||||
|
||||
Query: "How to implement authentication in this codebase"
|
||||
Response: `CODE`
|
||||
|
||||
Query: "MATCH (n) RETURN labels(n) as types, n.name as name LIMIT 10"
|
||||
Response: `CYPHER`
|
||||
|
||||
Query: "Get all nodes connected to John"
|
||||
Response: `NATURAL_LANGUAGE`
|
||||
|
||||
|
||||
|
||||
Your response MUST be a single word, consisting of only the chosen `SearchType` name. Do not provide any explanation.
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Any
|
||||
from typing import List, Any, Optional
|
||||
import tiktoken
|
||||
|
||||
from ..tokenizer_interface import TokenizerInterface
|
||||
|
|
@ -12,13 +12,17 @@ class TikTokenTokenizer(TokenizerInterface):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
model: Optional[str] = None,
|
||||
max_tokens: int = 8191,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
# Initialize TikToken for GPT based on model
|
||||
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
||||
if model:
|
||||
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
||||
else:
|
||||
# Use default if model not provided
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
from sqlalchemy import UUID, Column, DateTime, String, JSON, Integer
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
|
@ -21,7 +22,11 @@ class Data(Base):
|
|||
tenant_id = Column(UUID, index=True, nullable=True)
|
||||
content_hash = Column(String)
|
||||
external_metadata = Column(JSON)
|
||||
node_set = Column(JSON, nullable=True) # Store NodeSet as JSON list of strings
|
||||
# Store NodeSet as JSON list of strings
|
||||
node_set = Column(JSON, nullable=True)
|
||||
# MutableDict allows SQLAlchemy to notice key-value pair changes, without it changing a value for a key
|
||||
# wouldn't be noticed when commiting a database session
|
||||
pipeline_status = Column(MutableDict.as_mutable(JSON))
|
||||
token_count = Column(Integer)
|
||||
data_size = Column(Integer, nullable=True) # File size in bytes
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from cognee.modules.chunking.Chunker import Chunker
|
|||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
|
||||
from .Document import Document
|
||||
from .exceptions.exceptions import PyPdfInternalError
|
||||
|
||||
logger = get_logger("PDFDocument")
|
||||
|
||||
|
|
@ -17,18 +16,12 @@ class PdfDocument(Document):
|
|||
async with open_data_file(self.raw_data_location, mode="rb") as stream:
|
||||
logger.info(f"Reading PDF: {self.raw_data_location}")
|
||||
|
||||
try:
|
||||
file = PdfReader(stream, strict=False)
|
||||
except Exception:
|
||||
raise PyPdfInternalError()
|
||||
file = PdfReader(stream, strict=False)
|
||||
|
||||
async def get_text():
|
||||
try:
|
||||
for page in file.pages:
|
||||
page_text = page.extract_text()
|
||||
yield page_text
|
||||
except Exception:
|
||||
raise PyPdfInternalError()
|
||||
for page in file.pages:
|
||||
page_text = page.extract_text()
|
||||
yield page_text
|
||||
|
||||
chunker = chunker_cls(self, get_text=get_text, max_chunk_size=max_chunk_size)
|
||||
|
||||
|
|
|
|||
5
cognee/modules/engine/utils/generate_edge_id.py
Normal file
5
cognee/modules/engine/utils/generate_edge_id.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
|
||||
def generate_edge_id(edge_id: str) -> str:
|
||||
return uuid5(NAMESPACE_OID, edge_id.lower().replace(" ", "_").replace("'", ""))
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import time
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import List, Dict, Union, Optional, Type
|
||||
|
||||
|
|
@ -154,38 +155,34 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
raise ValueError("Failed to generate query embedding.")
|
||||
|
||||
if edge_distances is None:
|
||||
start_time = time.time()
|
||||
edge_distances = await vector_engine.search(
|
||||
collection_name="EdgeType_relationship_name",
|
||||
query_vector=query_vector,
|
||||
limit=0,
|
||||
)
|
||||
projection_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
|
||||
)
|
||||
|
||||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
|
||||
for edge in self.edges:
|
||||
relationship_type = edge.attributes.get("relationship_type")
|
||||
if relationship_type and relationship_type in embedding_map:
|
||||
edge.attributes["vector_distance"] = embedding_map[relationship_type]
|
||||
distance = embedding_map.get(relationship_type, None)
|
||||
if distance is not None:
|
||||
edge.attributes["vector_distance"] = distance
|
||||
|
||||
except Exception as ex:
|
||||
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
|
||||
raise ex
|
||||
|
||||
async def calculate_top_triplet_importances(self, k: int) -> List:
|
||||
min_heap = []
|
||||
def score(edge):
|
||||
n1 = edge.node1.attributes.get("vector_distance", 1)
|
||||
n2 = edge.node2.attributes.get("vector_distance", 1)
|
||||
e = edge.attributes.get("vector_distance", 1)
|
||||
return n1 + n2 + e
|
||||
|
||||
for i, edge in enumerate(self.edges):
|
||||
source_node = self.get_node(edge.node1.id)
|
||||
target_node = self.get_node(edge.node2.id)
|
||||
|
||||
source_distance = source_node.attributes.get("vector_distance", 1) if source_node else 1
|
||||
target_distance = target_node.attributes.get("vector_distance", 1) if target_node else 1
|
||||
edge_distance = edge.attributes.get("vector_distance", 1)
|
||||
|
||||
total_distance = source_distance + target_distance + edge_distance
|
||||
|
||||
heapq.heappush(min_heap, (-total_distance, i, edge))
|
||||
if len(min_heap) > k:
|
||||
heapq.heappop(min_heap)
|
||||
|
||||
return [edge for _, _, edge in sorted(min_heap)]
|
||||
return heapq.nsmallest(k, self.edges, key=score)
|
||||
|
|
|
|||
1
cognee/modules/pipelines/exceptions/__init__.py
Normal file
1
cognee/modules/pipelines/exceptions/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .exceptions import PipelineRunFailedError
|
||||
12
cognee/modules/pipelines/exceptions/exceptions.py
Normal file
12
cognee/modules/pipelines/exceptions/exceptions.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from cognee.exceptions import CogneeApiError
|
||||
from fastapi import status
|
||||
|
||||
|
||||
class PipelineRunFailedError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Pipeline run failed.",
|
||||
name: str = "PipelineRunFailedError",
|
||||
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
5
cognee/modules/pipelines/models/DataItemStatus.py
Normal file
5
cognee/modules/pipelines/models/DataItemStatus.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
import enum
|
||||
|
||||
|
||||
class DataItemStatus(str, enum.Enum):
|
||||
DATA_ITEM_PROCESSING_COMPLETED = "DATA_ITEM_PROCESSING_COMPLETED"
|
||||
|
|
@ -9,6 +9,7 @@ class PipelineRunInfo(BaseModel):
|
|||
dataset_id: UUID
|
||||
dataset_name: str
|
||||
payload: Optional[Any] = None
|
||||
data_ingestion_info: Optional[list] = None
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
|
|
@ -30,6 +31,11 @@ class PipelineRunCompleted(PipelineRunInfo):
|
|||
pass
|
||||
|
||||
|
||||
class PipelineRunAlreadyCompleted(PipelineRunInfo):
|
||||
status: str = "PipelineRunAlreadyCompleted"
|
||||
pass
|
||||
|
||||
|
||||
class PipelineRunErrored(PipelineRunInfo):
|
||||
status: str = "PipelineRunErrored"
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -6,3 +6,4 @@ from .PipelineRunInfo import (
|
|||
PipelineRunCompleted,
|
||||
PipelineRunErrored,
|
||||
)
|
||||
from .DataItemStatus import DataItemStatus
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ async def cognee_pipeline(
|
|||
pipeline_name: str = "custom_pipeline",
|
||||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
# Note: These context variables allow different value assignment for databases in Cognee
|
||||
# per async task, thread, process and etc.
|
||||
|
|
@ -106,6 +107,7 @@ async def cognee_pipeline(
|
|||
data=data,
|
||||
pipeline_name=pipeline_name,
|
||||
context={"dataset": dataset},
|
||||
incremental_loading=incremental_loading,
|
||||
):
|
||||
yield run_info
|
||||
|
||||
|
|
@ -117,6 +119,7 @@ async def run_pipeline(
|
|||
data=None,
|
||||
pipeline_name: str = "custom_pipeline",
|
||||
context: dict = None,
|
||||
incremental_loading=True,
|
||||
):
|
||||
check_dataset_name(dataset.name)
|
||||
|
||||
|
|
@ -184,7 +187,9 @@ async def run_pipeline(
|
|||
if not isinstance(task, Task):
|
||||
raise ValueError(f"Task {task} is not an instance of Task")
|
||||
|
||||
pipeline_run = run_tasks(tasks, dataset_id, data, user, pipeline_name, context)
|
||||
pipeline_run = run_tasks(
|
||||
tasks, dataset_id, data, user, pipeline_name, context, incremental_loading
|
||||
)
|
||||
|
||||
async for pipeline_run_info in pipeline_run:
|
||||
yield pipeline_run_info
|
||||
|
|
|
|||
|
|
@ -1,21 +1,31 @@
|
|||
import os
|
||||
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
from typing import Any
|
||||
from functools import wraps
|
||||
from sqlalchemy import select
|
||||
|
||||
import cognee.modules.ingestion as ingestion
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines.utils import generate_pipeline_id
|
||||
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
|
||||
from cognee.tasks.ingestion import save_data_item_to_storage, resolve_data_directories
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunCompleted,
|
||||
PipelineRunErrored,
|
||||
PipelineRunStarted,
|
||||
PipelineRunYield,
|
||||
PipelineRunAlreadyCompleted,
|
||||
)
|
||||
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
|
||||
|
||||
from cognee.modules.pipelines.operations import (
|
||||
log_pipeline_run_start,
|
||||
|
|
@ -56,34 +66,116 @@ async def run_tasks(
|
|||
user: User = None,
|
||||
pipeline_name: str = "unknown_pipeline",
|
||||
context: dict = None,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
async def _run_tasks_data_item_incremental(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_name,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
context,
|
||||
user,
|
||||
):
|
||||
db_engine = get_relational_engine()
|
||||
# If incremental_loading of data is set to True don't process documents already processed by pipeline
|
||||
# If data is being added to Cognee for the first time calculate the id of the data
|
||||
if not isinstance(data_item, Data):
|
||||
file_path = await save_data_item_to_storage(data_item)
|
||||
# Ingest data and add metadata
|
||||
async with open_data_file(file_path) as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
# data_id is the hash of file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
else:
|
||||
# If data was already processed by Cognee get data id
|
||||
data_id = data_item.id
|
||||
|
||||
# Get Dataset object
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
from cognee.modules.data.models import Dataset
|
||||
# Check pipeline status, if Data already processed for pipeline before skip current processing
|
||||
async with db_engine.get_async_session() as session:
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
if data_point:
|
||||
if (
|
||||
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
|
||||
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||
):
|
||||
yield {
|
||||
"run_info": PipelineRunAlreadyCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
return
|
||||
|
||||
dataset = await session.get(Dataset, dataset_id)
|
||||
try:
|
||||
# Process data based on data_item and list of tasks
|
||||
async for result in run_tasks_with_telemetry(
|
||||
tasks=tasks,
|
||||
data=[data_item],
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
):
|
||||
yield PipelineRunYield(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=result,
|
||||
)
|
||||
|
||||
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
||||
# Update pipeline status for Data element
|
||||
async with db_engine.get_async_session() as session:
|
||||
data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id))
|
||||
).scalar_one_or_none()
|
||||
data_point.pipeline_status[pipeline_name] = {
|
||||
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
|
||||
}
|
||||
await session.merge(data_point)
|
||||
await session.commit()
|
||||
|
||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
||||
yield {
|
||||
"run_info": PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
|
||||
pipeline_run_id = pipeline_run.pipeline_run_id
|
||||
except Exception as error:
|
||||
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
|
||||
logger.error(
|
||||
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
|
||||
)
|
||||
yield {
|
||||
"run_info": PipelineRunErrored(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
payload=repr(error),
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
),
|
||||
"data_id": data_id,
|
||||
}
|
||||
|
||||
yield PipelineRunStarted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=data,
|
||||
)
|
||||
|
||||
try:
|
||||
async def _run_tasks_data_item_regular(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
context,
|
||||
user,
|
||||
):
|
||||
# Process data based on data_item and list of tasks
|
||||
async for result in run_tasks_with_telemetry(
|
||||
tasks=tasks,
|
||||
data=data,
|
||||
data=[data_item],
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
|
|
@ -95,6 +187,112 @@ async def run_tasks(
|
|||
payload=result,
|
||||
)
|
||||
|
||||
yield {
|
||||
"run_info": PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
)
|
||||
}
|
||||
|
||||
async def _run_tasks_data_item(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_name,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
context,
|
||||
user,
|
||||
incremental_loading,
|
||||
):
|
||||
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
|
||||
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
|
||||
result = None
|
||||
if incremental_loading:
|
||||
async for result in _run_tasks_data_item_incremental(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_name=pipeline_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
):
|
||||
pass
|
||||
else:
|
||||
async for result in _run_tasks_data_item_regular(
|
||||
data_item=data_item,
|
||||
dataset=dataset,
|
||||
tasks=tasks,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
user=user,
|
||||
):
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
# Get Dataset object
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
from cognee.modules.data.models import Dataset
|
||||
|
||||
dataset = await session.get(Dataset, dataset_id)
|
||||
|
||||
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
||||
pipeline_run_id = pipeline_run.pipeline_run_id
|
||||
|
||||
yield PipelineRunStarted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=data,
|
||||
)
|
||||
|
||||
try:
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
if incremental_loading:
|
||||
data = await resolve_data_directories(data)
|
||||
|
||||
# Create async tasks per data item that will run the pipeline for the data item
|
||||
data_item_tasks = [
|
||||
asyncio.create_task(
|
||||
_run_tasks_data_item(
|
||||
data_item,
|
||||
dataset,
|
||||
tasks,
|
||||
pipeline_name,
|
||||
pipeline_id,
|
||||
pipeline_run_id,
|
||||
context,
|
||||
user,
|
||||
incremental_loading,
|
||||
)
|
||||
)
|
||||
for data_item in data
|
||||
]
|
||||
results = await asyncio.gather(*data_item_tasks)
|
||||
# Remove skipped data items from results
|
||||
results = [result for result in results if result]
|
||||
|
||||
# If any data item could not be processed propagate error
|
||||
errored_results = [
|
||||
result for result in results if isinstance(result["run_info"], PipelineRunErrored)
|
||||
]
|
||||
if errored_results:
|
||||
raise PipelineRunFailedError(
|
||||
message="Pipeline run failed. Data item could not be processed."
|
||||
)
|
||||
|
||||
await log_pipeline_run_complete(
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
|
||||
)
|
||||
|
|
@ -103,6 +301,7 @@ async def run_tasks(
|
|||
pipeline_run_id=pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
data_ingestion_info=results,
|
||||
)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
|
@ -120,9 +319,14 @@ async def run_tasks(
|
|||
|
||||
yield PipelineRunErrored(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
payload=error,
|
||||
payload=repr(error),
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
data_ingestion_info=locals().get(
|
||||
"results"
|
||||
), # Returns results if they exist or returns None
|
||||
)
|
||||
|
||||
raise error
|
||||
# In case of error during incremental loading of data just let the user know the pipeline Errored, don't raise error
|
||||
if not isinstance(error, PipelineRunFailedError):
|
||||
raise error
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
|
|
@ -174,6 +175,8 @@ async def brute_force_search(
|
|||
return []
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[search_in_collection(collection_name) for collection_name in collections]
|
||||
)
|
||||
|
|
@ -181,6 +184,12 @@ async def brute_force_search(
|
|||
if all(not item for item in results):
|
||||
return []
|
||||
|
||||
# Final statistics
|
||||
projection_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
|
||||
)
|
||||
|
||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
|
||||
edge_distances = node_distances.get("EdgeType_relationship_name", None)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from cognee.modules.users.models import User
|
|||
from cognee.modules.data.models import Dataset
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||
from cognee.modules.search.operations import log_query, log_result
|
||||
from cognee.modules.search.operations import log_query, log_result, select_search_type
|
||||
|
||||
|
||||
async def search(
|
||||
|
|
@ -129,6 +129,10 @@ async def specific_search(
|
|||
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
|
||||
}
|
||||
|
||||
# If the query type is FEELING_LUCKY, select the search type intelligently
|
||||
if query_type is SearchType.FEELING_LUCKY:
|
||||
query_type = await select_search_type(query)
|
||||
|
||||
search_task = search_tasks.get(query_type)
|
||||
|
||||
if search_task is None:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .log_query import log_query
|
||||
from .log_result import log_result
|
||||
from .get_history import get_history
|
||||
from .select_search_type import select_search_type
|
||||
|
|
|
|||
43
cognee/modules/search/operations/select_search_type.py
Normal file
43
cognee/modules/search/operations/select_search_type.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger("SearchTypeSelector")
|
||||
|
||||
|
||||
async def select_search_type(
|
||||
query: str,
|
||||
system_prompt_path: str = "search_type_selector_prompt.txt",
|
||||
) -> SearchType:
|
||||
"""
|
||||
Analyzes the query and Selects the best search type.
|
||||
|
||||
Args:
|
||||
query: The query to analyze.
|
||||
system_prompt_path: The path to the system prompt.
|
||||
|
||||
Returns:
|
||||
The best search type given by the LLM.
|
||||
"""
|
||||
default_search_type = SearchType.RAG_COMPLETION
|
||||
system_prompt = read_query_prompt(system_prompt_path)
|
||||
llm_client = get_llm_client()
|
||||
|
||||
try:
|
||||
response = await llm_client.acreate_structured_output(
|
||||
text_input=query,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
|
||||
if response.upper() in SearchType.__members__:
|
||||
logger.info(f"Selected lucky search type: {response.upper()}")
|
||||
return SearchType(response.upper())
|
||||
|
||||
# If the response is not a valid search type, return the default search type
|
||||
logger.info(f"LLM gives an invalid search type: {response.upper()}")
|
||||
return default_search_type
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to select search type intelligently from LLM: {str(e)}")
|
||||
return default_search_type
|
||||
|
|
@ -13,3 +13,4 @@ class SearchType(Enum):
|
|||
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
||||
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|
||||
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
|
||||
FEELING_LUCKY = "FEELING_LUCKY"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from cognee.modules.data.models import Data
|
|||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from cognee.modules.data.processing.document_types.exceptions.exceptions import PyPdfInternalError
|
||||
|
||||
|
||||
async def update_document_token_count(document_id: UUID, token_count: int) -> None:
|
||||
|
|
@ -40,15 +39,14 @@ async def extract_chunks_from_documents(
|
|||
"""
|
||||
for document in documents:
|
||||
document_token_count = 0
|
||||
try:
|
||||
async for document_chunk in document.read(
|
||||
max_chunk_size=max_chunk_size, chunker_cls=chunker
|
||||
):
|
||||
document_token_count += document_chunk.chunk_size
|
||||
document_chunk.belongs_to_set = document.belongs_to_set
|
||||
yield document_chunk
|
||||
|
||||
await update_document_token_count(document.id, document_token_count)
|
||||
except PyPdfInternalError:
|
||||
pass
|
||||
async for document_chunk in document.read(
|
||||
max_chunk_size=max_chunk_size, chunker_cls=chunker
|
||||
):
|
||||
document_token_count += document_chunk.chunk_size
|
||||
document_chunk.belongs_to_set = document.belongs_to_set
|
||||
yield document_chunk
|
||||
|
||||
await update_document_token_count(document.id, document_token_count)
|
||||
|
||||
# todo rita
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@ from uuid import UUID
|
|||
from typing import Union, BinaryIO, Any, List, Optional
|
||||
|
||||
import cognee.modules.ingestion as ingestion
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.modules.data.methods import (
|
||||
get_authorized_existing_datasets,
|
||||
get_dataset_data,
|
||||
|
|
@ -134,6 +134,7 @@ async def ingest_data(
|
|||
node_set=json.dumps(node_set) if node_set else None,
|
||||
data_size=file_metadata["file_size"],
|
||||
tenant_id=user.tenant_id if user.tenant_id else None,
|
||||
pipeline_status={},
|
||||
token_count=-1,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ async def resolve_data_directories(
|
|||
if include_subdirectories:
|
||||
base_path = item if item.endswith("/") else item + "/"
|
||||
s3_keys = fs.glob(base_path + "**")
|
||||
# If path is not directory attempt to add item directly
|
||||
if not s3_keys:
|
||||
s3_keys = fs.ls(item)
|
||||
else:
|
||||
s3_keys = fs.ls(item)
|
||||
# Filter out keys that represent directories using fs.isdir
|
||||
|
|
|
|||
|
|
@ -103,6 +103,9 @@ async def get_repo_file_dependencies(
|
|||
extraction of dependencies (default is False). (default False)
|
||||
"""
|
||||
|
||||
if isinstance(repo_path, list) and len(repo_path) == 1:
|
||||
repo_path = repo_path[0]
|
||||
|
||||
if not os.path.exists(repo_path):
|
||||
raise FileNotFoundError(f"Repository path {repo_path} does not exist.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import asyncio
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||
|
|
@ -6,6 +8,9 @@ from cognee.infrastructure.engine import DataPoint
|
|||
|
||||
logger = get_logger("index_data_points")
|
||||
|
||||
# A single lock shared by all coroutines
|
||||
vector_index_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def index_data_points(data_points: list[DataPoint]):
|
||||
created_indexes = {}
|
||||
|
|
@ -22,9 +27,11 @@ async def index_data_points(data_points: list[DataPoint]):
|
|||
|
||||
index_name = f"{data_point_type.__name__}_{field_name}"
|
||||
|
||||
if index_name not in created_indexes:
|
||||
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
||||
created_indexes[index_name] = True
|
||||
# Add async lock to make sure two different coroutines won't create a table at the same time
|
||||
async with vector_index_lock:
|
||||
if index_name not in created_indexes:
|
||||
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
||||
created_indexes[index_name] = True
|
||||
|
||||
if index_name not in index_points:
|
||||
index_points[index_name] = []
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from collections import Counter
|
||||
|
||||
|
|
@ -49,7 +50,9 @@ async def index_graph_edges(batch_size: int = 1024):
|
|||
)
|
||||
|
||||
for text, count in edge_types.items():
|
||||
edge = EdgeType(relationship_name=text, number_of_edges=count)
|
||||
edge = EdgeType(
|
||||
id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count
|
||||
)
|
||||
data_point_type = type(edge)
|
||||
|
||||
for field_name in edge.metadata["index_fields"]:
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ async def test_deduplication():
|
|||
explanation_file_path2 = os.path.join(
|
||||
pathlib.Path(__file__).parent, "test_data/Natural_language_processing_copy.txt"
|
||||
)
|
||||
await cognee.add([explanation_file_path], dataset_name)
|
||||
await cognee.add([explanation_file_path2], dataset_name2)
|
||||
await cognee.add([explanation_file_path], dataset_name, incremental_loading=False)
|
||||
await cognee.add([explanation_file_path2], dataset_name2, incremental_loading=False)
|
||||
|
||||
result = await relational_engine.get_all_data_from_table("data")
|
||||
assert len(result) == 1, "More than one data entity was found."
|
||||
|
|
|
|||
|
|
@ -28,18 +28,38 @@ class TestGraphCompletionRetriever:
|
|||
|
||||
class Company(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
description: str
|
||||
works_for: Company
|
||||
|
||||
company1 = Company(name="Figma")
|
||||
company2 = Company(name="Canva")
|
||||
person1 = Person(name="Steve Rodger", works_for=company1)
|
||||
person2 = Person(name="Ike Loma", works_for=company1)
|
||||
person3 = Person(name="Jason Statham", works_for=company1)
|
||||
person4 = Person(name="Mike Broski", works_for=company2)
|
||||
person5 = Person(name="Christina Mayer", works_for=company2)
|
||||
company1 = Company(name="Figma", description="Figma is a company")
|
||||
company2 = Company(name="Canva", description="Canvas is a company")
|
||||
person1 = Person(
|
||||
name="Steve Rodger",
|
||||
description="This is description about Steve Rodger",
|
||||
works_for=company1,
|
||||
)
|
||||
person2 = Person(
|
||||
name="Ike Loma", description="This is description about Ike Loma", works_for=company1
|
||||
)
|
||||
person3 = Person(
|
||||
name="Jason Statham",
|
||||
description="This is description about Jason Statham",
|
||||
works_for=company1,
|
||||
)
|
||||
person4 = Person(
|
||||
name="Mike Broski",
|
||||
description="This is description about Mike Broski",
|
||||
works_for=company2,
|
||||
)
|
||||
person5 = Person(
|
||||
name="Christina Mayer",
|
||||
description="This is description about Christina Mayer",
|
||||
works_for=company2,
|
||||
)
|
||||
|
||||
entities = [company1, company2, person1, person2, person3, person4, person5]
|
||||
|
||||
|
|
@ -49,8 +69,63 @@ class TestGraphCompletionRetriever:
|
|||
|
||||
context = await retriever.get_context("Who works at Canva?")
|
||||
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
|
||||
# Ensure the top-level sections are present
|
||||
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
|
||||
assert "Connections:" in context, "Missing 'Connections:' section in context"
|
||||
|
||||
# --- Nodes headers ---
|
||||
assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger"
|
||||
assert "Node: Figma" in context, "Missing node header for Figma"
|
||||
assert "Node: Ike Loma" in context, "Missing node header for Ike Loma"
|
||||
assert "Node: Jason Statham" in context, "Missing node header for Jason Statham"
|
||||
assert "Node: Mike Broski" in context, "Missing node header for Mike Broski"
|
||||
assert "Node: Canva" in context, "Missing node header for Canva"
|
||||
assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer"
|
||||
|
||||
# --- Node contents ---
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Steve Rodger altered"
|
||||
assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, (
|
||||
"Description block for Figma altered"
|
||||
)
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Ike Loma\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Ike Loma altered"
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Jason Statham\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Jason Statham altered"
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Mike Broski\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Mike Broski altered"
|
||||
assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, (
|
||||
"Description block for Canva altered"
|
||||
)
|
||||
assert (
|
||||
"__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__"
|
||||
in context
|
||||
), "Description block for Christina Mayer altered"
|
||||
|
||||
# --- Connections ---
|
||||
assert "Steve Rodger --[works_for]--> Figma" in context, (
|
||||
"Connection Steve Rodger→Figma missing or changed"
|
||||
)
|
||||
assert "Ike Loma --[works_for]--> Figma" in context, (
|
||||
"Connection Ike Loma→Figma missing or changed"
|
||||
)
|
||||
assert "Jason Statham --[works_for]--> Figma" in context, (
|
||||
"Connection Jason Statham→Figma missing or changed"
|
||||
)
|
||||
assert "Mike Broski --[works_for]--> Canva" in context, (
|
||||
"Connection Mike Broski→Canva missing or changed"
|
||||
)
|
||||
assert "Christina Mayer --[works_for]--> Canva" in context, (
|
||||
"Connection Christina Mayer→Canva missing or changed"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_complex(self):
|
||||
|
|
|
|||
|
|
@ -155,6 +155,61 @@ async def test_specific_search_chunks(mock_send_telemetry, mock_chunks_retriever
|
|||
assert results[0]["content"] == "Chunk result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"selected_type, retriever_name, expected_content, top_k",
|
||||
[
|
||||
(SearchType.RAG_COMPLETION, "CompletionRetriever", "RAG result from lucky search", 10),
|
||||
(SearchType.CHUNKS, "ChunksRetriever", "Chunk result from lucky search", 5),
|
||||
(SearchType.SUMMARIES, "SummariesRetriever", "Summary from lucky search", 15),
|
||||
(SearchType.INSIGHTS, "InsightsRetriever", "Insight result from lucky search", 20),
|
||||
],
|
||||
)
|
||||
@patch.object(search_module, "select_search_type")
|
||||
@patch.object(search_module, "send_telemetry")
|
||||
async def test_specific_search_feeling_lucky(
|
||||
mock_send_telemetry,
|
||||
mock_select_search_type,
|
||||
selected_type,
|
||||
retriever_name,
|
||||
expected_content,
|
||||
top_k,
|
||||
mock_user,
|
||||
):
|
||||
with patch.object(search_module, retriever_name) as mock_retriever_class:
|
||||
# Setup
|
||||
query = f"test query for {retriever_name}"
|
||||
query_type = SearchType.FEELING_LUCKY
|
||||
|
||||
# Mock the intelligent search type selection
|
||||
mock_select_search_type.return_value = selected_type
|
||||
|
||||
# Mock the retriever
|
||||
mock_retriever_instance = MagicMock()
|
||||
mock_retriever_instance.get_completion = AsyncMock(
|
||||
return_value=[{"content": expected_content}]
|
||||
)
|
||||
mock_retriever_class.return_value = mock_retriever_instance
|
||||
|
||||
# Execute
|
||||
results = await specific_search(query_type, query, mock_user, top_k=top_k)
|
||||
|
||||
# Verify
|
||||
mock_select_search_type.assert_called_once_with(query)
|
||||
|
||||
if retriever_name == "CompletionRetriever":
|
||||
mock_retriever_class.assert_called_once_with(
|
||||
system_prompt_path="answer_simple_question.txt", top_k=top_k
|
||||
)
|
||||
else:
|
||||
mock_retriever_class.assert_called_once_with(top_k=top_k)
|
||||
|
||||
mock_retriever_instance.get_completion.assert_called_once_with(query)
|
||||
mock_send_telemetry.assert_called()
|
||||
assert len(results) == 1
|
||||
assert results[0]["content"] == expected_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_specific_search_invalid_type(mock_user):
|
||||
# Setup
|
||||
|
|
|
|||
|
|
@ -43,10 +43,10 @@ sleep 2
|
|||
if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then
|
||||
if [ "$DEBUG" = "true" ]; then
|
||||
echo "Waiting for the debugger to attach..."
|
||||
debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m gunicorn -w 3 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app
|
||||
debugpy --wait-for-client --listen 0.0.0.0:$DEBUG_PORT -m gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app
|
||||
else
|
||||
gunicorn -w 3 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app
|
||||
gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level debug --reload cognee.api.client:app
|
||||
fi
|
||||
else
|
||||
gunicorn -w 3 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level error cognee.api.client:app
|
||||
gunicorn -w 1 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:$HTTP_PORT --log-level error cognee.api.client:app
|
||||
fi
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
863
poetry.lock
generated
863
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "cognee"
|
||||
version = "0.2.1-dev7"
|
||||
version = "0.2.2.dev0"
|
||||
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
||||
authors = [
|
||||
{ name = "Vasilije Markovic" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue