Merge branch 'dev' into COG-2082

This commit is contained in:
Igor Ilic 2025-08-05 19:15:03 +02:00
commit a09d2d0b3c
35 changed files with 701 additions and 120 deletions

View file

@ -15,6 +15,7 @@ async def add(
vector_db_config: dict = None,
graph_db_config: dict = None,
dataset_id: Optional[UUID] = None,
incremental_loading: bool = True,
):
"""
Add data to Cognee for knowledge graph processing.
@ -153,6 +154,7 @@ async def add(
pipeline_name="add_pipeline",
vector_db_config=vector_db_config,
graph_db_config=graph_db_config,
incremental_loading=incremental_loading,
):
pipeline_run_info = run_info

View file

@ -11,6 +11,7 @@ from typing import List, Optional, Union, Literal
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
logger = get_logger()
@ -100,6 +101,8 @@ def get_add_router() -> APIRouter:
else:
add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId)
if isinstance(add_run, PipelineRunErrored):
return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
return add_run.model_dump()
except Exception as error:
return JSONResponse(status_code=409, content={"error": str(error)})

View file

@ -79,7 +79,9 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
async for run_status in non_code_pipeline_run:
yield run_status
async for run_status in run_tasks(tasks, dataset.id, repo_path, user, "cognify_code_pipeline"):
async for run_status in run_tasks(
tasks, dataset.id, repo_path, user, "cognify_code_pipeline", incremental_loading=False
):
yield run_status

View file

@ -39,6 +39,7 @@ async def cognify(
vector_db_config: dict = None,
graph_db_config: dict = None,
run_in_background: bool = False,
incremental_loading: bool = True,
):
"""
Transform ingested data into a structured knowledge graph.
@ -194,6 +195,7 @@ async def cognify(
datasets=datasets,
vector_db_config=vector_db_config,
graph_db_config=graph_db_config,
incremental_loading=incremental_loading,
)
else:
return await run_cognify_blocking(
@ -202,6 +204,7 @@ async def cognify(
datasets=datasets,
vector_db_config=vector_db_config,
graph_db_config=graph_db_config,
incremental_loading=incremental_loading,
)
@ -211,6 +214,7 @@ async def run_cognify_blocking(
datasets,
graph_db_config: dict = None,
vector_db_config: dict = False,
incremental_loading: bool = True,
):
total_run_info = {}
@ -221,6 +225,7 @@ async def run_cognify_blocking(
pipeline_name="cognify_pipeline",
graph_db_config=graph_db_config,
vector_db_config=vector_db_config,
incremental_loading=incremental_loading,
):
if run_info.dataset_id:
total_run_info[run_info.dataset_id] = run_info
@ -236,6 +241,7 @@ async def run_cognify_as_background_process(
datasets,
graph_db_config: dict = None,
vector_db_config: dict = False,
incremental_loading: bool = True,
):
# Convert dataset to list if it's a string
if isinstance(datasets, str):
@ -246,6 +252,7 @@ async def run_cognify_as_background_process(
async def handle_rest_of_the_run(pipeline_list):
# Execute all provided pipelines one by one to avoid database write conflicts
# TODO: Convert to async gather task instead of for loop when Queue mechanism for database is created
for pipeline in pipeline_list:
while True:
try:
@ -270,6 +277,7 @@ async def run_cognify_as_background_process(
pipeline_name="cognify_pipeline",
graph_db_config=graph_db_config,
vector_db_config=vector_db_config,
incremental_loading=incremental_loading,
)
# Save dataset Pipeline run started info

View file

@ -16,7 +16,11 @@ from cognee.modules.graph.methods import get_formatted_graph_data
from cognee.modules.users.get_user_manager import get_user_manager_context
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.authentication.default.default_jwt_strategy import DefaultJWTStrategy
from cognee.modules.pipelines.models.PipelineRunInfo import PipelineRunCompleted, PipelineRunInfo
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunCompleted,
PipelineRunInfo,
PipelineRunErrored,
)
from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
get_from_queue,
initialize_queue,
@ -105,6 +109,9 @@ def get_cognify_router() -> APIRouter:
datasets, user, run_in_background=payload.run_in_background
)
# If any cognify run errored return JSONResponse with proper error status code
if any(isinstance(v, PipelineRunErrored) for v in cognify_run.values()):
return JSONResponse(status_code=420, content=cognify_run)
return cognify_run
except Exception as error:
return JSONResponse(status_code=409, content={"error": str(error)})

View file

@ -71,6 +71,12 @@ async def search(
Best for: Advanced users, specific graph traversals, debugging.
Returns: Raw graph query results.
**FEELING_LUCKY**:
Intelligently selects and runs the most appropriate search type.
Best for: General-purpose queries or when you're unsure which search type is best.
Returns: The results from the automatically selected search type.
Args:
query_text: Your question or search query in natural language.
Examples:
@ -119,6 +125,9 @@ async def search(
**CODE**:
[List of structured code information with context]
**FEELING_LUCKY**:
[List of results in the format of the search type that is automatically selected]
@ -130,6 +139,7 @@ async def search(
- **CHUNKS**: Fastest, pure vector similarity search without LLM
- **SUMMARIES**: Fast, returns pre-computed summaries
- **CODE**: Medium speed, specialized for code understanding
- **FEELING_LUCKY**: Variable speed, uses LLM + search type selection intelligently
- **top_k**: Start with 10, increase for comprehensive analysis (max 100)
- **datasets**: Specify datasets to improve speed and relevance

View file

@ -410,6 +410,38 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params)
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
"""
Flatten edge properties to handle nested dictionaries like weights.
Neo4j doesn't support nested dictionaries as property values, so we need to
flatten the 'weights' dictionary into individual properties with prefixes.
Args:
properties: Dictionary of edge properties that may contain nested dicts
Returns:
Flattened properties dictionary suitable for Neo4j storage
"""
flattened = {}
for key, value in properties.items():
if key == "weights" and isinstance(value, dict):
# Flatten weights dictionary into individual properties
for weight_name, weight_value in value.items():
flattened[f"weight_{weight_name}"] = weight_value
elif isinstance(value, dict):
# For other nested dictionaries, serialize as JSON string
flattened[f"{key}_json"] = json.dumps(value, cls=JSONEncoder)
elif isinstance(value, list):
# For lists, serialize as JSON string
flattened[f"{key}_json"] = json.dumps(value, cls=JSONEncoder)
else:
# Keep primitive types as-is
flattened[key] = value
return flattened
@record_graph_changes
@override_distributed(queued_add_edges)
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
@ -448,11 +480,13 @@ class Neo4jAdapter(GraphDBInterface):
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"relationship_name": edge[2],
"properties": {
**(edge[3] if edge[3] else {}),
"source_node_id": str(edge[0]),
"target_node_id": str(edge[1]),
},
"properties": self._flatten_edge_properties(
{
**(edge[3] if edge[3] else {}),
"source_node_id": str(edge[0]),
"target_node_id": str(edge[1]),
}
),
}
for edge in edges
]

View file

@ -185,7 +185,12 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
elif "mistral" in self.provider.lower():
tokenizer = MistralTokenizer(model=model, max_tokens=self.max_tokens)
else:
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
try:
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
except Exception as e:
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
logger.info("Switching to TikToken default tokenizer.")
tokenizer = TikTokenTokenizer(model=None, max_tokens=self.max_tokens)
logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer

View file

@ -0,0 +1,130 @@
You are an expert query analyzer for a **GraphRAG system**. Your primary goal is to analyze a user's query and select the single most appropriate `SearchType` tool to answer it.
Here are the available `SearchType` tools and their specific functions:
- **`SUMMARIES`**: The `SUMMARIES` search type retrieves summarized information from the knowledge graph.
**Best for:**
- Getting concise overviews of topics
- Summarizing large amounts of information
- Quick understanding of complex subjects
* **`INSIGHTS`**: The `INSIGHTS` search type discovers connections and relationships between entities in the knowledge graph.
**Best for:**
- Discovering how entities are connected
- Understanding relationships between concepts
- Exploring the structure of your knowledge graph
* **`CHUNKS`**: The `CHUNKS` search type retrieves specific facts and information chunks from the knowledge graph.
**Best for:**
- Finding specific facts
- Getting direct answers to questions
- Retrieving precise information
* **`RAG_COMPLETION`**: Use for direct factual questions that can likely be answered by retrieving a specific text passage from a document. It does not use the graph's relationship structure.
**Best for:**
- Getting detailed explanations or comprehensive answers
- Combining multiple pieces of information
- Getting a single, coherent answer that is generated from relevant text passages
* **`GRAPH_COMPLETION`**: The `GRAPH_COMPLETION` search type leverages the graph structure to provide more contextually aware completions.
**Best for:**
- Complex queries requiring graph traversal
- Questions that benefit from understanding relationships
- Queries where context from connected entities matters
* **`GRAPH_SUMMARY_COMPLETION`**: The `GRAPH_SUMMARY_COMPLETION` search type combines graph traversal with summarization to provide concise but comprehensive answers.
**Best for:**
- Getting summarized information that requires understanding relationships
- Complex topics that need concise explanations
- Queries that benefit from both graph structure and summarization
* **`GRAPH_COMPLETION_COT`**: The `GRAPH_COMPLETION_COT` search type combines graph traversal with chain of thought to provide answers to complex multi hop questions.
**Best for:**
- Multi-hop questions that require following several linked concepts or entities
- Tracing relational paths in a knowledge graph while also getting clear step-by-step reasoning
- Summarizing completx linkages into a concise, human-readable answer once all hops have been explored
* **`GRAPH_COMPLETION_CONTEXT_EXTENSION`**: The `GRAPH_COMPLETION_CONTEXT_EXTENSION` search type combines graph traversal with multi-round context extension.
**Best for:**
- Iterative, multi-hop queries where intermediate facts arent all present upfront
- Complex linkages that benefit from multi-round “search → extend context → reason” loops to uncover deep connections.
- Sparse or evolving graphs that require on-the-fly expansion—issuing follow-up searches to discover missing nodes or properties
* **`CODE`**: The `CODE` search type is specialized for retrieving and understanding code-related information from the knowledge graph.
**Best for:**
- Code-related queries
- Programming examples and patterns
- Technical documentation searches
* **`CYPHER`**: The `CYPHER` search type allows user to execute raw Cypher queries directly against your graph database.
**Best for:**
- Executing precise graph queries with full control
- Leveraging Cypher features and functions
- Getting raw data directly from the graph database
* **`NATURAL_LANGUAGE`**: The `NATURAL_LANGUAGE` search type translates a natural language question into a precise Cypher query that is executed directly against the graph database.
**Best for:**
- Getting precise, structured answers from the graph using natural language.
- Performing advanced graph operations like filtering and aggregating data using natural language.
- Asking precise, database-style questions without needing to write Cypher.
**Examples:**
Query: "Summarize the key findings from these research papers"
Response: `SUMMARIES`
Query: "What is the relationship between the methodologies used in these papers?"
Response: `INSIGHTS`
Query: "When was Einstein born?"
Response: `CHUNKS`
Query: "Explain Einstein's contributions to physics"
Response: `RAG_COMPLETION`
Query: "Provide a comprehensive analysis of how these papers contribute to the field"
Response: `GRAPH_COMPLETION`
Query: "Explain the overall architecture of this codebase"
Response: `GRAPH_SUMMARY_COMPLETION`
Query: "Who was the father of the person who invented the lightbulb"
Response: `GRAPH_COMPLETION_COT`
Query: "What county was XY born in"
Response: `GRAPH_COMPLETION_CONTEXT_EXTENSION`
Query: "How to implement authentication in this codebase"
Response: `CODE`
Query: "MATCH (n) RETURN labels(n) as types, n.name as name LIMIT 10"
Response: `CYPHER`
Query: "Get all nodes connected to John"
Response: `NATURAL_LANGUAGE`
Your response MUST be a single word, consisting of only the chosen `SearchType` name. Do not provide any explanation.

View file

@ -1,4 +1,4 @@
from typing import List, Any
from typing import List, Any, Optional
import tiktoken
from ..tokenizer_interface import TokenizerInterface
@ -12,13 +12,17 @@ class TikTokenTokenizer(TokenizerInterface):
def __init__(
self,
model: str,
model: Optional[str] = None,
max_tokens: int = 8191,
):
self.model = model
self.max_tokens = max_tokens
# Initialize TikToken for GPT based on model
self.tokenizer = tiktoken.encoding_for_model(self.model)
if model:
self.tokenizer = tiktoken.encoding_for_model(self.model)
else:
# Use default if model not provided
self.tokenizer = tiktoken.get_encoding("cl100k_base")
def extract_tokens(self, text: str) -> List[Any]:
"""

View file

@ -1,6 +1,7 @@
from datetime import datetime, timezone
from uuid import uuid4
from sqlalchemy import UUID, Column, DateTime, String, JSON, Integer
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import relationship
from cognee.infrastructure.databases.relational import Base
@ -21,7 +22,11 @@ class Data(Base):
tenant_id = Column(UUID, index=True, nullable=True)
content_hash = Column(String)
external_metadata = Column(JSON)
node_set = Column(JSON, nullable=True) # Store NodeSet as JSON list of strings
# Store NodeSet as JSON list of strings
node_set = Column(JSON, nullable=True)
# MutableDict allows SQLAlchemy to notice key-value pair changes, without it changing a value for a key
# wouldn't be noticed when commiting a database session
pipeline_status = Column(MutableDict.as_mutable(JSON))
token_count = Column(Integer)
data_size = Column(Integer, nullable=True) # File size in bytes
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))

View file

@ -5,7 +5,6 @@ from cognee.modules.chunking.Chunker import Chunker
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from .Document import Document
from .exceptions.exceptions import PyPdfInternalError
logger = get_logger("PDFDocument")
@ -17,18 +16,12 @@ class PdfDocument(Document):
async with open_data_file(self.raw_data_location, mode="rb") as stream:
logger.info(f"Reading PDF: {self.raw_data_location}")
try:
file = PdfReader(stream, strict=False)
except Exception:
raise PyPdfInternalError()
file = PdfReader(stream, strict=False)
async def get_text():
try:
for page in file.pages:
page_text = page.extract_text()
yield page_text
except Exception:
raise PyPdfInternalError()
for page in file.pages:
page_text = page.extract_text()
yield page_text
chunker = chunker_cls(self, get_text=get_text, max_chunk_size=max_chunk_size)

View file

@ -0,0 +1,5 @@
from uuid import NAMESPACE_OID, uuid5
def generate_edge_id(edge_id: str) -> str:
return uuid5(NAMESPACE_OID, edge_id.lower().replace(" ", "_").replace("'", ""))

View file

@ -170,28 +170,19 @@ class CogneeGraph(CogneeAbstractGraph):
for edge in self.edges:
relationship_type = edge.attributes.get("relationship_type")
if relationship_type and relationship_type in embedding_map:
edge.attributes["vector_distance"] = embedding_map[relationship_type]
distance = embedding_map.get(relationship_type, None)
if distance is not None:
edge.attributes["vector_distance"] = distance
except Exception as ex:
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
raise ex
async def calculate_top_triplet_importances(self, k: int) -> List:
min_heap = []
def score(edge):
n1 = edge.node1.attributes.get("vector_distance", 1)
n2 = edge.node2.attributes.get("vector_distance", 1)
e = edge.attributes.get("vector_distance", 1)
return n1 + n2 + e
for i, edge in enumerate(self.edges):
source_node = self.get_node(edge.node1.id)
target_node = self.get_node(edge.node2.id)
source_distance = source_node.attributes.get("vector_distance", 1) if source_node else 1
target_distance = target_node.attributes.get("vector_distance", 1) if target_node else 1
edge_distance = edge.attributes.get("vector_distance", 1)
total_distance = source_distance + target_distance + edge_distance
heapq.heappush(min_heap, (-total_distance, i, edge))
if len(min_heap) > k:
heapq.heappop(min_heap)
return [edge for _, _, edge in sorted(min_heap)]
return heapq.nsmallest(k, self.edges, key=score)

View file

@ -0,0 +1 @@
from .exceptions import PipelineRunFailedError

View file

@ -0,0 +1,12 @@
from cognee.exceptions import CogneeApiError
from fastapi import status
class PipelineRunFailedError(CogneeApiError):
def __init__(
self,
message: str = "Pipeline run failed.",
name: str = "PipelineRunFailedError",
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)

View file

@ -0,0 +1,5 @@
import enum
class DataItemStatus(str, enum.Enum):
DATA_ITEM_PROCESSING_COMPLETED = "DATA_ITEM_PROCESSING_COMPLETED"

View file

@ -9,6 +9,7 @@ class PipelineRunInfo(BaseModel):
dataset_id: UUID
dataset_name: str
payload: Optional[Any] = None
data_ingestion_info: Optional[list] = None
model_config = {
"arbitrary_types_allowed": True,
@ -30,6 +31,11 @@ class PipelineRunCompleted(PipelineRunInfo):
pass
class PipelineRunAlreadyCompleted(PipelineRunInfo):
status: str = "PipelineRunAlreadyCompleted"
pass
class PipelineRunErrored(PipelineRunInfo):
status: str = "PipelineRunErrored"
pass

View file

@ -6,3 +6,4 @@ from .PipelineRunInfo import (
PipelineRunCompleted,
PipelineRunErrored,
)
from .DataItemStatus import DataItemStatus

View file

@ -52,6 +52,7 @@ async def cognee_pipeline(
pipeline_name: str = "custom_pipeline",
vector_db_config: dict = None,
graph_db_config: dict = None,
incremental_loading: bool = True,
):
# Note: These context variables allow different value assignment for databases in Cognee
# per async task, thread, process and etc.
@ -109,6 +110,7 @@ async def cognee_pipeline(
data=data,
pipeline_name=pipeline_name,
context={"dataset": dataset},
incremental_loading=incremental_loading,
):
yield run_info
@ -120,6 +122,7 @@ async def run_pipeline(
data=None,
pipeline_name: str = "custom_pipeline",
context: dict = None,
incremental_loading=True,
):
check_dataset_name(dataset.name)
@ -187,7 +190,9 @@ async def run_pipeline(
if not isinstance(task, Task):
raise ValueError(f"Task {task} is not an instance of Task")
pipeline_run = run_tasks(tasks, dataset_id, data, user, pipeline_name, context)
pipeline_run = run_tasks(
tasks, dataset_id, data, user, pipeline_name, context, incremental_loading
)
async for pipeline_run_info in pipeline_run:
yield pipeline_run_info

View file

@ -1,21 +1,31 @@
import os
from uuid import UUID
from typing import Any
from functools import wraps
import asyncio
from uuid import UUID
from typing import Any, List
from functools import wraps
from sqlalchemy import select
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
from cognee.modules.users.models import User
from cognee.modules.data.models import Data
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.utils import generate_pipeline_id
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
from cognee.tasks.ingestion import save_data_item_to_storage, resolve_data_directories
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunCompleted,
PipelineRunErrored,
PipelineRunStarted,
PipelineRunYield,
PipelineRunAlreadyCompleted,
)
from cognee.modules.pipelines.models.DataItemStatus import DataItemStatus
from cognee.modules.pipelines.operations import (
log_pipeline_run_start,
@ -50,40 +60,122 @@ def override_run_tasks(new_gen):
@override_run_tasks(run_tasks_distributed)
async def run_tasks(
tasks: list[Task],
tasks: List[Task],
dataset_id: UUID,
data: Any = None,
data: List[Any] = None,
user: User = None,
pipeline_name: str = "unknown_pipeline",
context: dict = None,
incremental_loading: bool = True,
):
if not user:
user = await get_default_user()
async def _run_tasks_data_item_incremental(
data_item,
dataset,
tasks,
pipeline_name,
pipeline_id,
pipeline_run_id,
context,
user,
):
db_engine = get_relational_engine()
# If incremental_loading of data is set to True don't process documents already processed by pipeline
# If data is being added to Cognee for the first time calculate the id of the data
if not isinstance(data_item, Data):
file_path = await save_data_item_to_storage(data_item)
# Ingest data and add metadata
async with open_data_file(file_path) as file:
classified_data = ingestion.classify(file)
# data_id is the hash of file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, user)
else:
# If data was already processed by Cognee get data id
data_id = data_item.id
# Get Dataset object
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
from cognee.modules.data.models import Dataset
# Check pipeline status, if Data already processed for pipeline before skip current processing
async with db_engine.get_async_session() as session:
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
if data_point:
if (
data_point.pipeline_status.get(pipeline_name, {}).get(str(dataset.id))
== DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
):
yield {
"run_info": PipelineRunAlreadyCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
return
dataset = await session.get(Dataset, dataset_id)
try:
# Process data based on data_item and list of tasks
async for result in run_tasks_with_telemetry(
tasks=tasks,
data=[data_item],
user=user,
pipeline_name=pipeline_id,
context=context,
):
yield PipelineRunYield(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
)
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
# Update pipeline status for Data element
async with db_engine.get_async_session() as session:
data_point = (
await session.execute(select(Data).filter(Data.id == data_id))
).scalar_one_or_none()
data_point.pipeline_status[pipeline_name] = {
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED
}
await session.merge(data_point)
await session.commit()
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
yield {
"run_info": PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
pipeline_run_id = pipeline_run.pipeline_run_id
except Exception as error:
# Temporarily swallow error and try to process rest of documents first, then re-raise error at end of data ingestion pipeline
logger.error(
f"Exception caught while processing data: {error}.\n Data processing failed for data item: {data_item}."
)
yield {
"run_info": PipelineRunErrored(
pipeline_run_id=pipeline_run_id,
payload=repr(error),
dataset_id=dataset.id,
dataset_name=dataset.name,
),
"data_id": data_id,
}
yield PipelineRunStarted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=data,
)
try:
async def _run_tasks_data_item_regular(
data_item,
dataset,
tasks,
pipeline_id,
pipeline_run_id,
context,
user,
):
# Process data based on data_item and list of tasks
async for result in run_tasks_with_telemetry(
tasks=tasks,
data=data,
data=[data_item],
user=user,
pipeline_name=pipeline_id,
context=context,
@ -95,6 +187,112 @@ async def run_tasks(
payload=result,
)
yield {
"run_info": PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
)
}
async def _run_tasks_data_item(
data_item,
dataset,
tasks,
pipeline_name,
pipeline_id,
pipeline_run_id,
context,
user,
incremental_loading,
):
# Go through async generator and return data item processing result. Result can be PipelineRunAlreadyCompleted when data item is skipped,
# PipelineRunCompleted when processing was successful and PipelineRunErrored if there were issues
result = None
if incremental_loading:
async for result in _run_tasks_data_item_incremental(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
):
pass
else:
async for result in _run_tasks_data_item_regular(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
):
pass
return result
if not user:
user = await get_default_user()
# Get Dataset object
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
from cognee.modules.data.models import Dataset
dataset = await session.get(Dataset, dataset_id)
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
pipeline_run_id = pipeline_run.pipeline_run_id
yield PipelineRunStarted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=data,
)
try:
if not isinstance(data, list):
data = [data]
if incremental_loading:
data = await resolve_data_directories(data)
# Create async tasks per data item that will run the pipeline for the data item
data_item_tasks = [
asyncio.create_task(
_run_tasks_data_item(
data_item,
dataset,
tasks,
pipeline_name,
pipeline_id,
pipeline_run_id,
context,
user,
incremental_loading,
)
)
for data_item in data
]
results = await asyncio.gather(*data_item_tasks)
# Remove skipped data items from results
results = [result for result in results if result]
# If any data item could not be processed propagate error
errored_results = [
result for result in results if isinstance(result["run_info"], PipelineRunErrored)
]
if errored_results:
raise PipelineRunFailedError(
message="Pipeline run failed. Data item could not be processed."
)
await log_pipeline_run_complete(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
)
@ -103,6 +301,7 @@ async def run_tasks(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
data_ingestion_info=results,
)
graph_engine = await get_graph_engine()
@ -120,9 +319,14 @@ async def run_tasks(
yield PipelineRunErrored(
pipeline_run_id=pipeline_run_id,
payload=error,
payload=repr(error),
dataset_id=dataset.id,
dataset_name=dataset.name,
data_ingestion_info=locals().get(
"results"
), # Returns results if they exist or returns None
)
raise error
# In case of error during incremental loading of data just let the user know the pipeline Errored, don't raise error
if not isinstance(error, PipelineRunFailedError):
raise error

View file

@ -27,7 +27,7 @@ from cognee.modules.users.models import User
from cognee.modules.data.models import Dataset
from cognee.shared.utils import send_telemetry
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
from cognee.modules.search.operations import log_query, log_result
from cognee.modules.search.operations import log_query, log_result, select_search_type
async def search(
@ -129,6 +129,10 @@ async def specific_search(
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
}
# If the query type is FEELING_LUCKY, select the search type intelligently
if query_type is SearchType.FEELING_LUCKY:
query_type = await select_search_type(query)
search_task = search_tasks.get(query_type)
if search_task is None:

View file

@ -1,3 +1,4 @@
from .log_query import log_query
from .log_result import log_result
from .get_history import get_history
from .select_search_type import select_search_type

View file

@ -0,0 +1,43 @@
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.search.types import SearchType
from cognee.shared.logging_utils import get_logger
logger = get_logger("SearchTypeSelector")
async def select_search_type(
query: str,
system_prompt_path: str = "search_type_selector_prompt.txt",
) -> SearchType:
"""
Analyzes the query and Selects the best search type.
Args:
query: The query to analyze.
system_prompt_path: The path to the system prompt.
Returns:
The best search type given by the LLM.
"""
default_search_type = SearchType.RAG_COMPLETION
system_prompt = read_query_prompt(system_prompt_path)
llm_client = get_llm_client()
try:
response = await llm_client.acreate_structured_output(
text_input=query,
system_prompt=system_prompt,
response_model=str,
)
if response.upper() in SearchType.__members__:
logger.info(f"Selected lucky search type: {response.upper()}")
return SearchType(response.upper())
# If the response is not a valid search type, return the default search type
logger.info(f"LLM gives an invalid search type: {response.upper()}")
return default_search_type
except Exception as e:
logger.error(f"Failed to select search type intelligently from LLM: {str(e)}")
return default_search_type

View file

@ -13,3 +13,4 @@ class SearchType(Enum):
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
FEELING_LUCKY = "FEELING_LUCKY"

View file

@ -8,7 +8,6 @@ from cognee.modules.data.models import Data
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.chunking.Chunker import Chunker
from cognee.modules.data.processing.document_types.exceptions.exceptions import PyPdfInternalError
async def update_document_token_count(document_id: UUID, token_count: int) -> None:
@ -40,15 +39,14 @@ async def extract_chunks_from_documents(
"""
for document in documents:
document_token_count = 0
try:
async for document_chunk in document.read(
max_chunk_size=max_chunk_size, chunker_cls=chunker
):
document_token_count += document_chunk.chunk_size
document_chunk.belongs_to_set = document.belongs_to_set
yield document_chunk
await update_document_token_count(document.id, document_token_count)
except PyPdfInternalError:
pass
async for document_chunk in document.read(
max_chunk_size=max_chunk_size, chunker_cls=chunker
):
document_token_count += document_chunk.chunk_size
document_chunk.belongs_to_set = document.belongs_to_set
yield document_chunk
await update_document_token_count(document.id, document_token_count)
# todo rita

View file

@ -5,12 +5,12 @@ from uuid import UUID
from typing import Union, BinaryIO, Any, List, Optional
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Data
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.modules.data.methods import (
get_authorized_existing_datasets,
get_dataset_data,
@ -134,6 +134,7 @@ async def ingest_data(
node_set=json.dumps(node_set) if node_set else None,
data_size=file_metadata["file_size"],
tenant_id=user.tenant_id if user.tenant_id else None,
pipeline_status={},
token_count=-1,
)

View file

@ -40,6 +40,9 @@ async def resolve_data_directories(
if include_subdirectories:
base_path = item if item.endswith("/") else item + "/"
s3_keys = fs.glob(base_path + "**")
# If path is not directory attempt to add item directly
if not s3_keys:
s3_keys = fs.ls(item)
else:
s3_keys = fs.ls(item)
# Filter out keys that represent directories using fs.isdir

View file

@ -103,6 +103,9 @@ async def get_repo_file_dependencies(
extraction of dependencies (default is False). (default False)
"""
if isinstance(repo_path, list) and len(repo_path) == 1:
repo_path = repo_path[0]
if not os.path.exists(repo_path):
raise FileNotFoundError(f"Repository path {repo_path} does not exist.")

View file

@ -1,3 +1,5 @@
import asyncio
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
@ -6,6 +8,9 @@ from cognee.infrastructure.engine import DataPoint
logger = get_logger("index_data_points")
# A single lock shared by all coroutines
vector_index_lock = asyncio.Lock()
async def index_data_points(data_points: list[DataPoint]):
created_indexes = {}
@ -22,9 +27,11 @@ async def index_data_points(data_points: list[DataPoint]):
index_name = f"{data_point_type.__name__}_{field_name}"
if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
created_indexes[index_name] = True
# Add async lock to make sure two different coroutines won't create a table at the same time
async with vector_index_lock:
if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
created_indexes[index_name] = True
if index_name not in index_points:
index_points[index_name] = []
@ -38,7 +45,7 @@ async def index_data_points(data_points: list[DataPoint]):
index_name = index_name_and_field[:first_occurence]
field_name = index_name_and_field[first_occurence + 1 :]
try:
# In case the ammount if indexable points is too large we need to send them in batches
# In case the amount of indexable points is too large we need to send them in batches
batch_size = 100
for i in range(0, len(indexable_points), batch_size):
batch = indexable_points[i : i + batch_size]

View file

@ -1,3 +1,4 @@
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.shared.logging_utils import get_logger, ERROR
from collections import Counter
@ -49,7 +50,9 @@ async def index_graph_edges(batch_size: int = 1024):
)
for text, count in edge_types.items():
edge = EdgeType(relationship_name=text, number_of_edges=count)
edge = EdgeType(
id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count
)
data_point_type = type(edge)
for field_name in edge.metadata["index_fields"]:

View file

@ -26,8 +26,8 @@ async def test_deduplication():
explanation_file_path2 = os.path.join(
pathlib.Path(__file__).parent, "test_data/Natural_language_processing_copy.txt"
)
await cognee.add([explanation_file_path], dataset_name)
await cognee.add([explanation_file_path2], dataset_name2)
await cognee.add([explanation_file_path], dataset_name, incremental_loading=False)
await cognee.add([explanation_file_path2], dataset_name2, incremental_loading=False)
result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found."

View file

@ -155,6 +155,61 @@ async def test_specific_search_chunks(mock_send_telemetry, mock_chunks_retriever
assert results[0]["content"] == "Chunk result"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"selected_type, retriever_name, expected_content, top_k",
[
(SearchType.RAG_COMPLETION, "CompletionRetriever", "RAG result from lucky search", 10),
(SearchType.CHUNKS, "ChunksRetriever", "Chunk result from lucky search", 5),
(SearchType.SUMMARIES, "SummariesRetriever", "Summary from lucky search", 15),
(SearchType.INSIGHTS, "InsightsRetriever", "Insight result from lucky search", 20),
],
)
@patch.object(search_module, "select_search_type")
@patch.object(search_module, "send_telemetry")
async def test_specific_search_feeling_lucky(
mock_send_telemetry,
mock_select_search_type,
selected_type,
retriever_name,
expected_content,
top_k,
mock_user,
):
with patch.object(search_module, retriever_name) as mock_retriever_class:
# Setup
query = f"test query for {retriever_name}"
query_type = SearchType.FEELING_LUCKY
# Mock the intelligent search type selection
mock_select_search_type.return_value = selected_type
# Mock the retriever
mock_retriever_instance = MagicMock()
mock_retriever_instance.get_completion = AsyncMock(
return_value=[{"content": expected_content}]
)
mock_retriever_class.return_value = mock_retriever_instance
# Execute
results = await specific_search(query_type, query, mock_user, top_k=top_k)
# Verify
mock_select_search_type.assert_called_once_with(query)
if retriever_name == "CompletionRetriever":
mock_retriever_class.assert_called_once_with(
system_prompt_path="answer_simple_question.txt", top_k=top_k
)
else:
mock_retriever_class.assert_called_once_with(top_k=top_k)
mock_retriever_instance.get_completion.assert_called_once_with(query)
mock_send_telemetry.assert_called()
assert len(results) == 1
assert results[0]["content"] == expected_content
@pytest.mark.asyncio
async def test_specific_search_invalid_type(mock_user):
# Setup

View file

@ -1,69 +1,77 @@
import os
import json
import asyncio
from typing import List, Any
from cognee import prune
from cognee import visualize_graph
from cognee.low_level import setup, DataPoint
from cognee.modules.data.methods import load_or_create_datasets
from cognee.modules.users.methods import get_default_user
from cognee.pipelines import run_tasks, Task
from cognee.tasks.storage import add_data_points
class Person(DataPoint):
name: str
# Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search
metadata: dict = {"index_fields": ["name"]}
class Department(DataPoint):
name: str
employees: list[Person]
# Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search
metadata: dict = {"index_fields": ["name"]}
class CompanyType(DataPoint):
name: str = "Company"
# Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search
metadata: dict = {"index_fields": ["name"]}
class Company(DataPoint):
name: str
departments: list[Department]
is_type: CompanyType
# Metadata "index_fields" specifies which DataPoint fields should be embedded for vector search
metadata: dict = {"index_fields": ["name"]}
def ingest_files():
companies_file_path = os.path.join(os.path.dirname(__file__), "companies.json")
companies = json.loads(open(companies_file_path, "r").read())
people_file_path = os.path.join(os.path.dirname(__file__), "people.json")
people = json.loads(open(people_file_path, "r").read())
def ingest_files(data: List[Any]):
people_data_points = {}
departments_data_points = {}
for person in people:
new_person = Person(name=person["name"])
people_data_points[person["name"]] = new_person
if person["department"] not in departments_data_points:
departments_data_points[person["department"]] = Department(
name=person["department"], employees=[new_person]
)
else:
departments_data_points[person["department"]].employees.append(new_person)
companies_data_points = {}
# Create a single CompanyType node, so we connect all companies to it.
companyType = CompanyType()
for data_item in data:
people = data_item["people"]
companies = data_item["companies"]
for company in companies:
new_company = Company(name=company["name"], departments=[], is_type=companyType)
companies_data_points[company["name"]] = new_company
for person in people:
new_person = Person(name=person["name"])
people_data_points[person["name"]] = new_person
for department_name in company["departments"]:
if department_name not in departments_data_points:
departments_data_points[department_name] = Department(
name=department_name, employees=[]
if person["department"] not in departments_data_points:
departments_data_points[person["department"]] = Department(
name=person["department"], employees=[new_person]
)
else:
departments_data_points[person["department"]].employees.append(new_person)
new_company.departments.append(departments_data_points[department_name])
# Create a single CompanyType node, so we connect all companies to it.
companyType = CompanyType()
for company in companies:
new_company = Company(name=company["name"], departments=[], is_type=companyType)
companies_data_points[company["name"]] = new_company
for department_name in company["departments"]:
if department_name not in departments_data_points:
departments_data_points[department_name] = Department(
name=department_name, employees=[]
)
new_company.departments.append(departments_data_points[department_name])
return companies_data_points.values()
@ -72,9 +80,30 @@ async def main():
await prune.prune_data()
await prune.prune_system(metadata=True)
# Create relational database tables
await setup()
pipeline = run_tasks([Task(ingest_files), Task(add_data_points)])
# If no user is provided use default user
user = await get_default_user()
# Create dataset object to keep track of pipeline status
datasets = await load_or_create_datasets(["test_dataset"], [], user)
# Prepare data for pipeline
companies_file_path = os.path.join(os.path.dirname(__file__), "companies.json")
companies = json.loads(open(companies_file_path, "r").read())
people_file_path = os.path.join(os.path.dirname(__file__), "people.json")
people = json.loads(open(people_file_path, "r").read())
# Run tasks expects a list of data even if it is just one document
data = [{"companies": companies, "people": people}]
pipeline = run_tasks(
[Task(ingest_files), Task(add_data_points)],
dataset_id=datasets[0].id,
data=data,
incremental_loading=False,
)
async for status in pipeline:
print(status)

File diff suppressed because one or more lines are too long