diff --git a/.env.template b/.env.template
index d178965e8..7defaee09 100644
--- a/.env.template
+++ b/.env.template
@@ -21,6 +21,10 @@ LLM_PROVIDER="openai"
LLM_ENDPOINT=""
LLM_API_VERSION=""
LLM_MAX_TOKENS="16384"
+# Instructor's modes determine how structured data is requested from and extracted from LLM responses
+# You can change this type (i.e. mode) via this env variable
+# Each LLM has its own default value, e.g. gpt-5 models have "json_schema_mode"
+LLM_INSTRUCTOR_MODE=""
EMBEDDING_PROVIDER="openai"
EMBEDDING_MODEL="openai/text-embedding-3-large"
diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py
index 7c708638c..4131be988 100755
--- a/cognee-mcp/src/server.py
+++ b/cognee-mcp/src/server.py
@@ -194,7 +194,6 @@ async def cognify(
Prerequisites:
- **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation)
- - **Data Added**: Must have data previously added via `cognee.add()`
- **Vector Database**: Must be accessible for embeddings storage
- **Graph Database**: Must be accessible for relationship storage
diff --git a/cognee/api/client.py b/cognee/api/client.py
index 19a607ff0..1a08aed56 100644
--- a/cognee/api/client.py
+++ b/cognee/api/client.py
@@ -23,6 +23,7 @@ from cognee.api.v1.settings.routers import get_settings_router
from cognee.api.v1.datasets.routers import get_datasets_router
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
from cognee.api.v1.search.routers import get_search_router
+from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router
from cognee.api.v1.memify.routers import get_memify_router
from cognee.api.v1.add.routers import get_add_router
from cognee.api.v1.delete.routers import get_delete_router
@@ -263,6 +264,8 @@ app.include_router(
app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"])
+app.include_router(get_ontology_router(), prefix="/api/v1/ontologies", tags=["ontologies"])
+
app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"])
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py
index 231bbcd11..4f1497e3c 100644
--- a/cognee/api/v1/cognify/routers/get_cognify_router.py
+++ b/cognee/api/v1/cognify/routers/get_cognify_router.py
@@ -41,6 +41,9 @@ class CognifyPayloadDTO(InDTO):
custom_prompt: Optional[str] = Field(
default="", description="Custom prompt for entity extraction and graph generation"
)
+ ontology_key: Optional[List[str]] = Field(
+ default=None, description="Reference to one or more previously uploaded ontologies"
+ )
def get_cognify_router() -> APIRouter:
@@ -68,6 +71,7 @@ def get_cognify_router() -> APIRouter:
- **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted).
- **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking).
- **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction.
+ - **ontology_key** (Optional[List[str]]): Reference to one or more previously uploaded ontology files to use for knowledge graph construction.
## Response
- **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status
@@ -82,7 +86,8 @@ def get_cognify_router() -> APIRouter:
{
"datasets": ["research_papers", "documentation"],
"run_in_background": false,
- "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections."
+ "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.",
+ "ontology_key": ["medical_ontology_v1"]
}
```
@@ -108,13 +113,35 @@ def get_cognify_router() -> APIRouter:
)
from cognee.api.v1.cognify import cognify as cognee_cognify
+ from cognee.api.v1.ontologies.ontologies import OntologyService
try:
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
+ config_to_use = None
+
+ if payload.ontology_key:
+ ontology_service = OntologyService()
+ ontology_contents = ontology_service.get_ontology_contents(
+ payload.ontology_key, user
+ )
+
+ from cognee.modules.ontology.ontology_config import Config
+ from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import (
+ RDFLibOntologyResolver,
+ )
+ from io import StringIO
+
+ ontology_streams = [StringIO(content) for content in ontology_contents]
+ config_to_use: Config = {
+ "ontology_config": {
+ "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_streams)
+ }
+ }
cognify_run = await cognee_cognify(
datasets,
user,
+ config=config_to_use,
run_in_background=payload.run_in_background,
custom_prompt=payload.custom_prompt,
)
diff --git a/cognee/api/v1/ontologies/__init__.py b/cognee/api/v1/ontologies/__init__.py
new file mode 100644
index 000000000..b90d46c3d
--- /dev/null
+++ b/cognee/api/v1/ontologies/__init__.py
@@ -0,0 +1,4 @@
+from .ontologies import OntologyService
+from .routers.get_ontology_router import get_ontology_router
+
+__all__ = ["OntologyService", "get_ontology_router"]
diff --git a/cognee/api/v1/ontologies/ontologies.py b/cognee/api/v1/ontologies/ontologies.py
new file mode 100644
index 000000000..130b4a862
--- /dev/null
+++ b/cognee/api/v1/ontologies/ontologies.py
@@ -0,0 +1,183 @@
+import os
+import json
+import tempfile
+from pathlib import Path
+from datetime import datetime, timezone
+from typing import Optional, List
+from dataclasses import dataclass
+
+
+@dataclass
+class OntologyMetadata:
+ ontology_key: str
+ filename: str
+ size_bytes: int
+ uploaded_at: str
+ description: Optional[str] = None
+
+
+class OntologyService:
+ def __init__(self):
+ pass
+
+ @property
+ def base_dir(self) -> Path:
+ return Path(tempfile.gettempdir()) / "ontologies"
+
+ def _get_user_dir(self, user_id: str) -> Path:
+ user_dir = self.base_dir / str(user_id)
+ user_dir.mkdir(parents=True, exist_ok=True)
+ return user_dir
+
+ def _get_metadata_path(self, user_dir: Path) -> Path:
+ return user_dir / "metadata.json"
+
+ def _load_metadata(self, user_dir: Path) -> dict:
+ metadata_path = self._get_metadata_path(user_dir)
+ if metadata_path.exists():
+ with open(metadata_path, "r") as f:
+ return json.load(f)
+ return {}
+
+ def _save_metadata(self, user_dir: Path, metadata: dict):
+ metadata_path = self._get_metadata_path(user_dir)
+ with open(metadata_path, "w") as f:
+ json.dump(metadata, f, indent=2)
+
+ async def upload_ontology(
+ self, ontology_key: str, file, user, description: Optional[str] = None
+ ) -> OntologyMetadata:
+ if not file.filename.lower().endswith(".owl"):
+ raise ValueError("File must be in .owl format")
+
+ user_dir = self._get_user_dir(str(user.id))
+ metadata = self._load_metadata(user_dir)
+
+ if ontology_key in metadata:
+ raise ValueError(f"Ontology key '{ontology_key}' already exists")
+
+ content = await file.read()
+ if len(content) > 10 * 1024 * 1024:
+ raise ValueError("File size exceeds 10MB limit")
+
+ file_path = user_dir / f"{ontology_key}.owl"
+ with open(file_path, "wb") as f:
+ f.write(content)
+
+ ontology_metadata = {
+ "filename": file.filename,
+ "size_bytes": len(content),
+ "uploaded_at": datetime.now(timezone.utc).isoformat(),
+ "description": description,
+ }
+ metadata[ontology_key] = ontology_metadata
+ self._save_metadata(user_dir, metadata)
+
+ return OntologyMetadata(
+ ontology_key=ontology_key,
+ filename=file.filename,
+ size_bytes=len(content),
+ uploaded_at=ontology_metadata["uploaded_at"],
+ description=description,
+ )
+
+ async def upload_ontologies(
+ self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None
+ ) -> List[OntologyMetadata]:
+ """
+ Upload ontology files with their respective keys.
+
+ Args:
+ ontology_key: List of unique keys for each ontology
+ files: List of UploadFile objects (same length as keys)
+ user: Authenticated user
+ descriptions: Optional list of descriptions for each file
+
+ Returns:
+ List of OntologyMetadata objects for uploaded files
+
+ Raises:
+ ValueError: If keys duplicate, file format invalid, or array lengths don't match
+ """
+ if len(ontology_key) != len(files):
+ raise ValueError("Number of keys must match number of files")
+
+ if len(set(ontology_key)) != len(ontology_key):
+ raise ValueError("Duplicate ontology keys not allowed")
+
+ if descriptions and len(descriptions) != len(files):
+ raise ValueError("Number of descriptions must match number of files")
+
+ results = []
+ user_dir = self._get_user_dir(str(user.id))
+ metadata = self._load_metadata(user_dir)
+
+ for i, (key, file) in enumerate(zip(ontology_key, files)):
+ if key in metadata:
+ raise ValueError(f"Ontology key '{key}' already exists")
+
+ if not file.filename.lower().endswith(".owl"):
+ raise ValueError(f"File '{file.filename}' must be in .owl format")
+
+ content = await file.read()
+ if len(content) > 10 * 1024 * 1024:
+ raise ValueError(f"File '{file.filename}' exceeds 10MB limit")
+
+ file_path = user_dir / f"{key}.owl"
+ with open(file_path, "wb") as f:
+ f.write(content)
+
+ ontology_metadata = {
+ "filename": file.filename,
+ "size_bytes": len(content),
+ "uploaded_at": datetime.now(timezone.utc).isoformat(),
+ "description": descriptions[i] if descriptions else None,
+ }
+ metadata[key] = ontology_metadata
+
+ results.append(
+ OntologyMetadata(
+ ontology_key=key,
+ filename=file.filename,
+ size_bytes=len(content),
+ uploaded_at=ontology_metadata["uploaded_at"],
+ description=descriptions[i] if descriptions else None,
+ )
+ )
+
+ self._save_metadata(user_dir, metadata)
+ return results
+
+ def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:
+ """
+ Retrieve ontology content for one or more keys.
+
+ Args:
+ ontology_key: List of ontology keys to retrieve (can contain single item)
+ user: Authenticated user
+
+ Returns:
+ List of ontology content strings
+
+ Raises:
+ ValueError: If any ontology key not found
+ """
+ user_dir = self._get_user_dir(str(user.id))
+ metadata = self._load_metadata(user_dir)
+
+ contents = []
+ for key in ontology_key:
+ if key not in metadata:
+ raise ValueError(f"Ontology key '{key}' not found")
+
+ file_path = user_dir / f"{key}.owl"
+ if not file_path.exists():
+ raise ValueError(f"Ontology file for key '{key}' not found")
+
+ with open(file_path, "r", encoding="utf-8") as f:
+ contents.append(f.read())
+ return contents
+
+ def list_ontologies(self, user) -> dict:
+ user_dir = self._get_user_dir(str(user.id))
+ return self._load_metadata(user_dir)
diff --git a/cognee/api/v1/ontologies/routers/__init__.py b/cognee/api/v1/ontologies/routers/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py
new file mode 100644
index 000000000..ee31c683f
--- /dev/null
+++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py
@@ -0,0 +1,107 @@
+from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException
+from fastapi.responses import JSONResponse
+from typing import Optional, List
+
+from cognee.modules.users.models import User
+from cognee.modules.users.methods import get_authenticated_user
+from cognee.shared.utils import send_telemetry
+from cognee import __version__ as cognee_version
+from ..ontologies import OntologyService
+
+
+def get_ontology_router() -> APIRouter:
+ router = APIRouter()
+ ontology_service = OntologyService()
+
+ @router.post("", response_model=dict)
+ async def upload_ontology(
+ ontology_key: str = Form(...),
+ ontology_file: List[UploadFile] = File(...),
+ descriptions: Optional[str] = Form(None),
+ user: User = Depends(get_authenticated_user),
+ ):
+ """
+ Upload ontology files with their respective keys for later use in cognify operations.
+
+ Supports both single and multiple file uploads:
+ - Single file: ontology_key=["key"], ontology_file=[file]
+ - Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2]
+
+ ## Request Parameters
+ - **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies
+ - **ontology_file** (List[UploadFile]): OWL format ontology files
+ - **descriptions** (Optional[str]): JSON array string of optional descriptions
+
+ ## Response
+ Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps.
+
+ ## Error Codes
+ - **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded
+ - **500 Internal Server Error**: File system or processing errors
+ """
+ send_telemetry(
+ "Ontology Upload API Endpoint Invoked",
+ user.id,
+ additional_properties={
+ "endpoint": "POST /api/v1/ontologies",
+ "cognee_version": cognee_version,
+ },
+ )
+
+ try:
+ import json
+
+ ontology_keys = json.loads(ontology_key)
+ description_list = json.loads(descriptions) if descriptions else None
+
+ if not isinstance(ontology_keys, list):
+ raise ValueError("ontology_key must be a JSON array")
+
+ results = await ontology_service.upload_ontologies(
+ ontology_keys, ontology_file, user, description_list
+ )
+
+ return {
+ "uploaded_ontologies": [
+ {
+ "ontology_key": result.ontology_key,
+ "filename": result.filename,
+ "size_bytes": result.size_bytes,
+ "uploaded_at": result.uploaded_at,
+ "description": result.description,
+ }
+ for result in results
+ ]
+ }
+ except (json.JSONDecodeError, ValueError) as e:
+ return JSONResponse(status_code=400, content={"error": str(e)})
+ except Exception as e:
+ return JSONResponse(status_code=500, content={"error": str(e)})
+
+ @router.get("", response_model=dict)
+ async def list_ontologies(user: User = Depends(get_authenticated_user)):
+ """
+ List all uploaded ontologies for the authenticated user.
+
+ ## Response
+ Returns a dictionary mapping ontology keys to their metadata including filename, size, and upload timestamp.
+
+ ## Error Codes
+ - **500 Internal Server Error**: File system or processing errors
+ """
+ send_telemetry(
+ "Ontology List API Endpoint Invoked",
+ user.id,
+ additional_properties={
+ "endpoint": "GET /api/v1/ontologies",
+ "cognee_version": cognee_version,
+ },
+ )
+
+ try:
+ metadata = ontology_service.list_ontologies(user)
+ return metadata
+ except Exception as e:
+ return JSONResponse(status_code=500, content={"error": str(e)})
+
+ return router
diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py
index d4e5fbbe6..354331c57 100644
--- a/cognee/api/v1/search/search.py
+++ b/cognee/api/v1/search/search.py
@@ -31,6 +31,8 @@ async def search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[List[SearchResult], CombinedSearchResult]:
"""
Search and query the knowledge graph for insights, information, and connections.
@@ -200,6 +202,8 @@ async def search(
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
return filtered_search_results
diff --git a/cognee/cli/commands/cognify_command.py b/cognee/cli/commands/cognify_command.py
index 16eaf0454..b89c1f70e 100644
--- a/cognee/cli/commands/cognify_command.py
+++ b/cognee/cli/commands/cognify_command.py
@@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin
Processing Pipeline:
1. **Document Classification**: Identifies document types and structures
-2. **Permission Validation**: Ensures user has processing rights
+2. **Permission Validation**: Ensures user has processing rights
3. **Text Chunking**: Breaks content into semantically meaningful segments
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
5. **Relationship Detection**: Discovers connections between entities
@@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge
chunker_class = LangchainChunker
except ImportError:
fmt.warning("LangchainChunker not available, using TextChunker")
+ elif args.chunker == "CsvChunker":
+ try:
+ from cognee.modules.chunking.CsvChunker import CsvChunker
+
+ chunker_class = CsvChunker
+ except ImportError:
+ fmt.warning("CsvChunker not available, using TextChunker")
result = await cognee.cognify(
datasets=datasets,
diff --git a/cognee/cli/config.py b/cognee/cli/config.py
index d016608c1..082adbaec 100644
--- a/cognee/cli/config.py
+++ b/cognee/cli/config.py
@@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [
]
# Chunker choices
-CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"]
+CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"]
# Output format choices
OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"]
diff --git a/cognee/eval_framework/Dockerfile b/cognee/eval_framework/Dockerfile
new file mode 100644
index 000000000..e83be3da4
--- /dev/null
+++ b/cognee/eval_framework/Dockerfile
@@ -0,0 +1,29 @@
+FROM python:3.11-slim
+
+# Set environment variables
+ENV PIP_NO_CACHE_DIR=true
+ENV PATH="${PATH}:/root/.poetry/bin"
+ENV PYTHONPATH=/app
+ENV SKIP_MIGRATIONS=true
+
+# System dependencies
+RUN apt-get update && apt-get install -y \
+ gcc \
+ libpq-dev \
+ git \
+ curl \
+ build-essential \
+ && rm -rf /var/lib/apt/lists/*
+
+WORKDIR /app
+
+COPY pyproject.toml poetry.lock README.md /app/
+
+RUN pip install poetry
+
+RUN poetry config virtualenvs.create false
+
+RUN poetry install --extras distributed --extras evals --extras deepeval --no-root
+
+COPY cognee/ /app/cognee
+COPY distributed/ /app/distributed
diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py
index 6f166657e..29b3ede68 100644
--- a/cognee/eval_framework/answer_generation/answer_generation_executor.py
+++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py
@@ -35,6 +35,16 @@ class AnswerGeneratorExecutor:
retrieval_context = await retriever.get_context(query_text)
search_results = await retriever.get_completion(query_text, retrieval_context)
+ ############
+ #:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure..
+ if isinstance(retrieval_context, list):
+ retrieval_context = await retriever.convert_retrieved_objects_to_context(
+ triplets=retrieval_context
+ )
+
+ if isinstance(search_results, str):
+ search_results = [search_results]
+ #############
answer = {
"question": query_text,
"answer": search_results[0],
diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py
index d0a2ebe1e..6b55d84b2 100644
--- a/cognee/eval_framework/answer_generation/run_question_answering_module.py
+++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py
@@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload):
async def run_question_answering(
- params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
+ params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None
) -> List[dict]:
if params.get("answering_questions"):
logger.info("Question answering started...")
diff --git a/cognee/eval_framework/eval_config.py b/cognee/eval_framework/eval_config.py
index 6edcc0454..9e6f26688 100644
--- a/cognee/eval_framework/eval_config.py
+++ b/cognee/eval_framework/eval_config.py
@@ -14,7 +14,7 @@ class EvalConfig(BaseSettings):
# Question answering params
answering_questions: bool = True
- qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
+ qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
# Evaluation params
evaluating_answers: bool = True
@@ -25,7 +25,7 @@ class EvalConfig(BaseSettings):
"EM",
"f1",
] # Use only 'correctness' for DirectLLM
- deepeval_model: str = "gpt-5-mini"
+ deepeval_model: str = "gpt-4o-mini"
# Metrics params
calculate_metrics: bool = True
diff --git a/cognee/eval_framework/modal_run_eval.py b/cognee/eval_framework/modal_run_eval.py
index aca2686a5..bc2ff77c5 100644
--- a/cognee/eval_framework/modal_run_eval.py
+++ b/cognee/eval_framework/modal_run_eval.py
@@ -2,7 +2,6 @@ import modal
import os
import asyncio
import datetime
-import hashlib
import json
from cognee.shared.logging_utils import get_logger
from cognee.eval_framework.eval_config import EvalConfig
@@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b
from cognee.eval_framework.answer_generation.run_question_answering_module import (
run_question_answering,
)
+import pathlib
+from os import path
+from modal import Image
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
from cognee.eval_framework.metrics_dashboard import create_dashboard
@@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict:
app = modal.App("modal-run-eval")
-image = (
- modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
- .copy_local_file("pyproject.toml", "pyproject.toml")
- .copy_local_file("poetry.lock", "poetry.lock")
- .env(
- {
- "ENV": os.getenv("ENV"),
- "LLM_API_KEY": os.getenv("LLM_API_KEY"),
- "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
- }
- )
- .pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
+image = Image.from_dockerfile(
+ path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(),
+ force_build=False,
+).add_local_python_source("cognee")
+
+
+@app.function(
+ image=image,
+ max_containers=10,
+ timeout=86400,
+ volumes={"/data": vol},
+ secrets=[modal.Secret.from_name("eval_secrets")],
)
-
-
-@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
async def modal_run_eval(eval_params=None):
"""Runs evaluation pipeline and returns combined metrics results."""
if eval_params is None:
@@ -105,18 +104,7 @@ async def main():
configs = [
EvalConfig(
task_getter_type="Default",
- number_of_samples_in_corpus=10,
- benchmark="HotPotQA",
- qa_engine="cognee_graph_completion",
- building_corpus_from_scratch=True,
- answering_questions=True,
- evaluating_answers=True,
- calculate_metrics=True,
- dashboard=True,
- ),
- EvalConfig(
- task_getter_type="Default",
- number_of_samples_in_corpus=10,
+ number_of_samples_in_corpus=25,
benchmark="TwoWikiMultiHop",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
@@ -127,7 +115,7 @@ async def main():
),
EvalConfig(
task_getter_type="Default",
- number_of_samples_in_corpus=10,
+ number_of_samples_in_corpus=25,
benchmark="Musique",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py
index 67df1a27c..8f8c96e79 100644
--- a/cognee/infrastructure/databases/graph/graph_db_interface.py
+++ b/cognee/infrastructure/databases/graph/graph_db_interface.py
@@ -398,3 +398,18 @@ class GraphDBInterface(ABC):
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
"""
raise NotImplementedError
+
+ @abstractmethod
+ async def get_filtered_graph_data(
+ self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
+ ) -> Tuple[List[Node], List[EdgeData]]:
+ """
+ Retrieve nodes and edges filtered by the provided attribute criteria.
+
+ Parameters:
+ -----------
+
+ - attribute_filters: A list of dictionaries where keys are attribute names and values
+ are lists of attribute values to filter by.
+ """
+ raise NotImplementedError
diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py
index 8dd160665..9dbc9c1bc 100644
--- a/cognee/infrastructure/databases/graph/kuzu/adapter.py
+++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py
@@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type
+from cognee.exceptions import CogneeValidationError
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.files.storage import get_file_storage
@@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface):
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
tuples of (source_id, target_id, relationship_name, properties).
"""
+
+ import time
+
+ start_time = time.time()
+
try:
nodes_query = """
MATCH (n:Node)
@@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface):
},
)
)
+
+ retrieval_time = time.time() - start_time
+ logger.info(
+ f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
+ )
return formatted_nodes, formatted_edges
except Exception as e:
logger.error(f"Failed to get graph data: {e}")
@@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface):
formatted_edges.append((source_id, target_id, rel_type, props))
return formatted_nodes, formatted_edges
+ async def get_id_filtered_graph_data(self, target_ids: list[str]):
+ """
+ Retrieve graph data filtered by specific node IDs, including their direct neighbors
+ and only edges where one endpoint matches those IDs.
+
+ Returns:
+ nodes: List[dict] -> Each dict includes "id" and all node properties
+ edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
+ """
+ import time
+
+ start_time = time.time()
+
+ try:
+ if not target_ids:
+ logger.warning("No target IDs provided for ID-filtered graph retrieval.")
+ return [], []
+
+ if not all(isinstance(x, str) for x in target_ids):
+ raise CogneeValidationError("target_ids must be a list of strings")
+
+ query = """
+ MATCH (n:Node)-[r]->(m:Node)
+ WHERE n.id IN $target_ids OR m.id IN $target_ids
+ RETURN n.id, {
+ name: n.name,
+ type: n.type,
+ properties: n.properties
+ }, m.id, {
+ name: m.name,
+ type: m.type,
+ properties: m.properties
+ }, r.relationship_name, r.properties
+ """
+
+ result = await self.query(query, {"target_ids": target_ids})
+
+ if not result:
+ logger.info("No data returned for the supplied IDs")
+ return [], []
+
+ nodes_dict = {}
+ edges = []
+
+ for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
+ if n_props.get("properties"):
+ try:
+ additional_props = json.loads(n_props["properties"])
+ n_props.update(additional_props)
+ del n_props["properties"]
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse properties JSON for node {n_id}")
+
+ if m_props.get("properties"):
+ try:
+ additional_props = json.loads(m_props["properties"])
+ m_props.update(additional_props)
+ del m_props["properties"]
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse properties JSON for node {m_id}")
+
+ nodes_dict[n_id] = (n_id, n_props)
+ nodes_dict[m_id] = (m_id, m_props)
+
+ edge_props = {}
+ if r_props_raw:
+ try:
+ edge_props = json.loads(r_props_raw)
+ except (json.JSONDecodeError, TypeError):
+ logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
+
+ source_id = edge_props.get("source_node_id", n_id)
+ target_id = edge_props.get("target_node_id", m_id)
+ edges.append((source_id, target_id, r_type, edge_props))
+
+ retrieval_time = time.time() - start_time
+ logger.info(
+ f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
+ )
+
+ return list(nodes_dict.values()), edges
+
+ except Exception as e:
+ logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
+ raise
+
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
"""
Get metrics on graph structure and connectivity.
diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
index 6216e107e..f3bb8e173 100644
--- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
+++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
@@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface):
logger.error(f"Error during graph data retrieval: {str(e)}")
raise
+ async def get_id_filtered_graph_data(self, target_ids: list[str]):
+ """
+ Retrieve graph data filtered by specific node IDs, including their direct neighbors
+ and only edges where one endpoint matches those IDs.
+
+ This version uses a single Cypher query for efficiency.
+ """
+ import time
+
+ start_time = time.time()
+
+ try:
+ if not target_ids:
+ logger.warning("No target IDs provided for ID-filtered graph retrieval.")
+ return [], []
+
+ query = """
+ MATCH ()-[r]-()
+ WHERE startNode(r).id IN $target_ids
+ OR endNode(r).id IN $target_ids
+ WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
+ RETURN
+ properties(a) AS n_properties,
+ properties(b) AS m_properties,
+ type(r) AS type,
+ properties(r) AS properties
+ """
+
+ result = await self.query(query, {"target_ids": target_ids})
+
+ nodes_dict = {}
+ edges = []
+
+ for record in result:
+ n_props = record["n_properties"]
+ m_props = record["m_properties"]
+ r_props = record["properties"]
+ r_type = record["type"]
+
+ nodes_dict[n_props["id"]] = (n_props["id"], n_props)
+ nodes_dict[m_props["id"]] = (m_props["id"], m_props)
+
+ source_id = r_props.get("source_node_id", n_props["id"])
+ target_id = r_props.get("target_node_id", m_props["id"])
+ edges.append((source_id, target_id, r_type, r_props))
+
+ retrieval_time = time.time() - start_time
+ logger.info(
+ f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
+ )
+
+ return list(nodes_dict.values()), edges
+
+ except Exception as e:
+ logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
+ raise
+
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:
diff --git a/cognee/infrastructure/files/utils/guess_file_type.py b/cognee/infrastructure/files/utils/guess_file_type.py
index 78b20c93d..4bc96fe80 100644
--- a/cognee/infrastructure/files/utils/guess_file_type.py
+++ b/cognee/infrastructure/files/utils/guess_file_type.py
@@ -55,6 +55,10 @@ def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type
file_type = Type("text/plain", "txt")
return file_type
+ if ext in [".csv"]:
+ file_type = Type("text/csv", "csv")
+ return file_type
+
file_type = filetype.guess(file)
# If file type could not be determined consider it a plain text file as they don't have magic number encoding
diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py
index 8fd196eaf..2e300dc0c 100644
--- a/cognee/infrastructure/llm/config.py
+++ b/cognee/infrastructure/llm/config.py
@@ -38,6 +38,7 @@ class LLMConfig(BaseSettings):
"""
structured_output_framework: str = "instructor"
+ llm_instructor_mode: str = ""
llm_provider: str = "openai"
llm_model: str = "openai/gpt-5-mini"
llm_endpoint: str = ""
@@ -181,6 +182,7 @@ class LLMConfig(BaseSettings):
instance.
"""
return {
+ "llm_instructor_mode": self.llm_instructor_mode.lower(),
"provider": self.llm_provider,
"model": self.llm_model,
"endpoint": self.llm_endpoint,
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
index bf19d6e86..dbf0dfbea 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
@@ -28,13 +28,16 @@ class AnthropicAdapter(LLMInterface):
name = "Anthropic"
model: str
+ default_instructor_mode = "anthropic_tools"
- def __init__(self, max_completion_tokens: int, model: str = None):
+ def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
import anthropic
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
+
self.aclient = instructor.patch(
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
- mode=instructor.Mode.ANTHROPIC_TOOLS,
+ mode=instructor.Mode(self.instructor_mode),
)
self.model = model
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
index 1187e0cad..226f291d7 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
@@ -41,6 +41,7 @@ class GeminiAdapter(LLMInterface):
name: str
model: str
api_key: str
+ default_instructor_mode = "json_mode"
def __init__(
self,
@@ -49,6 +50,7 @@ class GeminiAdapter(LLMInterface):
model: str,
api_version: str,
max_completion_tokens: int,
+ instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
@@ -63,7 +65,11 @@ class GeminiAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
- self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
+
+ self.aclient = instructor.from_litellm(
+ litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
+ )
@retry(
stop=stop_after_delay(128),
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py
index 8bbbaa2cc..9d7f25fc5 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py
@@ -41,6 +41,7 @@ class GenericAPIAdapter(LLMInterface):
name: str
model: str
api_key: str
+ default_instructor_mode = "json_mode"
def __init__(
self,
@@ -49,6 +50,7 @@ class GenericAPIAdapter(LLMInterface):
model: str,
name: str,
max_completion_tokens: int,
+ instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
@@ -63,7 +65,11 @@ class GenericAPIAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
- self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
+
+ self.aclient = instructor.from_litellm(
+ litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
+ )
@retry(
stop=stop_after_delay(128),
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
index c7dcecc56..39558f36d 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
@@ -81,6 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True):
model=llm_config.llm_model,
transcription_model=llm_config.transcription_model,
max_completion_tokens=max_completion_tokens,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
streaming=llm_config.llm_streaming,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
@@ -101,6 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True):
llm_config.llm_model,
"Ollama",
max_completion_tokens=max_completion_tokens,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.ANTHROPIC:
@@ -109,7 +111,9 @@ def get_llm_client(raise_api_key_error: bool = True):
)
return AnthropicAdapter(
- max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
+ max_completion_tokens=max_completion_tokens,
+ model=llm_config.llm_model,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.CUSTOM:
@@ -126,6 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True):
llm_config.llm_model,
"Custom",
max_completion_tokens=max_completion_tokens,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
fallback_model=llm_config.fallback_model,
@@ -145,6 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True):
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
api_version=llm_config.llm_api_version,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.MISTRAL:
@@ -160,21 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True):
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
- )
-
- elif provider == LLMProvider.MISTRAL:
- if llm_config.llm_api_key is None:
- raise LLMAPIKeyNotSetError()
-
- from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
- MistralAdapter,
- )
-
- return MistralAdapter(
- api_key=llm_config.llm_api_key,
- model=llm_config.llm_model,
- max_completion_tokens=max_completion_tokens,
- endpoint=llm_config.llm_endpoint,
+ instructor_mode=llm_config.llm_instructor_mode.lower(),
)
else:
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
index 78a3cbff5..355cdae0b 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
@@ -37,16 +37,26 @@ class MistralAdapter(LLMInterface):
model: str
api_key: str
max_completion_tokens: int
+ default_instructor_mode = "mistral_tools"
- def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
+ def __init__(
+ self,
+ api_key: str,
+ model: str,
+ max_completion_tokens: int,
+ endpoint: str = None,
+ instructor_mode: str = None,
+ ):
from mistralai import Mistral
self.model = model
self.max_completion_tokens = max_completion_tokens
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
+
self.aclient = instructor.from_litellm(
litellm.acompletion,
- mode=instructor.Mode.MISTRAL_TOOLS,
+ mode=instructor.Mode(self.instructor_mode),
api_key=get_llm_config().llm_api_key,
)
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
index 9c3d185aa..aabd19867 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
@@ -42,8 +42,16 @@ class OllamaAPIAdapter(LLMInterface):
- aclient
"""
+ default_instructor_mode = "json_mode"
+
def __init__(
- self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
+ self,
+ endpoint: str,
+ api_key: str,
+ model: str,
+ name: str,
+ max_completion_tokens: int,
+ instructor_mode: str = None,
):
self.name = name
self.model = model
@@ -51,8 +59,11 @@ class OllamaAPIAdapter(LLMInterface):
self.endpoint = endpoint
self.max_completion_tokens = max_completion_tokens
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
+
self.aclient = instructor.from_openai(
- OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
+ OpenAI(base_url=self.endpoint, api_key=self.api_key),
+ mode=instructor.Mode(self.instructor_mode),
)
@retry(
diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
index 305b426b8..778c8eec7 100644
--- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
+++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
@@ -56,6 +56,7 @@ class OpenAIAdapter(LLMInterface):
model: str
api_key: str
api_version: str
+ default_instructor_mode = "json_schema_mode"
MAX_RETRIES = 5
@@ -69,19 +70,21 @@ class OpenAIAdapter(LLMInterface):
model: str,
transcription_model: str,
max_completion_tokens: int,
+ instructor_mode: str = None,
streaming: bool = False,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
# Make sure all new gpt models will work with this mode as well.
if "gpt-5" in model:
self.aclient = instructor.from_litellm(
- litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
+ litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
self.client = instructor.from_litellm(
- litellm.completion, mode=instructor.Mode.JSON_SCHEMA
+ litellm.completion, mode=instructor.Mode(self.instructor_mode)
)
else:
self.aclient = instructor.from_litellm(litellm.acompletion)
diff --git a/cognee/infrastructure/loaders/LoaderEngine.py b/cognee/infrastructure/loaders/LoaderEngine.py
index f9511e7c5..4a363a0e6 100644
--- a/cognee/infrastructure/loaders/LoaderEngine.py
+++ b/cognee/infrastructure/loaders/LoaderEngine.py
@@ -31,6 +31,7 @@ class LoaderEngine:
"pypdf_loader",
"image_loader",
"audio_loader",
+ "csv_loader",
"unstructured_loader",
"advanced_pdf_loader",
]
diff --git a/cognee/infrastructure/loaders/core/__init__.py b/cognee/infrastructure/loaders/core/__init__.py
index 8a2df80f9..09819fbd2 100644
--- a/cognee/infrastructure/loaders/core/__init__.py
+++ b/cognee/infrastructure/loaders/core/__init__.py
@@ -3,5 +3,6 @@
from .text_loader import TextLoader
from .audio_loader import AudioLoader
from .image_loader import ImageLoader
+from .csv_loader import CsvLoader
-__all__ = ["TextLoader", "AudioLoader", "ImageLoader"]
+__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"]
diff --git a/cognee/infrastructure/loaders/core/csv_loader.py b/cognee/infrastructure/loaders/core/csv_loader.py
new file mode 100644
index 000000000..a314a7a24
--- /dev/null
+++ b/cognee/infrastructure/loaders/core/csv_loader.py
@@ -0,0 +1,93 @@
+import os
+from typing import List
+import csv
+from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
+from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
+from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
+
+
+class CsvLoader(LoaderInterface):
+ """
+ Core CSV file loader that handles basic CSV file formats.
+ """
+
+ @property
+ def supported_extensions(self) -> List[str]:
+ """Supported text file extensions."""
+ return [
+ "csv",
+ ]
+
+ @property
+ def supported_mime_types(self) -> List[str]:
+ """Supported MIME types for text content."""
+ return [
+ "text/csv",
+ ]
+
+ @property
+ def loader_name(self) -> str:
+ """Unique identifier for this loader."""
+ return "csv_loader"
+
+ def can_handle(self, extension: str, mime_type: str) -> bool:
+ """
+ Check if this loader can handle the given file.
+
+ Args:
+ extension: File extension
+ mime_type: Optional MIME type
+
+ Returns:
+ True if file can be handled, False otherwise
+ """
+ if extension in self.supported_extensions and mime_type in self.supported_mime_types:
+ return True
+
+ return False
+
+ async def load(self, file_path: str, encoding: str = "utf-8", **kwargs):
+ """
+ Load and process the csv file.
+
+ Args:
+ file_path: Path to the file to load
+ encoding: Text encoding to use (default: utf-8)
+ **kwargs: Additional configuration (unused)
+
+ Returns:
+ LoaderResult containing the file content and metadata
+
+ Raises:
+ FileNotFoundError: If file doesn't exist
+ UnicodeDecodeError: If file cannot be decoded with specified encoding
+ OSError: If file cannot be read
+ """
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"File not found: {file_path}")
+
+ with open(file_path, "rb") as f:
+ file_metadata = await get_file_metadata(f)
+ # Name ingested file of current loader based on original file content hash
+ storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
+
+ row_texts = []
+ row_index = 1
+
+ with open(file_path, "r", encoding=encoding, newline="") as file:
+ reader = csv.DictReader(file)
+ for row in reader:
+ pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
+ row_text = ", ".join(pairs)
+ row_texts.append(f"Row {row_index}:\n{row_text}\n")
+ row_index += 1
+
+ content = "\n".join(row_texts)
+
+ storage_config = get_storage_config()
+ data_root_directory = storage_config["data_root_directory"]
+ storage = get_file_storage(data_root_directory)
+
+ full_file_path = await storage.store(storage_file_name, content)
+
+ return full_file_path
diff --git a/cognee/infrastructure/loaders/core/text_loader.py b/cognee/infrastructure/loaders/core/text_loader.py
index a6f94be9b..e478edb22 100644
--- a/cognee/infrastructure/loaders/core/text_loader.py
+++ b/cognee/infrastructure/loaders/core/text_loader.py
@@ -16,7 +16,7 @@ class TextLoader(LoaderInterface):
@property
def supported_extensions(self) -> List[str]:
"""Supported text file extensions."""
- return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"]
+ return ["txt", "md", "json", "xml", "yaml", "yml", "log"]
@property
def supported_mime_types(self) -> List[str]:
@@ -24,7 +24,6 @@ class TextLoader(LoaderInterface):
return [
"text/plain",
"text/markdown",
- "text/csv",
"application/json",
"text/xml",
"application/xml",
diff --git a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py
index 6d1412b77..4b3ba296a 100644
--- a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py
+++ b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py
@@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface):
if value is None:
return ""
return str(value).replace("\xa0", " ").strip()
-
-
-if __name__ == "__main__":
- loader = AdvancedPdfLoader()
- asyncio.run(
- loader.load(
- "/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
- )
- )
diff --git a/cognee/infrastructure/loaders/supported_loaders.py b/cognee/infrastructure/loaders/supported_loaders.py
index 156253b53..2b8c3e0b4 100644
--- a/cognee/infrastructure/loaders/supported_loaders.py
+++ b/cognee/infrastructure/loaders/supported_loaders.py
@@ -1,5 +1,5 @@
from cognee.infrastructure.loaders.external import PyPdfLoader
-from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader
+from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader
# Registry for loader implementations
supported_loaders = {
@@ -7,6 +7,7 @@ supported_loaders = {
TextLoader.loader_name: TextLoader,
ImageLoader.loader_name: ImageLoader,
AudioLoader.loader_name: AudioLoader,
+ CsvLoader.loader_name: CsvLoader,
}
# Try adding optional loaders
diff --git a/cognee/modules/chunking/CsvChunker.py b/cognee/modules/chunking/CsvChunker.py
new file mode 100644
index 000000000..4ba4a969e
--- /dev/null
+++ b/cognee/modules/chunking/CsvChunker.py
@@ -0,0 +1,35 @@
+from cognee.shared.logging_utils import get_logger
+
+
+from cognee.tasks.chunks import chunk_by_row
+from cognee.modules.chunking.Chunker import Chunker
+from .models.DocumentChunk import DocumentChunk
+
+logger = get_logger()
+
+
+class CsvChunker(Chunker):
+ async def read(self):
+ async for content_text in self.get_text():
+ if content_text is None:
+ continue
+
+ for chunk_data in chunk_by_row(content_text, self.max_chunk_size):
+ if chunk_data["chunk_size"] <= self.max_chunk_size:
+ yield DocumentChunk(
+ id=chunk_data["chunk_id"],
+ text=chunk_data["text"],
+ chunk_size=chunk_data["chunk_size"],
+ is_part_of=self.document,
+ chunk_index=self.chunk_index,
+ cut_type=chunk_data["cut_type"],
+ contains=[],
+ metadata={
+ "index_fields": ["text"],
+ },
+ )
+ self.chunk_index += 1
+ else:
+ raise ValueError(
+ f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}"
+ )
diff --git a/cognee/modules/data/processing/document_types/CsvDocument.py b/cognee/modules/data/processing/document_types/CsvDocument.py
new file mode 100644
index 000000000..3381275bd
--- /dev/null
+++ b/cognee/modules/data/processing/document_types/CsvDocument.py
@@ -0,0 +1,33 @@
+import io
+import csv
+from typing import Type
+
+from cognee.modules.chunking.Chunker import Chunker
+from cognee.infrastructure.files.utils.open_data_file import open_data_file
+from .Document import Document
+
+
+class CsvDocument(Document):
+ type: str = "csv"
+ mime_type: str = "text/csv"
+
+ async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int):
+ async def get_text():
+ async with open_data_file(
+ self.raw_data_location, mode="r", encoding="utf-8", newline=""
+ ) as file:
+ content = file.read()
+ file_like_obj = io.StringIO(content)
+ reader = csv.DictReader(file_like_obj)
+
+ for row in reader:
+ pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
+ row_text = ", ".join(pairs)
+ if not row_text.strip():
+ break
+ yield row_text
+
+ chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
+
+ async for chunk in chunker.read():
+ yield chunk
diff --git a/cognee/modules/data/processing/document_types/__init__.py b/cognee/modules/data/processing/document_types/__init__.py
index 2e862f4ba..133dd53f8 100644
--- a/cognee/modules/data/processing/document_types/__init__.py
+++ b/cognee/modules/data/processing/document_types/__init__.py
@@ -4,3 +4,4 @@ from .TextDocument import TextDocument
from .ImageDocument import ImageDocument
from .AudioDocument import AudioDocument
from .UnstructuredDocument import UnstructuredDocument
+from .CsvDocument import CsvDocument
diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py
index cb7562422..2e0b82e8d 100644
--- a/cognee/modules/graph/cognee_graph/CogneeGraph.py
+++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py
@@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph):
def get_edges(self) -> List[Edge]:
return self.edges
+ async def _get_nodeset_subgraph(
+ self,
+ adapter,
+ node_type,
+ node_name,
+ ):
+ """Retrieve subgraph based on node type and name."""
+ logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
+ nodes_data, edges_data = await adapter.get_nodeset_subgraph(
+ node_type=node_type, node_name=node_name
+ )
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError(
+ message="Nodeset does not exist, or empty nodeset projected from the database."
+ )
+ return nodes_data, edges_data
+
+ async def _get_full_or_id_filtered_graph(
+ self,
+ adapter,
+ relevant_ids_to_filter,
+ ):
+ """Retrieve full or ID-filtered graph with fallback."""
+ if relevant_ids_to_filter is None:
+ logger.info("Retrieving full graph.")
+ nodes_data, edges_data = await adapter.get_graph_data()
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError(message="Empty graph projected from the database.")
+ return nodes_data, edges_data
+
+ get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
+ if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
+ logger.info("Retrieving ID-filtered graph from database.")
+ nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
+ else:
+ logger.info("Retrieving full graph from database.")
+ nodes_data, edges_data = await get_graph_data_fn()
+ if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
+ logger.warning(
+ "Id filtered graph returned empty, falling back to full graph retrieval."
+ )
+ logger.info("Retrieving full graph")
+ nodes_data, edges_data = await adapter.get_graph_data()
+
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError("Empty graph projected from the database.")
+ return nodes_data, edges_data
+
+ async def _get_filtered_graph(
+ self,
+ adapter,
+ memory_fragment_filter,
+ ):
+ """Retrieve graph filtered by attributes."""
+ logger.info("Retrieving graph filtered by memory fragment")
+ nodes_data, edges_data = await adapter.get_filtered_graph_data(
+ attribute_filters=memory_fragment_filter
+ )
+ if not nodes_data or not edges_data:
+ raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
+ return nodes_data, edges_data
+
async def project_graph_from_db(
self,
adapter: Union[GraphDBInterface],
@@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph):
memory_fragment_filter=[],
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ relevant_ids_to_filter: Optional[List[str]] = None,
+ triplet_distance_penalty: float = 3.5,
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidDimensionsError()
try:
+ if node_type is not None and node_name not in [None, [], ""]:
+ nodes_data, edges_data = await self._get_nodeset_subgraph(
+ adapter, node_type, node_name
+ )
+ elif len(memory_fragment_filter) == 0:
+ nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
+ adapter, relevant_ids_to_filter
+ )
+ else:
+ nodes_data, edges_data = await self._get_filtered_graph(
+ adapter, memory_fragment_filter
+ )
+
import time
start_time = time.time()
-
- # Determine projection strategy
- if node_type is not None and node_name not in [None, [], ""]:
- nodes_data, edges_data = await adapter.get_nodeset_subgraph(
- node_type=node_type, node_name=node_name
- )
- if not nodes_data or not edges_data:
- raise EntityNotFoundError(
- message="Nodeset does not exist, or empty nodetes projected from the database."
- )
- elif len(memory_fragment_filter) == 0:
- nodes_data, edges_data = await adapter.get_graph_data()
- if not nodes_data or not edges_data:
- raise EntityNotFoundError(message="Empty graph projected from the database.")
- else:
- nodes_data, edges_data = await adapter.get_filtered_graph_data(
- attribute_filters=memory_fragment_filter
- )
- if not nodes_data or not edges_data:
- raise EntityNotFoundError(
- message="Empty filtered graph projected from the database."
- )
-
# Process nodes
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
- self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
+ self.add_node(
+ Node(
+ str(node_id),
+ node_attributes,
+ dimension=node_dimension,
+ node_penalty=triplet_distance_penalty,
+ )
+ )
# Process edges
for source_id, target_id, relationship_type, properties in edges_data:
@@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph):
attributes=edge_attributes,
directed=directed,
dimension=edge_dimension,
+ edge_penalty=triplet_distance_penalty,
)
self.add_edge(edge)
diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py
index 0ca9c4fb9..62ef8d9fd 100644
--- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py
+++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py
@@ -20,13 +20,17 @@ class Node:
status: np.ndarray
def __init__(
- self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
+ self,
+ node_id: str,
+ attributes: Optional[Dict[str, Any]] = None,
+ dimension: int = 1,
+ node_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.id = node_id
self.attributes = attributes if attributes is not None else {}
- self.attributes["vector_distance"] = float("inf")
+ self.attributes["vector_distance"] = node_penalty
self.skeleton_neighbours = []
self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int)
@@ -105,13 +109,14 @@ class Edge:
attributes: Optional[Dict[str, Any]] = None,
directed: bool = True,
dimension: int = 1,
+ edge_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.node1 = node1
self.node2 = node2
self.attributes = attributes if attributes is not None else {}
- self.attributes["vector_distance"] = float("inf")
+ self.attributes["vector_distance"] = edge_penalty
self.directed = directed
self.status = np.ones(dimension, dtype=int)
diff --git a/cognee/modules/notebooks/operations/run_in_local_sandbox.py b/cognee/modules/notebooks/operations/run_in_local_sandbox.py
index 071deafb7..46499186e 100644
--- a/cognee/modules/notebooks/operations/run_in_local_sandbox.py
+++ b/cognee/modules/notebooks/operations/run_in_local_sandbox.py
@@ -2,6 +2,8 @@ import io
import sys
import traceback
+import cognee
+
def wrap_in_async_handler(user_code: str) -> str:
return (
@@ -34,6 +36,7 @@ def run_in_local_sandbox(code, environment=None, loop=None):
environment["print"] = customPrintFunction
environment["running_loop"] = loop
+ environment["cognee"] = cognee
try:
exec(code, environment)
diff --git a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py
index 45e32936a..34d7a946a 100644
--- a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py
+++ b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py
@@ -2,7 +2,7 @@ import os
import difflib
from cognee.shared.logging_utils import get_logger
from collections import deque
-from typing import List, Tuple, Dict, Optional, Any, Union
+from typing import List, Tuple, Dict, Optional, Any, Union, IO
from rdflib import Graph, URIRef, RDF, RDFS, OWL
from cognee.modules.ontology.exceptions import (
@@ -26,44 +26,76 @@ class RDFLibOntologyResolver(BaseOntologyResolver):
def __init__(
self,
- ontology_file: Optional[Union[str, List[str]]] = None,
+ ontology_file: Optional[Union[str, List[str], IO, List[IO]]] = None,
matching_strategy: Optional[MatchingStrategy] = None,
) -> None:
super().__init__(matching_strategy)
self.ontology_file = ontology_file
try:
- files_to_load = []
+ self.graph = None
if ontology_file is not None:
- if isinstance(ontology_file, str):
+ files_to_load = []
+ file_objects = []
+
+ if hasattr(ontology_file, "read"):
+ file_objects = [ontology_file]
+ elif isinstance(ontology_file, str):
files_to_load = [ontology_file]
elif isinstance(ontology_file, list):
- files_to_load = ontology_file
+ if all(hasattr(item, "read") for item in ontology_file):
+ file_objects = ontology_file
+ else:
+ files_to_load = ontology_file
else:
raise ValueError(
- f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}"
+ f"ontology_file must be a string, list of strings, file-like object, list of file-like objects, or None. Got: {type(ontology_file)}"
)
- if files_to_load:
- self.graph = Graph()
- loaded_files = []
- for file_path in files_to_load:
- if os.path.exists(file_path):
- self.graph.parse(file_path)
- loaded_files.append(file_path)
- logger.info("Ontology loaded successfully from file: %s", file_path)
- else:
- logger.warning(
- "Ontology file '%s' not found. Skipping this file.",
- file_path,
+ if file_objects:
+ self.graph = Graph()
+ loaded_objects = []
+ for file_obj in file_objects:
+ try:
+ content = file_obj.read()
+ self.graph.parse(data=content, format="xml")
+ loaded_objects.append(file_obj)
+ logger.info("Ontology loaded successfully from file object")
+ except Exception as e:
+ logger.warning("Failed to parse ontology file object: %s", str(e))
+
+ if not loaded_objects:
+ logger.info(
+ "No valid ontology file objects found. No owl ontology will be attached to the graph."
)
+ self.graph = None
+ else:
+ logger.info("Total ontology file objects loaded: %d", len(loaded_objects))
- if not loaded_files:
- logger.info(
- "No valid ontology files found. No owl ontology will be attached to the graph."
- )
- self.graph = None
+ elif files_to_load:
+ self.graph = Graph()
+ loaded_files = []
+ for file_path in files_to_load:
+ if os.path.exists(file_path):
+ self.graph.parse(file_path)
+ loaded_files.append(file_path)
+ logger.info("Ontology loaded successfully from file: %s", file_path)
+ else:
+ logger.warning(
+ "Ontology file '%s' not found. Skipping this file.",
+ file_path,
+ )
+
+ if not loaded_files:
+ logger.info(
+ "No valid ontology files found. No owl ontology will be attached to the graph."
+ )
+ self.graph = None
+ else:
+ logger.info("Total ontology files loaded: %d", len(loaded_files))
else:
- logger.info("Total ontology files loaded: %d", len(loaded_files))
+ logger.info(
+ "No ontology file provided. No owl ontology will be attached to the graph."
+ )
else:
logger.info(
"No ontology file provided. No owl ontology will be attached to the graph."
diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py
index b07d11fd2..fc49a139b 100644
--- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py
+++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py
@@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
async def get_completion(
diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py
index eb8f502cb..70fcb6cdb 100644
--- a/cognee/modules/retrieval/graph_completion_cot_retriever.py
+++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py
@@ -65,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@@ -74,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
self.validation_system_prompt_path = validation_system_prompt_path
self.validation_user_prompt_path = validation_user_prompt_path
diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py
index df77a11ac..89e9e47ce 100644
--- a/cognee/modules/retrieval/graph_completion_retriever.py
+++ b/cognee/modules/retrieval/graph_completion_retriever.py
@@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction
@@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
self.system_prompt_path = system_prompt_path
self.system_prompt = system_prompt
self.top_k = top_k if top_k is not None else 5
+ self.wide_search_top_k = wide_search_top_k
self.node_type = node_type
self.node_name = node_name
+ self.triplet_distance_penalty = triplet_distance_penalty
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""
@@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
collections=vector_index_collections or None,
node_type=self.node_type,
node_name=self.node_name,
+ wide_search_top_k=self.wide_search_top_k,
+ triplet_distance_penalty=self.triplet_distance_penalty,
)
return found_triplets
@@ -141,6 +147,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
return triplets
+ async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
+ context = await self.resolve_edges_to_text(triplets)
+ return context
+
async def get_completion(
self,
query: str,
diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py
index 051f39b22..e31ad126e 100644
--- a/cognee/modules/retrieval/graph_summary_completion_retriever.py
+++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py
@@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(
@@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
self.summarize_prompt_path = summarize_prompt_path
diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py
index f3da02c15..87d2ab009 100644
--- a/cognee/modules/retrieval/temporal_retriever.py
+++ b/cognee/modules/retrieval/temporal_retriever.py
@@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py
index f8bdbb97d..2f8a545f7 100644
--- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py
+++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py
@@ -58,6 +58,8 @@ async def get_memory_fragment(
properties_to_project: Optional[List[str]] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ relevant_ids_to_filter: Optional[List[str]] = None,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
if properties_to_project is None:
@@ -74,6 +76,8 @@ async def get_memory_fragment(
edge_properties_to_project=["relationship_name", "edge_text"],
node_type=node_type,
node_name=node_name,
+ relevant_ids_to_filter=relevant_ids_to_filter,
+ triplet_distance_penalty=triplet_distance_penalty,
)
except EntityNotFoundError:
@@ -95,6 +99,8 @@ async def brute_force_triplet_search(
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Edge]:
"""
Performs a brute force search to retrieve the top triplets from the graph.
@@ -107,6 +113,8 @@ async def brute_force_triplet_search(
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
node_type: node type to filter
node_name: node name to filter
+ wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
+ triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
Returns:
list: The top triplet results.
@@ -116,10 +124,10 @@ async def brute_force_triplet_search(
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
- if memory_fragment is None:
- memory_fragment = await get_memory_fragment(
- properties_to_project, node_type=node_type, node_name=node_name
- )
+ # Setting wide search limit based on the parameters
+ non_global_search = node_name is None
+
+ wide_search_limit = wide_search_top_k if non_global_search else None
if collections is None:
collections = [
@@ -140,7 +148,7 @@ async def brute_force_triplet_search(
async def search_in_collection(collection_name: str):
try:
return await vector_engine.search(
- collection_name=collection_name, query_vector=query_vector, limit=None
+ collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
)
except CollectionNotFoundError:
return []
@@ -156,15 +164,38 @@ async def brute_force_triplet_search(
return []
# Final statistics
- projection_time = time.time() - start_time
+ vector_collection_search_time = time.time() - start_time
logger.info(
- f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
+ f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
)
node_distances = {collection: result for collection, result in zip(collections, results)}
edge_distances = node_distances.get("EdgeType_relationship_name", None)
+ if wide_search_limit is not None:
+ relevant_ids_to_filter = list(
+ {
+ str(getattr(scored_node, "id"))
+ for collection_name, score_collection in node_distances.items()
+ if collection_name != "EdgeType_relationship_name"
+ and isinstance(score_collection, (list, tuple))
+ for scored_node in score_collection
+ if getattr(scored_node, "id", None)
+ }
+ )
+ else:
+ relevant_ids_to_filter = None
+
+ if memory_fragment is None:
+ memory_fragment = await get_memory_fragment(
+ properties_to_project=properties_to_project,
+ node_type=node_type,
+ node_name=node_name,
+ relevant_ids_to_filter=relevant_ids_to_filter,
+ triplet_distance_penalty=triplet_distance_penalty,
+ )
+
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances
diff --git a/cognee/modules/search/methods/get_search_type_tools.py b/cognee/modules/search/methods/get_search_type_tools.py
index 72e2db89a..165ec379b 100644
--- a/cognee/modules/search/methods/get_search_type_tools.py
+++ b/cognee/modules/search/methods/get_search_type_tools.py
@@ -37,6 +37,8 @@ async def get_search_type_tools(
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> list:
search_tasks: dict[SearchType, List[Callable]] = {
SearchType.SUMMARIES: [
@@ -67,6 +69,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
@@ -75,6 +79,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_COMPLETION_COT: [
@@ -85,6 +91,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
@@ -93,6 +101,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
@@ -103,6 +113,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
@@ -111,6 +123,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_SUMMARY_COMPLETION: [
@@ -121,6 +135,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
@@ -129,6 +145,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.CODE: [
@@ -145,8 +163,16 @@ async def get_search_type_tools(
],
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
SearchType.TEMPORAL: [
- TemporalRetriever(top_k=top_k).get_completion,
- TemporalRetriever(top_k=top_k).get_context,
+ TemporalRetriever(
+ top_k=top_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
+ ).get_completion,
+ TemporalRetriever(
+ top_k=top_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
+ ).get_context,
],
SearchType.CHUNKS_LEXICAL: (
lambda _r=JaccardChunksRetriever(top_k=top_k): [
diff --git a/cognee/modules/search/methods/no_access_control_search.py b/cognee/modules/search/methods/no_access_control_search.py
index fcb02da46..3a703bbc9 100644
--- a/cognee/modules/search/methods/no_access_control_search.py
+++ b/cognee/modules/search/methods/no_access_control_search.py
@@ -24,6 +24,8 @@ async def no_access_control_search(
last_k: Optional[int] = None,
only_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
search_tools = await get_search_type_tools(
query_type=query_type,
@@ -35,6 +37,8 @@ async def no_access_control_search(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
graph_engine = await get_graph_engine()
is_empty = await graph_engine.is_empty()
diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py
index b4278424b..9f180d607 100644
--- a/cognee/modules/search/methods/search.py
+++ b/cognee/modules/search/methods/search.py
@@ -47,6 +47,8 @@ async def search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[CombinedSearchResult, List[SearchResult]]:
"""
@@ -90,6 +92,8 @@ async def search(
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
else:
search_results = [
@@ -105,6 +109,8 @@ async def search(
last_k=last_k,
only_context=only_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
]
@@ -219,6 +225,8 @@ async def authorized_search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
@@ -246,6 +254,8 @@ async def authorized_search(
last_k=last_k,
only_context=True,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
context = {}
@@ -267,6 +277,8 @@ async def authorized_search(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
@@ -306,6 +318,7 @@ async def authorized_search(
last_k=last_k,
only_context=only_context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
)
return search_results
@@ -325,6 +338,8 @@ async def search_in_datasets_context(
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@@ -345,6 +360,8 @@ async def search_in_datasets_context(
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
+ wide_search_top_k: Optional[int] = 100,
+ triplet_distance_penalty: Optional[float] = 3.5,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
@@ -378,6 +395,8 @@ async def search_in_datasets_context(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
@@ -413,6 +432,8 @@ async def search_in_datasets_context(
only_context=only_context,
context=context,
session_id=session_id,
+ wide_search_top_k=wide_search_top_k,
+ triplet_distance_penalty=triplet_distance_penalty,
)
)
diff --git a/cognee/tasks/chunks/__init__.py b/cognee/tasks/chunks/__init__.py
index 22ce96be8..37d4de73e 100644
--- a/cognee/tasks/chunks/__init__.py
+++ b/cognee/tasks/chunks/__init__.py
@@ -1,4 +1,5 @@
from .chunk_by_word import chunk_by_word
from .chunk_by_sentence import chunk_by_sentence
from .chunk_by_paragraph import chunk_by_paragraph
+from .chunk_by_row import chunk_by_row
from .remove_disconnected_chunks import remove_disconnected_chunks
diff --git a/cognee/tasks/chunks/chunk_by_row.py b/cognee/tasks/chunks/chunk_by_row.py
new file mode 100644
index 000000000..8daf13689
--- /dev/null
+++ b/cognee/tasks/chunks/chunk_by_row.py
@@ -0,0 +1,94 @@
+from typing import Any, Dict, Iterator
+from uuid import NAMESPACE_OID, uuid5
+
+from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
+
+
+def _get_pair_size(pair_text: str) -> int:
+ """
+ Calculate the size of a given text in terms of tokens.
+
+ If an embedding engine's tokenizer is available, count the tokens for the provided word.
+ If the tokenizer is not available, assume the word counts as one token.
+
+ Parameters:
+ -----------
+
+ - pair_text (str): The key:value pair text for which the token size is to be calculated.
+
+ Returns:
+ --------
+
+ - int: The number of tokens representing the text, typically an integer, depending
+ on the tokenizer's output.
+ """
+ embedding_engine = get_embedding_engine()
+ if embedding_engine.tokenizer:
+ return embedding_engine.tokenizer.count_tokens(pair_text)
+ else:
+ return 3
+
+
+def chunk_by_row(
+ data: str,
+ max_chunk_size,
+) -> Iterator[Dict[str, Any]]:
+ """
+ Chunk the input text by row while enabling exact text reconstruction.
+
+ This function divides the given text data into smaller chunks on a line-by-line basis,
+ ensuring that the size of each chunk is less than or equal to the specified maximum
+ chunk size. It guarantees that when the generated chunks are concatenated, they
+ reproduce the original text accurately. The tokenization process is handled by
+ adapters compatible with the vector engine's embedding model.
+
+ Parameters:
+ -----------
+
+ - data (str): The input text to be chunked.
+ - max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or
+ words.
+ """
+ current_chunk_list = []
+ chunk_index = 0
+ current_chunk_size = 0
+
+ lines = data.split("\n\n")
+ for line in lines:
+ pairs_text = line.split(", ")
+
+ for pair_text in pairs_text:
+ pair_size = _get_pair_size(pair_text)
+ if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size):
+ # Yield current cut chunk
+ current_chunk = ", ".join(current_chunk_list)
+ chunk_dict = {
+ "text": current_chunk,
+ "chunk_size": current_chunk_size,
+ "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
+ "chunk_index": chunk_index,
+ "cut_type": "row_cut",
+ }
+
+ yield chunk_dict
+
+ # Start new chunk with current pair text
+ current_chunk_list = []
+ current_chunk_size = 0
+ chunk_index += 1
+
+ current_chunk_list.append(pair_text)
+ current_chunk_size += pair_size
+
+ # Yield row chunk
+ current_chunk = ", ".join(current_chunk_list)
+ if current_chunk:
+ chunk_dict = {
+ "text": current_chunk,
+ "chunk_size": current_chunk_size,
+ "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
+ "chunk_index": chunk_index,
+ "cut_type": "row_end",
+ }
+
+ yield chunk_dict
diff --git a/cognee/tasks/documents/classify_documents.py b/cognee/tasks/documents/classify_documents.py
index 9fa512906..e4f13ebd1 100644
--- a/cognee/tasks/documents/classify_documents.py
+++ b/cognee/tasks/documents/classify_documents.py
@@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import (
ImageDocument,
TextDocument,
UnstructuredDocument,
+ CsvDocument,
)
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.engine.utils.generate_node_id import generate_node_id
@@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError
EXTENSION_TO_DOCUMENT_CLASS = {
"pdf": PdfDocument, # Text documents
"txt": TextDocument,
+ "csv": CsvDocument,
"docx": UnstructuredDocument,
"doc": UnstructuredDocument,
"odt": UnstructuredDocument,
diff --git a/cognee/tests/integration/documents/CsvDocument_test.py b/cognee/tests/integration/documents/CsvDocument_test.py
new file mode 100644
index 000000000..421bb81bd
--- /dev/null
+++ b/cognee/tests/integration/documents/CsvDocument_test.py
@@ -0,0 +1,70 @@
+import os
+import sys
+import uuid
+import pytest
+import pathlib
+from unittest.mock import patch
+
+from cognee.modules.chunking.CsvChunker import CsvChunker
+from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument
+from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
+from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
+
+chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row")
+
+
+GROUND_TRUTH = {
+ "chunk_size_10": [
+ {"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0},
+ {"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1},
+ {"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2},
+ {"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3},
+ ],
+ "chunk_size_128": [
+ {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0},
+ {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1},
+ ],
+}
+
+
+@pytest.mark.parametrize(
+ "input_file,chunk_size",
+ [("example_with_header.csv", 10), ("example_with_header.csv", 128)],
+)
+@patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine)
+@pytest.mark.asyncio
+async def test_CsvDocument(mock_engine, input_file, chunk_size):
+ # Define file paths of test data
+ csv_file_path = os.path.join(
+ pathlib.Path(__file__).parent.parent.parent,
+ "test_data",
+ input_file,
+ )
+
+ # Define test documents
+ csv_document = CsvDocument(
+ id=uuid.uuid4(),
+ name="example_with_header.csv",
+ raw_data_location=csv_file_path,
+ external_metadata="",
+ mime_type="text/csv",
+ )
+
+ # TEST CSV
+ ground_truth_key = f"chunk_size_{chunk_size}"
+ async for ground_truth, row_data in async_gen_zip(
+ GROUND_TRUTH[ground_truth_key],
+ csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size),
+ ):
+ assert ground_truth["token_count"] == row_data.chunk_size, (
+ f'{ground_truth["token_count"] = } != {row_data.chunk_size = }'
+ )
+ assert ground_truth["len_text"] == len(row_data.text), (
+ f'{ground_truth["len_text"] = } != {len(row_data.text) = }'
+ )
+ assert ground_truth["cut_type"] == row_data.cut_type, (
+ f'{ground_truth["cut_type"] = } != {row_data.cut_type = }'
+ )
+ assert ground_truth["chunk_index"] == row_data.chunk_index, (
+ f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }'
+ )
diff --git a/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py b/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py
index 156cc87a4..af2595b14 100644
--- a/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py
+++ b/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py
@@ -5,7 +5,7 @@ from cognee.tasks.web_scraper import DefaultUrlCrawler
@pytest.mark.asyncio
async def test_fetch():
crawler = DefaultUrlCrawler()
- url = "https://en.wikipedia.org/wiki/Large_language_model"
+ url = "http://example.com/"
results = await crawler.fetch_urls(url)
assert len(results) == 1
assert isinstance(results, dict)
diff --git a/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py b/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py
index 946ce8378..5db9b58ce 100644
--- a/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py
+++ b/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py
@@ -11,7 +11,7 @@ skip_in_ci = pytest.mark.skipif(
@skip_in_ci
@pytest.mark.asyncio
async def test_fetch():
- url = "https://en.wikipedia.org/wiki/Large_language_model"
+ url = "http://example.com/"
results = await fetch_with_tavily(url)
assert isinstance(results, dict)
assert len(results) == 1
diff --git a/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py b/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py
index d91b075aa..200f40a94 100644
--- a/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py
+++ b/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py
@@ -14,9 +14,7 @@ async def test_url_saves_as_html_file():
await cognee.prune.prune_system(metadata=True)
try:
- original_file_path = await save_data_item_to_storage(
- "https://en.wikipedia.org/wiki/Large_language_model"
- )
+ original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
file = Path(file_path)
@@ -44,9 +42,7 @@ async def test_saved_html_is_valid():
await cognee.prune.prune_system(metadata=True)
try:
- original_file_path = await save_data_item_to_storage(
- "https://en.wikipedia.org/wiki/Large_language_model"
- )
+ original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
content = Path(file_path).read_text()
@@ -72,7 +68,7 @@ async def test_add_url():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
- await cognee.add("https://en.wikipedia.org/wiki/Large_language_model")
+ await cognee.add("http://example.com/")
skip_in_ci = pytest.mark.skipif(
@@ -88,7 +84,7 @@ async def test_add_url_with_tavily():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
- await cognee.add("https://en.wikipedia.org/wiki/Large_language_model")
+ await cognee.add("http://example.com/")
@pytest.mark.asyncio
@@ -98,7 +94,7 @@ async def test_add_url_without_incremental_loading():
try:
await cognee.add(
- "https://en.wikipedia.org/wiki/Large_language_model",
+ "http://example.com/",
incremental_loading=False,
)
except Exception as e:
@@ -112,7 +108,7 @@ async def test_add_url_with_incremental_loading():
try:
await cognee.add(
- "https://en.wikipedia.org/wiki/Large_language_model",
+ "http://example.com/",
incremental_loading=True,
)
except Exception as e:
@@ -125,7 +121,7 @@ async def test_add_url_can_define_preferred_loader_as_list_of_str():
await cognee.prune.prune_system(metadata=True)
await cognee.add(
- "https://en.wikipedia.org/wiki/Large_language_model",
+ "http://example.com/",
preferred_loaders=["beautiful_soup_loader"],
)
@@ -144,7 +140,7 @@ async def test_add_url_with_extraction_rules():
try:
await cognee.add(
- "https://en.wikipedia.org/wiki/Large_language_model",
+ "http://example.com/",
preferred_loaders={"beautiful_soup_loader": {"extraction_rules": extraction_rules}},
)
except Exception as e:
@@ -163,9 +159,7 @@ async def test_loader_is_none_by_default():
}
try:
- original_file_path = await save_data_item_to_storage(
- "https://en.wikipedia.org/wiki/Large_language_model"
- )
+ original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
file = Path(file_path)
@@ -196,9 +190,7 @@ async def test_beautiful_soup_loader_is_selected_loader_if_preferred_loader_prov
}
try:
- original_file_path = await save_data_item_to_storage(
- "https://en.wikipedia.org/wiki/Large_language_model"
- )
+ original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
file = Path(file_path)
@@ -225,9 +217,7 @@ async def test_beautiful_soup_loader_works_with_and_without_arguments():
await cognee.prune.prune_system(metadata=True)
try:
- original_file_path = await save_data_item_to_storage(
- "https://en.wikipedia.org/wiki/Large_language_model"
- )
+ original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
file = Path(file_path)
@@ -263,9 +253,7 @@ async def test_beautiful_soup_loader_successfully_loads_file_if_required_args_pr
await cognee.prune.prune_system(metadata=True)
try:
- original_file_path = await save_data_item_to_storage(
- "https://en.wikipedia.org/wiki/Large_language_model"
- )
+ original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
file = Path(file_path)
@@ -302,9 +290,7 @@ async def test_beautiful_soup_loads_file_successfully():
}
try:
- original_file_path = await save_data_item_to_storage(
- "https://en.wikipedia.org/wiki/Large_language_model"
- )
+ original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
original_file = Path(file_path)
diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py
index ab68a8ef1..ddffe53a4 100644
--- a/cognee/tests/test_cognee_server_start.py
+++ b/cognee/tests/test_cognee_server_start.py
@@ -7,6 +7,7 @@ import requests
from pathlib import Path
import sys
import uuid
+import json
class TestCogneeServerStart(unittest.TestCase):
@@ -90,12 +91,71 @@ class TestCogneeServerStart(unittest.TestCase):
)
}
- payload = {"datasets": [dataset_name]}
+ ontology_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
+ payload = {"datasets": [dataset_name], "ontology_key": [ontology_key]}
add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50)
if add_response.status_code not in [200, 201]:
add_response.raise_for_status()
+ ontology_content = b"""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ A failure caused by physical components.
+
+
+
+
+ An error caused by software logic or configuration.
+
+
+
+ A human being or individual.
+
+
+
+
+ Programmers
+
+
+
+
+
+ Hardware Problem
+
+
+ """
+
+ ontology_response = requests.post(
+ "http://127.0.0.1:8000/api/v1/ontologies",
+ headers=headers,
+ files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
+ data={
+ "ontology_key": json.dumps([ontology_key]),
+ "description": json.dumps(["Test ontology"]),
+ },
+ )
+ self.assertEqual(ontology_response.status_code, 200)
+
# Cognify request
url = "http://127.0.0.1:8000/api/v1/cognify"
headers = {
@@ -107,6 +167,29 @@ class TestCogneeServerStart(unittest.TestCase):
if cognify_response.status_code not in [200, 201]:
cognify_response.raise_for_status()
+ datasets_response = requests.get("http://127.0.0.1:8000/api/v1/datasets", headers=headers)
+
+ datasets = datasets_response.json()
+ dataset_id = None
+ for dataset in datasets:
+ if dataset["name"] == dataset_name:
+ dataset_id = dataset["id"]
+ break
+
+ graph_response = requests.get(
+ f"http://127.0.0.1:8000/api/v1/datasets/{dataset_id}/graph", headers=headers
+ )
+ self.assertEqual(graph_response.status_code, 200)
+
+ graph_data = graph_response.json()
+ ontology_nodes = [
+ node for node in graph_data.get("nodes") if node.get("properties").get("ontology_valid")
+ ]
+
+ self.assertGreater(
+ len(ontology_nodes), 0, "No ontology nodes found - ontology was not integrated"
+ )
+
# TODO: Add test to verify cognify pipeline is complete before testing search
# Search request
diff --git a/cognee/tests/test_data/example_with_header.csv b/cognee/tests/test_data/example_with_header.csv
new file mode 100644
index 000000000..dc900e5ef
--- /dev/null
+++ b/cognee/tests/test_data/example_with_header.csv
@@ -0,0 +1,3 @@
+id,name,age,city,country
+1,Eric,30,Beijing,China
+2,Joe,35,Berlin,Germany
diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py
new file mode 100644
index 000000000..af3a4d90e
--- /dev/null
+++ b/cognee/tests/unit/api/test_ontology_endpoint.py
@@ -0,0 +1,272 @@
+import pytest
+import uuid
+from fastapi.testclient import TestClient
+from unittest.mock import patch, Mock, AsyncMock
+from types import SimpleNamespace
+import importlib
+from cognee.api.client import app
+
+gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
+
+
+@pytest.fixture
+def client():
+ return TestClient(app)
+
+
+@pytest.fixture
+def mock_user():
+ user = Mock()
+ user.id = "test-user-123"
+ return user
+
+
+@pytest.fixture
+def mock_default_user():
+ """Mock default user for testing."""
+ return SimpleNamespace(
+ id=str(uuid.uuid4()),
+ email="default@example.com",
+ is_active=True,
+ tenant_id=str(uuid.uuid4()),
+ )
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_ontology_success(mock_get_default_user, client, mock_default_user):
+ """Test successful ontology upload"""
+ import json
+
+ mock_get_default_user.return_value = mock_default_user
+ ontology_content = (
+ b""
+ )
+ unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
+
+ response = client.post(
+ "/api/v1/ontologies",
+ files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
+ data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])},
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
+ assert "uploaded_at" in data["uploaded_ontologies"][0]
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user):
+ """Test 400 response for non-.owl files"""
+ mock_get_default_user.return_value = mock_default_user
+ unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
+ response = client.post(
+ "/api/v1/ontologies",
+ files={"ontology_file": ("test.txt", b"not xml")},
+ data={"ontology_key": unique_key},
+ )
+ assert response.status_code == 400
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user):
+ """Test 400 response for missing file or key"""
+ import json
+
+ mock_get_default_user.return_value = mock_default_user
+ # Missing file
+ response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])})
+ assert response.status_code == 400
+
+ # Missing key
+ response = client.post(
+ "/api/v1/ontologies", files=[("ontology_file", ("test.owl", b"xml", "application/xml"))]
+ )
+ assert response.status_code == 400
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user):
+ """Test behavior when default user is provided (no explicit authentication)"""
+ import json
+
+ unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
+ mock_get_default_user.return_value = mock_default_user
+ response = client.post(
+ "/api/v1/ontologies",
+ files=[("ontology_file", ("test.owl", b"", "application/xml"))],
+ data={"ontology_key": json.dumps([unique_key])},
+ )
+
+ # The current system provides a default user when no explicit authentication is given
+ # This test verifies the system works with conditional authentication
+ assert response.status_code == 200
+ data = response.json()
+ assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
+ assert "uploaded_at" in data["uploaded_ontologies"][0]
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user):
+ """Test uploading multiple ontology files in single request"""
+ import io
+
+ mock_get_default_user.return_value = mock_default_user
+ # Create mock files
+ file1_content = b""
+ file2_content = b""
+
+ files = [
+ ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
+ ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
+ ]
+ data = {
+ "ontology_key": '["vehicles", "manufacturers"]',
+ "descriptions": '["Base vehicles", "Car manufacturers"]',
+ }
+
+ response = client.post("/api/v1/ontologies", files=files, data=data)
+
+ assert response.status_code == 200
+ result = response.json()
+ assert "uploaded_ontologies" in result
+ assert len(result["uploaded_ontologies"]) == 2
+ assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles"
+ assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers"
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user):
+ """Test that upload endpoint accepts array parameters"""
+ import io
+ import json
+
+ mock_get_default_user.return_value = mock_default_user
+ file_content = b""
+
+ files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))]
+ data = {
+ "ontology_key": json.dumps(["single_key"]),
+ "descriptions": json.dumps(["Single ontology"]),
+ }
+
+ response = client.post("/api/v1/ontologies", files=files, data=data)
+
+ assert response.status_code == 200
+ result = response.json()
+ assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key"
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user):
+ """Test cognify endpoint accepts multiple ontology keys"""
+ payload = {
+ "datasets": ["test_dataset"],
+ "ontology_key": ["ontology1", "ontology2"], # Array instead of string
+ "run_in_background": False,
+ }
+
+ response = client.post("/api/v1/cognify", json=payload)
+
+ # Should not fail due to ontology_key type
+ assert response.status_code in [200, 400, 409] # May fail for other reasons, not type
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user):
+ """Test complete workflow: upload multiple ontologies → cognify with multiple keys"""
+ import io
+ import json
+
+ mock_get_default_user.return_value = mock_default_user
+ # Step 1: Upload multiple ontologies
+ file1_content = b"""
+
+
+ """
+
+ file2_content = b"""
+
+
+ """
+
+ files = [
+ ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
+ ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
+ ]
+ data = {
+ "ontology_key": json.dumps(["vehicles", "manufacturers"]),
+ "descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]),
+ }
+
+ upload_response = client.post("/api/v1/ontologies", files=files, data=data)
+ assert upload_response.status_code == 200
+
+ # Step 2: Verify ontologies are listed
+ list_response = client.get("/api/v1/ontologies")
+ assert list_response.status_code == 200
+ ontologies = list_response.json()
+ assert "vehicles" in ontologies
+ assert "manufacturers" in ontologies
+
+ # Step 3: Test cognify with multiple ontologies
+ cognify_payload = {
+ "datasets": ["test_dataset"],
+ "ontology_key": ["vehicles", "manufacturers"],
+ "run_in_background": False,
+ }
+
+ cognify_response = client.post("/api/v1/cognify", json=cognify_payload)
+ # Should not fail due to ontology handling (may fail for dataset reasons)
+ assert cognify_response.status_code != 400 # Not a validation error
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_multifile_error_handling(mock_get_default_user, client, mock_default_user):
+ """Test error handling for invalid multifile uploads"""
+ import io
+ import json
+
+ # Test mismatched array lengths
+ file_content = b""
+ files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))]
+ data = {
+ "ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file
+ "descriptions": json.dumps(["desc1"]),
+ }
+
+ response = client.post("/api/v1/ontologies", files=files, data=data)
+ assert response.status_code == 400
+ assert "Number of keys must match number of files" in response.json()["error"]
+
+ # Test duplicate keys
+ files = [
+ ("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")),
+ ("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")),
+ ]
+ data = {
+ "ontology_key": json.dumps(["duplicate", "duplicate"]),
+ "descriptions": json.dumps(["desc1", "desc2"]),
+ }
+
+ response = client.post("/api/v1/ontologies", files=files, data=data)
+ assert response.status_code == 400
+ assert "Duplicate ontology keys not allowed" in response.json()["error"]
+
+
+@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user):
+ """Test cognify with non-existent ontology key"""
+ mock_get_default_user.return_value = mock_default_user
+
+ payload = {
+ "datasets": ["test_dataset"],
+ "ontology_key": ["nonexistent_key"],
+ "run_in_background": False,
+ }
+
+ response = client.post("/api/v1/cognify", json=payload)
+ assert response.status_code == 409
+ assert "Ontology key 'nonexistent_key' not found" in response.json()["error"]
diff --git a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py
index 37ba113b5..1d2b79cf9 100644
--- a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py
+++ b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py
@@ -9,7 +9,7 @@ def test_node_initialization():
"""Test that a Node is initialized correctly."""
node = Node("node1", {"attr1": "value1"}, dimension=2)
assert node.id == "node1"
- assert node.attributes == {"attr1": "value1", "vector_distance": np.inf}
+ assert node.attributes == {"attr1": "value1", "vector_distance": 3.5}
assert len(node.status) == 2
assert np.all(node.status == 1)
@@ -96,7 +96,7 @@ def test_edge_initialization():
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
assert edge.node1 == node1
assert edge.node2 == node2
- assert edge.attributes == {"vector_distance": np.inf, "weight": 10}
+ assert edge.attributes == {"vector_distance": 3.5, "weight": 10}
assert edge.directed is False
assert len(edge.status) == 2
assert np.all(edge.status == 1)
diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py
index 6888648c3..711479387 100644
--- a/cognee/tests/unit/modules/graph/cognee_graph_test.py
+++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py
@@ -1,4 +1,5 @@
import pytest
+from unittest.mock import AsyncMock
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
@@ -11,6 +12,30 @@ def setup_graph():
return CogneeGraph()
+@pytest.fixture
+def mock_adapter():
+ """Fixture to create a mock adapter for database operations."""
+ adapter = AsyncMock()
+ return adapter
+
+
+@pytest.fixture
+def mock_vector_engine():
+ """Fixture to create a mock vector engine."""
+ engine = AsyncMock()
+ engine.search = AsyncMock()
+ return engine
+
+
+class MockScoredResult:
+ """Mock class for vector search results."""
+
+ def __init__(self, id, score, payload=None):
+ self.id = id
+ self.score = score
+ self.payload = payload or {}
+
+
def test_add_node_success(setup_graph):
"""Test successful addition of a node."""
graph = setup_graph
@@ -73,3 +98,433 @@ def test_get_edges_nonexistent_node(setup_graph):
graph = setup_graph
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
graph.get_edges_from_node("nonexistent")
+
+
+@pytest.mark.asyncio
+async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter):
+ """Test projecting a full graph from database."""
+ graph = setup_graph
+
+ nodes_data = [
+ ("1", {"name": "Node1", "description": "First node"}),
+ ("2", {"name": "Node2", "description": "Second node"}),
+ ]
+ edges_data = [
+ ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
+ ]
+
+ mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
+
+ await graph.project_graph_from_db(
+ adapter=mock_adapter,
+ node_properties_to_project=["name", "description"],
+ edge_properties_to_project=["relationship_name"],
+ )
+
+ assert len(graph.nodes) == 2
+ assert len(graph.edges) == 1
+ assert graph.get_node("1") is not None
+ assert graph.get_node("2") is not None
+ assert graph.edges[0].node1.id == "1"
+ assert graph.edges[0].node2.id == "2"
+
+
+@pytest.mark.asyncio
+async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter):
+ """Test projecting an ID-filtered graph from database."""
+ graph = setup_graph
+
+ nodes_data = [
+ ("1", {"name": "Node1"}),
+ ("2", {"name": "Node2"}),
+ ]
+ edges_data = [
+ ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
+ ]
+
+ mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
+
+ await graph.project_graph_from_db(
+ adapter=mock_adapter,
+ node_properties_to_project=["name"],
+ edge_properties_to_project=["relationship_name"],
+ relevant_ids_to_filter=["1", "2"],
+ )
+
+ assert len(graph.nodes) == 2
+ assert len(graph.edges) == 1
+ mock_adapter.get_id_filtered_graph_data.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter):
+ """Test projecting a nodeset subgraph filtered by node type and name."""
+ graph = setup_graph
+
+ nodes_data = [
+ ("1", {"name": "Alice", "type": "Person"}),
+ ("2", {"name": "Bob", "type": "Person"}),
+ ]
+ edges_data = [
+ ("1", "2", "KNOWS", {"relationship_name": "knows"}),
+ ]
+
+ mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data))
+
+ await graph.project_graph_from_db(
+ adapter=mock_adapter,
+ node_properties_to_project=["name", "type"],
+ edge_properties_to_project=["relationship_name"],
+ node_type="Person",
+ node_name=["Alice"],
+ )
+
+ assert len(graph.nodes) == 2
+ assert graph.get_node("1") is not None
+ assert len(graph.edges) == 1
+ mock_adapter.get_nodeset_subgraph.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
+ """Test projecting empty graph raises EntityNotFoundError."""
+ graph = setup_graph
+
+ mock_adapter.get_graph_data = AsyncMock(return_value=([], []))
+
+ with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."):
+ await graph.project_graph_from_db(
+ adapter=mock_adapter,
+ node_properties_to_project=["name"],
+ edge_properties_to_project=[],
+ )
+
+
+@pytest.mark.asyncio
+async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
+ """Test that edges referencing missing nodes raise error."""
+ graph = setup_graph
+
+ nodes_data = [
+ ("1", {"name": "Node1"}),
+ ]
+ edges_data = [
+ ("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}),
+ ]
+
+ mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
+
+ with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"):
+ await graph.project_graph_from_db(
+ adapter=mock_adapter,
+ node_properties_to_project=["name"],
+ edge_properties_to_project=["relationship_name"],
+ )
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_to_graph_nodes(setup_graph):
+ """Test mapping vector distances to graph nodes."""
+ graph = setup_graph
+
+ node1 = Node("1", {"name": "Node1"})
+ node2 = Node("2", {"name": "Node2"})
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ node_distances = {
+ "Entity_name": [
+ MockScoredResult("1", 0.95),
+ MockScoredResult("2", 0.87),
+ ]
+ }
+
+ await graph.map_vector_distances_to_graph_nodes(node_distances)
+
+ assert graph.get_node("1").attributes.get("vector_distance") == 0.95
+ assert graph.get_node("2").attributes.get("vector_distance") == 0.87
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_partial_node_coverage(setup_graph):
+ """Test mapping vector distances when only some nodes have results."""
+ graph = setup_graph
+
+ node1 = Node("1", {"name": "Node1"})
+ node2 = Node("2", {"name": "Node2"})
+ node3 = Node("3", {"name": "Node3"})
+ graph.add_node(node1)
+ graph.add_node(node2)
+ graph.add_node(node3)
+
+ node_distances = {
+ "Entity_name": [
+ MockScoredResult("1", 0.95),
+ MockScoredResult("2", 0.87),
+ ]
+ }
+
+ await graph.map_vector_distances_to_graph_nodes(node_distances)
+
+ assert graph.get_node("1").attributes.get("vector_distance") == 0.95
+ assert graph.get_node("2").attributes.get("vector_distance") == 0.87
+ assert graph.get_node("3").attributes.get("vector_distance") == 3.5
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_multiple_categories(setup_graph):
+ """Test mapping vector distances from multiple collection categories."""
+ graph = setup_graph
+
+ # Create nodes
+ node1 = Node("1")
+ node2 = Node("2")
+ node3 = Node("3")
+ node4 = Node("4")
+ graph.add_node(node1)
+ graph.add_node(node2)
+ graph.add_node(node3)
+ graph.add_node(node4)
+
+ node_distances = {
+ "Entity_name": [
+ MockScoredResult("1", 0.95),
+ MockScoredResult("2", 0.87),
+ ],
+ "TextSummary_text": [
+ MockScoredResult("3", 0.92),
+ ],
+ }
+
+ await graph.map_vector_distances_to_graph_nodes(node_distances)
+
+ assert graph.get_node("1").attributes.get("vector_distance") == 0.95
+ assert graph.get_node("2").attributes.get("vector_distance") == 0.87
+ assert graph.get_node("3").attributes.get("vector_distance") == 0.92
+ assert graph.get_node("4").attributes.get("vector_distance") == 3.5
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine):
+ """Test mapping vector distances to edges when edge_distances provided."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ edge = Edge(
+ node1,
+ node2,
+ attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
+ )
+ graph.add_edge(edge)
+
+ edge_distances = [
+ MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
+ ]
+
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[0.1, 0.2, 0.3],
+ edge_distances=edge_distances,
+ )
+
+ assert graph.edges[0].attributes.get("vector_distance") == 0.92
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine):
+ """Test mapping edge distances when searching for them."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ edge = Edge(
+ node1,
+ node2,
+ attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
+ )
+ graph.add_edge(edge)
+
+ mock_vector_engine.search.return_value = [
+ MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}),
+ ]
+
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[0.1, 0.2, 0.3],
+ edge_distances=None,
+ )
+
+ mock_vector_engine.search.assert_called_once()
+ assert graph.edges[0].attributes.get("vector_distance") == 0.88
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine):
+ """Test mapping edge distances when only some edges have results."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ node3 = Node("3")
+ graph.add_node(node1)
+ graph.add_node(node2)
+ graph.add_node(node3)
+
+ edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"})
+ edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"})
+ graph.add_edge(edge1)
+ graph.add_edge(edge2)
+
+ edge_distances = [
+ MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
+ ]
+
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[0.1, 0.2, 0.3],
+ edge_distances=edge_distances,
+ )
+
+ assert graph.edges[0].attributes.get("vector_distance") == 0.92
+ assert graph.edges[1].attributes.get("vector_distance") == 3.5
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_edges_fallback_to_relationship_type(
+ setup_graph, mock_vector_engine
+):
+ """Test that edge mapping falls back to relationship_type when edge_text is missing."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ edge = Edge(
+ node1,
+ node2,
+ attributes={"relationship_type": "KNOWS"},
+ )
+ graph.add_edge(edge)
+
+ edge_distances = [
+ MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
+ ]
+
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[0.1, 0.2, 0.3],
+ edge_distances=edge_distances,
+ )
+
+ assert graph.edges[0].attributes.get("vector_distance") == 0.85
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine):
+ """Test edge mapping when no edges match the distance results."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ edge = Edge(
+ node1,
+ node2,
+ attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
+ )
+ graph.add_edge(edge)
+
+ edge_distances = [
+ MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
+ ]
+
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[0.1, 0.2, 0.3],
+ edge_distances=edge_distances,
+ )
+
+ assert graph.edges[0].attributes.get("vector_distance") == 3.5
+
+
+@pytest.mark.asyncio
+async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine):
+ """Test that invalid query vector raises error."""
+ graph = setup_graph
+
+ with pytest.raises(ValueError, match="Failed to generate query embedding"):
+ await graph.map_vector_distances_to_graph_edges(
+ vector_engine=mock_vector_engine,
+ query_vector=[],
+ edge_distances=None,
+ )
+
+
+@pytest.mark.asyncio
+async def test_calculate_top_triplet_importances(setup_graph):
+ """Test calculating top triplet importances by score."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ node3 = Node("3")
+ node4 = Node("4")
+
+ node1.add_attribute("vector_distance", 0.9)
+ node2.add_attribute("vector_distance", 0.8)
+ node3.add_attribute("vector_distance", 0.7)
+ node4.add_attribute("vector_distance", 0.6)
+
+ graph.add_node(node1)
+ graph.add_node(node2)
+ graph.add_node(node3)
+ graph.add_node(node4)
+
+ edge1 = Edge(node1, node2)
+ edge2 = Edge(node2, node3)
+ edge3 = Edge(node3, node4)
+
+ edge1.add_attribute("vector_distance", 0.85)
+ edge2.add_attribute("vector_distance", 0.75)
+ edge3.add_attribute("vector_distance", 0.65)
+
+ graph.add_edge(edge1)
+ graph.add_edge(edge2)
+ graph.add_edge(edge3)
+
+ top_triplets = await graph.calculate_top_triplet_importances(k=2)
+
+ assert len(top_triplets) == 2
+
+ assert top_triplets[0] == edge3
+ assert top_triplets[1] == edge2
+
+
+@pytest.mark.asyncio
+async def test_calculate_top_triplet_importances_default_distances(setup_graph):
+ """Test calculating importances when nodes/edges have no vector distances."""
+ graph = setup_graph
+
+ node1 = Node("1")
+ node2 = Node("2")
+ graph.add_node(node1)
+ graph.add_node(node2)
+
+ edge = Edge(node1, node2)
+ graph.add_edge(edge)
+
+ top_triplets = await graph.calculate_top_triplet_importances(k=1)
+
+ assert len(top_triplets) == 1
+ assert top_triplets[0] == edge
diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py
new file mode 100644
index 000000000..5eb6fb105
--- /dev/null
+++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py
@@ -0,0 +1,582 @@
+import pytest
+from unittest.mock import AsyncMock, patch
+
+from cognee.modules.retrieval.utils.brute_force_triplet_search import (
+ brute_force_triplet_search,
+ get_memory_fragment,
+)
+from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
+from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
+
+
+class MockScoredResult:
+ """Mock class for vector search results."""
+
+ def __init__(self, id, score, payload=None):
+ self.id = id
+ self.score = score
+ self.payload = payload or {}
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_empty_query():
+ """Test that empty query raises ValueError."""
+ with pytest.raises(ValueError, match="The query must be a non-empty string."):
+ await brute_force_triplet_search(query="")
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_none_query():
+ """Test that None query raises ValueError."""
+ with pytest.raises(ValueError, match="The query must be a non-empty string."):
+ await brute_force_triplet_search(query=None)
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_negative_top_k():
+ """Test that negative top_k raises ValueError."""
+ with pytest.raises(ValueError, match="top_k must be a positive integer."):
+ await brute_force_triplet_search(query="test query", top_k=-1)
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_zero_top_k():
+ """Test that zero top_k raises ValueError."""
+ with pytest.raises(ValueError, match="top_k must be a positive integer."):
+ await brute_force_triplet_search(query="test query", top_k=0)
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_wide_search_limit_global_search():
+ """Test that wide_search_limit is applied for global search (node_name=None)."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(
+ query="test",
+ node_name=None, # Global search
+ wide_search_top_k=75,
+ )
+
+ for call in mock_vector_engine.search.call_args_list:
+ assert call[1]["limit"] == 75
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
+ """Test that wide_search_limit is None for filtered search (node_name provided)."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(
+ query="test",
+ node_name=["Node1"],
+ wide_search_top_k=50,
+ )
+
+ for call in mock_vector_engine.search.call_args_list:
+ assert call[1]["limit"] is None
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_wide_search_default():
+ """Test that wide_search_top_k defaults to 100."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ for call in mock_vector_engine.search.call_args_list:
+ assert call[1]["limit"] == 100
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_default_collections():
+ """Test that default collections are used when none provided."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(query="test")
+
+ expected_collections = [
+ "Entity_name",
+ "TextSummary_text",
+ "EntityType_name",
+ "DocumentChunk_text",
+ ]
+
+ call_collections = [
+ call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
+ ]
+ assert call_collections == expected_collections
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_custom_collections():
+ """Test that custom collections are used when provided."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ custom_collections = ["CustomCol1", "CustomCol2"]
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(query="test", collections=custom_collections)
+
+ call_collections = [
+ call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
+ ]
+ assert call_collections == custom_collections
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_all_collections_empty():
+ """Test that empty list is returned when all collections return no results."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ results = await brute_force_triplet_search(query="test")
+ assert results == []
+
+
+# Tests for query embedding
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_embeds_query():
+ """Test that query is embedded before searching."""
+ query_text = "test query"
+ expected_vector = [0.1, 0.2, 0.3]
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
+ mock_vector_engine.search = AsyncMock(return_value=[])
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ):
+ await brute_force_triplet_search(query=query_text)
+
+ mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
+
+ for call in mock_vector_engine.search.call_args_list:
+ assert call[1]["query_vector"] == expected_vector
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_extracts_node_ids_global_search():
+ """Test that node IDs are extracted from search results for global search."""
+ scored_results = [
+ MockScoredResult("node1", 0.95),
+ MockScoredResult("node2", 0.87),
+ MockScoredResult("node3", 0.92),
+ ]
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=scored_results)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_reuses_provided_fragment():
+ """Test that provided memory fragment is reused instead of creating new one."""
+ provided_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
+ ) as mock_get_fragment,
+ ):
+ await brute_force_triplet_search(
+ query="test",
+ memory_fragment=provided_fragment,
+ node_name=["node"],
+ )
+
+ mock_get_fragment.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
+ """Test that memory fragment is created when not provided."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment,
+ ):
+ await brute_force_triplet_search(query="test", node_name=["node"])
+
+ mock_get_fragment.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
+ """Test that custom top_k is passed to importance calculation."""
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ),
+ ):
+ custom_top_k = 15
+ await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
+
+ mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
+
+
+@pytest.mark.asyncio
+async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
+ """Test that get_memory_fragment returns empty graph when entity not found."""
+ mock_graph_engine = AsyncMock()
+ mock_graph_engine.project_graph_from_db = AsyncMock(
+ side_effect=EntityNotFoundError("Entity not found")
+ )
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
+ return_value=mock_graph_engine,
+ ):
+ fragment = await get_memory_fragment()
+
+ assert isinstance(fragment, CogneeGraph)
+ assert len(fragment.nodes) == 0
+
+
+@pytest.mark.asyncio
+async def test_get_memory_fragment_returns_empty_graph_on_error():
+ """Test that get_memory_fragment returns empty graph on generic error."""
+ mock_graph_engine = AsyncMock()
+ mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
+
+ with patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
+ return_value=mock_graph_engine,
+ ):
+ fragment = await get_memory_fragment()
+
+ assert isinstance(fragment, CogneeGraph)
+ assert len(fragment.nodes) == 0
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_deduplicates_node_ids():
+ """Test that duplicate node IDs across collections are deduplicated."""
+
+ def search_side_effect(*args, **kwargs):
+ collection_name = kwargs.get("collection_name")
+ if collection_name == "Entity_name":
+ return [
+ MockScoredResult("node1", 0.95),
+ MockScoredResult("node2", 0.87),
+ ]
+ elif collection_name == "TextSummary_text":
+ return [
+ MockScoredResult("node1", 0.90),
+ MockScoredResult("node3", 0.92),
+ ]
+ else:
+ return []
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
+ assert len(call_kwargs["relevant_ids_to_filter"]) == 3
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_excludes_edge_collection():
+ """Test that EdgeType_relationship_name collection is excluded from ID extraction."""
+
+ def search_side_effect(*args, **kwargs):
+ collection_name = kwargs.get("collection_name")
+ if collection_name == "Entity_name":
+ return [MockScoredResult("node1", 0.95)]
+ elif collection_name == "EdgeType_relationship_name":
+ return [MockScoredResult("edge1", 0.88)]
+ else:
+ return []
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(
+ query="test",
+ node_name=None,
+ collections=["Entity_name", "EdgeType_relationship_name"],
+ )
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_skips_nodes_without_ids():
+ """Test that nodes without ID attribute are skipped."""
+
+ class ScoredResultNoId:
+ """Mock result without id attribute."""
+
+ def __init__(self, score):
+ self.score = score
+
+ def search_side_effect(*args, **kwargs):
+ collection_name = kwargs.get("collection_name")
+ if collection_name == "Entity_name":
+ return [
+ MockScoredResult("node1", 0.95),
+ ScoredResultNoId(0.90),
+ MockScoredResult("node2", 0.87),
+ ]
+ else:
+ return []
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_handles_tuple_results():
+ """Test that both list and tuple results are handled correctly."""
+
+ def search_side_effect(*args, **kwargs):
+ collection_name = kwargs.get("collection_name")
+ if collection_name == "Entity_name":
+ return (
+ MockScoredResult("node1", 0.95),
+ MockScoredResult("node2", 0.87),
+ )
+ else:
+ return []
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
+
+
+@pytest.mark.asyncio
+async def test_brute_force_triplet_search_mixed_empty_collections():
+ """Test ID extraction with mixed empty and non-empty collections."""
+
+ def search_side_effect(*args, **kwargs):
+ collection_name = kwargs.get("collection_name")
+ if collection_name == "Entity_name":
+ return [MockScoredResult("node1", 0.95)]
+ elif collection_name == "TextSummary_text":
+ return []
+ elif collection_name == "EntityType_name":
+ return [MockScoredResult("node2", 0.92)]
+ else:
+ return []
+
+ mock_vector_engine = AsyncMock()
+ mock_vector_engine.embedding_engine = AsyncMock()
+ mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
+ mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
+
+ mock_fragment = AsyncMock(
+ map_vector_distances_to_graph_nodes=AsyncMock(),
+ map_vector_distances_to_graph_edges=AsyncMock(),
+ calculate_top_triplet_importances=AsyncMock(return_value=[]),
+ )
+
+ with (
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
+ return_value=mock_vector_engine,
+ ),
+ patch(
+ "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
+ return_value=mock_fragment,
+ ) as mock_get_fragment_fn,
+ ):
+ await brute_force_triplet_search(query="test", node_name=None)
+
+ call_kwargs = mock_get_fragment_fn.call_args[1]
+ assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
diff --git a/cognee/tests/unit/processing/chunks/chunk_by_row_test.py b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py
new file mode 100644
index 000000000..7d6a73a06
--- /dev/null
+++ b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py
@@ -0,0 +1,52 @@
+from itertools import product
+
+import numpy as np
+import pytest
+
+from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
+from cognee.tasks.chunks import chunk_by_row
+
+INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA"
+max_chunk_size_vals = [8, 32]
+
+
+@pytest.mark.parametrize(
+ "input_text,max_chunk_size",
+ list(product([INPUT_TEXTS], max_chunk_size_vals)),
+)
+def test_chunk_by_row_isomorphism(input_text, max_chunk_size):
+ chunks = chunk_by_row(input_text, max_chunk_size)
+ reconstructed_text = ", ".join([chunk["text"] for chunk in chunks])
+ assert reconstructed_text == input_text, (
+ f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
+ )
+
+
+@pytest.mark.parametrize(
+ "input_text,max_chunk_size",
+ list(product([INPUT_TEXTS], max_chunk_size_vals)),
+)
+def test_row_chunk_length(input_text, max_chunk_size):
+ chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size))
+ embedding_engine = get_embedding_engine()
+
+ chunk_lengths = np.array(
+ [embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks]
+ )
+
+ larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
+ assert np.all(chunk_lengths <= max_chunk_size), (
+ f"{max_chunk_size = }: {larger_chunks} are too large"
+ )
+
+
+@pytest.mark.parametrize(
+ "input_text,max_chunk_size",
+ list(product([INPUT_TEXTS], max_chunk_size_vals)),
+)
+def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size):
+ chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)
+ chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
+ assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
+ f"{chunk_indices = } are not monotonically increasing"
+ )
diff --git a/poetry.lock b/poetry.lock
index 67de51633..6e88ccd22 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -11656,7 +11656,9 @@ groups = ["main"]
files = [
{file = "SQLAlchemy-2.0.43-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:21ba7a08a4253c5825d1db389d4299f64a100ef9800e4624c8bf70d8f136e6ed"},
{file = "SQLAlchemy-2.0.43-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11b9503fa6f8721bef9b8567730f664c5a5153d25e247aadc69247c4bc605227"},
+ {file = "SQLAlchemy-2.0.43-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07097c0a1886c150ef2adba2ff7437e84d40c0f7dcb44a2c2b9c905ccfc6361c"},
{file = "SQLAlchemy-2.0.43-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cdeff998cb294896a34e5b2f00e383e7c5c4ef3b4bfa375d9104723f15186443"},
+ {file = "SQLAlchemy-2.0.43-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:bcf0724a62a5670e5718957e05c56ec2d6850267ea859f8ad2481838f889b42c"},
{file = "SQLAlchemy-2.0.43-cp37-cp37m-win32.whl", hash = "sha256:c697575d0e2b0a5f0433f679bda22f63873821d991e95a90e9e52aae517b2e32"},
{file = "SQLAlchemy-2.0.43-cp37-cp37m-win_amd64.whl", hash = "sha256:d34c0f6dbefd2e816e8f341d0df7d4763d382e3f452423e752ffd1e213da2512"},
{file = "sqlalchemy-2.0.43-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70322986c0c699dca241418fcf18e637a4369e0ec50540a2b907b184c8bca069"},
@@ -11691,12 +11693,20 @@ files = [
{file = "sqlalchemy-2.0.43-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9df7126fd9db49e3a5a3999442cc67e9ee8971f3cb9644250107d7296cb2a164"},
{file = "sqlalchemy-2.0.43-cp313-cp313-win32.whl", hash = "sha256:7f1ac7828857fcedb0361b48b9ac4821469f7694089d15550bbcf9ab22564a1d"},
{file = "sqlalchemy-2.0.43-cp313-cp313-win_amd64.whl", hash = "sha256:971ba928fcde01869361f504fcff3b7143b47d30de188b11c6357c0505824197"},
+ {file = "sqlalchemy-2.0.43-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4e6aeb2e0932f32950cf56a8b4813cb15ff792fc0c9b3752eaf067cfe298496a"},
+ {file = "sqlalchemy-2.0.43-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:61f964a05356f4bca4112e6334ed7c208174511bd56e6b8fc86dad4d024d4185"},
{file = "sqlalchemy-2.0.43-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46293c39252f93ea0910aababa8752ad628bcce3a10d3f260648dd472256983f"},
+ {file = "sqlalchemy-2.0.43-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:136063a68644eca9339d02e6693932116f6a8591ac013b0014479a1de664e40a"},
{file = "sqlalchemy-2.0.43-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6e2bf13d9256398d037fef09fd8bf9b0bf77876e22647d10761d35593b9ac547"},
+ {file = "sqlalchemy-2.0.43-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:44337823462291f17f994d64282a71c51d738fc9ef561bf265f1d0fd9116a782"},
{file = "sqlalchemy-2.0.43-cp38-cp38-win32.whl", hash = "sha256:13194276e69bb2af56198fef7909d48fd34820de01d9c92711a5fa45497cc7ed"},
{file = "sqlalchemy-2.0.43-cp38-cp38-win_amd64.whl", hash = "sha256:334f41fa28de9f9be4b78445e68530da3c5fa054c907176460c81494f4ae1f5e"},
+ {file = "sqlalchemy-2.0.43-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ceb5c832cc30663aeaf5e39657712f4c4241ad1f638d487ef7216258f6d41fe7"},
+ {file = "sqlalchemy-2.0.43-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:11f43c39b4b2ec755573952bbcc58d976779d482f6f832d7f33a8d869ae891bf"},
{file = "sqlalchemy-2.0.43-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:413391b2239db55be14fa4223034d7e13325a1812c8396ecd4f2c08696d5ccad"},
+ {file = "sqlalchemy-2.0.43-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c379e37b08c6c527181a397212346be39319fb64323741d23e46abd97a400d34"},
{file = "sqlalchemy-2.0.43-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03d73ab2a37d9e40dec4984d1813d7878e01dbdc742448d44a7341b7a9f408c7"},
+ {file = "sqlalchemy-2.0.43-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:8cee08f15d9e238ede42e9bbc1d6e7158d0ca4f176e4eab21f88ac819ae3bd7b"},
{file = "sqlalchemy-2.0.43-cp39-cp39-win32.whl", hash = "sha256:b3edaec7e8b6dc5cd94523c6df4f294014df67097c8217a89929c99975811414"},
{file = "sqlalchemy-2.0.43-cp39-cp39-win_amd64.whl", hash = "sha256:227119ce0a89e762ecd882dc661e0aa677a690c914e358f0dd8932a2e8b2765b"},
{file = "sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc"},
diff --git a/pyproject.toml b/pyproject.toml
index 13266f83e..a9b895dfb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "cognee"
-version = "0.3.9"
+version = "0.5.0.dev0"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = [
{ name = "Vasilije Markovic" },
@@ -156,7 +156,6 @@ Homepage = "https://www.cognee.ai"
Repository = "https://github.com/topoteretes/cognee"
[project.scripts]
-cognee = "cognee.cli._cognee:main"
cognee-cli = "cognee.cli._cognee:main"
[build-system]
diff --git a/uv.lock b/uv.lock
index 8c35a3366..cc66c3d7e 100644
--- a/uv.lock
+++ b/uv.lock
@@ -929,7 +929,7 @@ wheels = [
[[package]]
name = "cognee"
-version = "0.3.9"
+version = "0.5.0.dev0"
source = { editable = "." }
dependencies = [
{ name = "aiofiles" },
@@ -2560,6 +2560,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/7f/91/ae2eb6b7979e2f9b035a9f612cf70f1bf54aad4e1d125129bef1eae96f19/greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d", size = 584358, upload-time = "2025-08-07T13:18:23.708Z" },
{ url = "https://files.pythonhosted.org/packages/f7/85/433de0c9c0252b22b16d413c9407e6cb3b41df7389afc366ca204dbc1393/greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5", size = 1113550, upload-time = "2025-08-07T13:42:37.467Z" },
{ url = "https://files.pythonhosted.org/packages/a1/8d/88f3ebd2bc96bf7747093696f4335a0a8a4c5acfcf1b757717c0d2474ba3/greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f", size = 1137126, upload-time = "2025-08-07T13:18:20.239Z" },
+ { url = "https://files.pythonhosted.org/packages/f1/29/74242b7d72385e29bcc5563fba67dad94943d7cd03552bac320d597f29b2/greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7", size = 1544904, upload-time = "2025-11-04T12:42:04.763Z" },
+ { url = "https://files.pythonhosted.org/packages/c8/e2/1572b8eeab0f77df5f6729d6ab6b141e4a84ee8eb9bc8c1e7918f94eda6d/greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8", size = 1611228, upload-time = "2025-11-04T12:42:08.423Z" },
{ url = "https://files.pythonhosted.org/packages/d6/6f/b60b0291d9623c496638c582297ead61f43c4b72eef5e9c926ef4565ec13/greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c", size = 298654, upload-time = "2025-08-07T13:50:00.469Z" },
{ url = "https://files.pythonhosted.org/packages/a4/de/f28ced0a67749cac23fecb02b694f6473f47686dff6afaa211d186e2ef9c/greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2", size = 272305, upload-time = "2025-08-07T13:15:41.288Z" },
{ url = "https://files.pythonhosted.org/packages/09/16/2c3792cba130000bf2a31c5272999113f4764fd9d874fb257ff588ac779a/greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246", size = 632472, upload-time = "2025-08-07T13:42:55.044Z" },
@@ -2569,6 +2571,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684, upload-time = "2025-08-07T13:18:25.164Z" },
{ url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647, upload-time = "2025-08-07T13:42:38.655Z" },
{ url = "https://files.pythonhosted.org/packages/3f/cc/b07000438a29ac5cfb2194bfc128151d52f333cee74dd7dfe3fb733fc16c/greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa", size = 1142073, upload-time = "2025-08-07T13:18:21.737Z" },
+ { url = "https://files.pythonhosted.org/packages/67/24/28a5b2fa42d12b3d7e5614145f0bd89714c34c08be6aabe39c14dd52db34/greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c", size = 1548385, upload-time = "2025-11-04T12:42:11.067Z" },
+ { url = "https://files.pythonhosted.org/packages/6a/05/03f2f0bdd0b0ff9a4f7b99333d57b53a7709c27723ec8123056b084e69cd/greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5", size = 1613329, upload-time = "2025-11-04T12:42:12.928Z" },
{ url = "https://files.pythonhosted.org/packages/d8/0f/30aef242fcab550b0b3520b8e3561156857c94288f0332a79928c31a52cf/greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9", size = 299100, upload-time = "2025-08-07T13:44:12.287Z" },
{ url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" },
{ url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" },
@@ -2578,6 +2582,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" },
{ url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" },
{ url = "https://files.pythonhosted.org/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142, upload-time = "2025-08-07T13:18:22.981Z" },
+ { url = "https://files.pythonhosted.org/packages/27/45/80935968b53cfd3f33cf99ea5f08227f2646e044568c9b1555b58ffd61c2/greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0", size = 1564846, upload-time = "2025-11-04T12:42:15.191Z" },
+ { url = "https://files.pythonhosted.org/packages/69/02/b7c30e5e04752cb4db6202a3858b149c0710e5453b71a3b2aec5d78a1aab/greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d", size = 1633814, upload-time = "2025-11-04T12:42:17.175Z" },
{ url = "https://files.pythonhosted.org/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899, upload-time = "2025-08-07T13:38:53.448Z" },
{ url = "https://files.pythonhosted.org/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" },
{ url = "https://files.pythonhosted.org/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" },
@@ -2587,6 +2593,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" },
{ url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" },
{ url = "https://files.pythonhosted.org/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" },
+ { url = "https://files.pythonhosted.org/packages/1c/53/f9c440463b3057485b8594d7a638bed53ba531165ef0ca0e6c364b5cc807/greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b", size = 1564759, upload-time = "2025-11-04T12:42:19.395Z" },
+ { url = "https://files.pythonhosted.org/packages/47/e4/3bb4240abdd0a8d23f4f88adec746a3099f0d86bfedb623f063b2e3b4df0/greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929", size = 1634288, upload-time = "2025-11-04T12:42:21.174Z" },
{ url = "https://files.pythonhosted.org/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685, upload-time = "2025-08-07T13:24:38.824Z" },
]