Merge branch 'dev' into multi-tenant-neo4j

This commit is contained in:
Igor Ilic 2025-11-28 12:55:48 +01:00 committed by GitHub
commit 0c825b96ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
67 changed files with 2723 additions and 170 deletions

View file

@ -21,6 +21,10 @@ LLM_PROVIDER="openai"
LLM_ENDPOINT=""
LLM_API_VERSION=""
LLM_MAX_TOKENS="16384"
# Instructor's modes determine how structured data is requested from and extracted from LLM responses
# You can change this type (i.e. mode) via this env variable
# Each LLM has its own default value, e.g. gpt-5 models have "json_schema_mode"
LLM_INSTRUCTOR_MODE=""
EMBEDDING_PROVIDER="openai"
EMBEDDING_MODEL="openai/text-embedding-3-large"

View file

@ -194,7 +194,6 @@ async def cognify(
Prerequisites:
- **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation)
- **Data Added**: Must have data previously added via `cognee.add()`
- **Vector Database**: Must be accessible for embeddings storage
- **Graph Database**: Must be accessible for relationship storage

View file

@ -23,6 +23,7 @@ from cognee.api.v1.settings.routers import get_settings_router
from cognee.api.v1.datasets.routers import get_datasets_router
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
from cognee.api.v1.search.routers import get_search_router
from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router
from cognee.api.v1.memify.routers import get_memify_router
from cognee.api.v1.add.routers import get_add_router
from cognee.api.v1.delete.routers import get_delete_router
@ -263,6 +264,8 @@ app.include_router(
app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"])
app.include_router(get_ontology_router(), prefix="/api/v1/ontologies", tags=["ontologies"])
app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"])
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])

View file

@ -41,6 +41,9 @@ class CognifyPayloadDTO(InDTO):
custom_prompt: Optional[str] = Field(
default="", description="Custom prompt for entity extraction and graph generation"
)
ontology_key: Optional[List[str]] = Field(
default=None, description="Reference to one or more previously uploaded ontologies"
)
def get_cognify_router() -> APIRouter:
@ -68,6 +71,7 @@ def get_cognify_router() -> APIRouter:
- **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted).
- **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking).
- **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction.
- **ontology_key** (Optional[List[str]]): Reference to one or more previously uploaded ontology files to use for knowledge graph construction.
## Response
- **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status
@ -82,7 +86,8 @@ def get_cognify_router() -> APIRouter:
{
"datasets": ["research_papers", "documentation"],
"run_in_background": false,
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections."
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.",
"ontology_key": ["medical_ontology_v1"]
}
```
@ -108,13 +113,35 @@ def get_cognify_router() -> APIRouter:
)
from cognee.api.v1.cognify import cognify as cognee_cognify
from cognee.api.v1.ontologies.ontologies import OntologyService
try:
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
config_to_use = None
if payload.ontology_key:
ontology_service = OntologyService()
ontology_contents = ontology_service.get_ontology_contents(
payload.ontology_key, user
)
from cognee.modules.ontology.ontology_config import Config
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import (
RDFLibOntologyResolver,
)
from io import StringIO
ontology_streams = [StringIO(content) for content in ontology_contents]
config_to_use: Config = {
"ontology_config": {
"ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_streams)
}
}
cognify_run = await cognee_cognify(
datasets,
user,
config=config_to_use,
run_in_background=payload.run_in_background,
custom_prompt=payload.custom_prompt,
)

View file

@ -0,0 +1,4 @@
from .ontologies import OntologyService
from .routers.get_ontology_router import get_ontology_router
__all__ = ["OntologyService", "get_ontology_router"]

View file

@ -0,0 +1,183 @@
import os
import json
import tempfile
from pathlib import Path
from datetime import datetime, timezone
from typing import Optional, List
from dataclasses import dataclass
@dataclass
class OntologyMetadata:
ontology_key: str
filename: str
size_bytes: int
uploaded_at: str
description: Optional[str] = None
class OntologyService:
def __init__(self):
pass
@property
def base_dir(self) -> Path:
return Path(tempfile.gettempdir()) / "ontologies"
def _get_user_dir(self, user_id: str) -> Path:
user_dir = self.base_dir / str(user_id)
user_dir.mkdir(parents=True, exist_ok=True)
return user_dir
def _get_metadata_path(self, user_dir: Path) -> Path:
return user_dir / "metadata.json"
def _load_metadata(self, user_dir: Path) -> dict:
metadata_path = self._get_metadata_path(user_dir)
if metadata_path.exists():
with open(metadata_path, "r") as f:
return json.load(f)
return {}
def _save_metadata(self, user_dir: Path, metadata: dict):
metadata_path = self._get_metadata_path(user_dir)
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
async def upload_ontology(
self, ontology_key: str, file, user, description: Optional[str] = None
) -> OntologyMetadata:
if not file.filename.lower().endswith(".owl"):
raise ValueError("File must be in .owl format")
user_dir = self._get_user_dir(str(user.id))
metadata = self._load_metadata(user_dir)
if ontology_key in metadata:
raise ValueError(f"Ontology key '{ontology_key}' already exists")
content = await file.read()
if len(content) > 10 * 1024 * 1024:
raise ValueError("File size exceeds 10MB limit")
file_path = user_dir / f"{ontology_key}.owl"
with open(file_path, "wb") as f:
f.write(content)
ontology_metadata = {
"filename": file.filename,
"size_bytes": len(content),
"uploaded_at": datetime.now(timezone.utc).isoformat(),
"description": description,
}
metadata[ontology_key] = ontology_metadata
self._save_metadata(user_dir, metadata)
return OntologyMetadata(
ontology_key=ontology_key,
filename=file.filename,
size_bytes=len(content),
uploaded_at=ontology_metadata["uploaded_at"],
description=description,
)
async def upload_ontologies(
self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None
) -> List[OntologyMetadata]:
"""
Upload ontology files with their respective keys.
Args:
ontology_key: List of unique keys for each ontology
files: List of UploadFile objects (same length as keys)
user: Authenticated user
descriptions: Optional list of descriptions for each file
Returns:
List of OntologyMetadata objects for uploaded files
Raises:
ValueError: If keys duplicate, file format invalid, or array lengths don't match
"""
if len(ontology_key) != len(files):
raise ValueError("Number of keys must match number of files")
if len(set(ontology_key)) != len(ontology_key):
raise ValueError("Duplicate ontology keys not allowed")
if descriptions and len(descriptions) != len(files):
raise ValueError("Number of descriptions must match number of files")
results = []
user_dir = self._get_user_dir(str(user.id))
metadata = self._load_metadata(user_dir)
for i, (key, file) in enumerate(zip(ontology_key, files)):
if key in metadata:
raise ValueError(f"Ontology key '{key}' already exists")
if not file.filename.lower().endswith(".owl"):
raise ValueError(f"File '{file.filename}' must be in .owl format")
content = await file.read()
if len(content) > 10 * 1024 * 1024:
raise ValueError(f"File '{file.filename}' exceeds 10MB limit")
file_path = user_dir / f"{key}.owl"
with open(file_path, "wb") as f:
f.write(content)
ontology_metadata = {
"filename": file.filename,
"size_bytes": len(content),
"uploaded_at": datetime.now(timezone.utc).isoformat(),
"description": descriptions[i] if descriptions else None,
}
metadata[key] = ontology_metadata
results.append(
OntologyMetadata(
ontology_key=key,
filename=file.filename,
size_bytes=len(content),
uploaded_at=ontology_metadata["uploaded_at"],
description=descriptions[i] if descriptions else None,
)
)
self._save_metadata(user_dir, metadata)
return results
def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:
"""
Retrieve ontology content for one or more keys.
Args:
ontology_key: List of ontology keys to retrieve (can contain single item)
user: Authenticated user
Returns:
List of ontology content strings
Raises:
ValueError: If any ontology key not found
"""
user_dir = self._get_user_dir(str(user.id))
metadata = self._load_metadata(user_dir)
contents = []
for key in ontology_key:
if key not in metadata:
raise ValueError(f"Ontology key '{key}' not found")
file_path = user_dir / f"{key}.owl"
if not file_path.exists():
raise ValueError(f"Ontology file for key '{key}' not found")
with open(file_path, "r", encoding="utf-8") as f:
contents.append(f.read())
return contents
def list_ontologies(self, user) -> dict:
user_dir = self._get_user_dir(str(user.id))
return self._load_metadata(user_dir)

View file

@ -0,0 +1,107 @@
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException
from fastapi.responses import JSONResponse
from typing import Optional, List
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
from ..ontologies import OntologyService
def get_ontology_router() -> APIRouter:
router = APIRouter()
ontology_service = OntologyService()
@router.post("", response_model=dict)
async def upload_ontology(
ontology_key: str = Form(...),
ontology_file: List[UploadFile] = File(...),
descriptions: Optional[str] = Form(None),
user: User = Depends(get_authenticated_user),
):
"""
Upload ontology files with their respective keys for later use in cognify operations.
Supports both single and multiple file uploads:
- Single file: ontology_key=["key"], ontology_file=[file]
- Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2]
## Request Parameters
- **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies
- **ontology_file** (List[UploadFile]): OWL format ontology files
- **descriptions** (Optional[str]): JSON array string of optional descriptions
## Response
Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps.
## Error Codes
- **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded
- **500 Internal Server Error**: File system or processing errors
"""
send_telemetry(
"Ontology Upload API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": "POST /api/v1/ontologies",
"cognee_version": cognee_version,
},
)
try:
import json
ontology_keys = json.loads(ontology_key)
description_list = json.loads(descriptions) if descriptions else None
if not isinstance(ontology_keys, list):
raise ValueError("ontology_key must be a JSON array")
results = await ontology_service.upload_ontologies(
ontology_keys, ontology_file, user, description_list
)
return {
"uploaded_ontologies": [
{
"ontology_key": result.ontology_key,
"filename": result.filename,
"size_bytes": result.size_bytes,
"uploaded_at": result.uploaded_at,
"description": result.description,
}
for result in results
]
}
except (json.JSONDecodeError, ValueError) as e:
return JSONResponse(status_code=400, content={"error": str(e)})
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@router.get("", response_model=dict)
async def list_ontologies(user: User = Depends(get_authenticated_user)):
"""
List all uploaded ontologies for the authenticated user.
## Response
Returns a dictionary mapping ontology keys to their metadata including filename, size, and upload timestamp.
## Error Codes
- **500 Internal Server Error**: File system or processing errors
"""
send_telemetry(
"Ontology List API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": "GET /api/v1/ontologies",
"cognee_version": cognee_version,
},
)
try:
metadata = ontology_service.list_ontologies(user)
return metadata
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
return router

View file

@ -31,6 +31,8 @@ async def search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[List[SearchResult], CombinedSearchResult]:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -200,6 +202,8 @@ async def search(
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
return filtered_search_results

View file

@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge
chunker_class = LangchainChunker
except ImportError:
fmt.warning("LangchainChunker not available, using TextChunker")
elif args.chunker == "CsvChunker":
try:
from cognee.modules.chunking.CsvChunker import CsvChunker
chunker_class = CsvChunker
except ImportError:
fmt.warning("CsvChunker not available, using TextChunker")
result = await cognee.cognify(
datasets=datasets,

View file

@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [
]
# Chunker choices
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"]
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"]
# Output format choices
OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"]

View file

@ -0,0 +1,29 @@
FROM python:3.11-slim
# Set environment variables
ENV PIP_NO_CACHE_DIR=true
ENV PATH="${PATH}:/root/.poetry/bin"
ENV PYTHONPATH=/app
ENV SKIP_MIGRATIONS=true
# System dependencies
RUN apt-get update && apt-get install -y \
gcc \
libpq-dev \
git \
curl \
build-essential \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY pyproject.toml poetry.lock README.md /app/
RUN pip install poetry
RUN poetry config virtualenvs.create false
RUN poetry install --extras distributed --extras evals --extras deepeval --no-root
COPY cognee/ /app/cognee
COPY distributed/ /app/distributed

View file

@ -35,6 +35,16 @@ class AnswerGeneratorExecutor:
retrieval_context = await retriever.get_context(query_text)
search_results = await retriever.get_completion(query_text, retrieval_context)
############
#:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure..
if isinstance(retrieval_context, list):
retrieval_context = await retriever.convert_retrieved_objects_to_context(
triplets=retrieval_context
)
if isinstance(search_results, str):
search_results = [search_results]
#############
answer = {
"question": query_text,
"answer": search_results[0],

View file

@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload):
async def run_question_answering(
params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None
params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None
) -> List[dict]:
if params.get("answering_questions"):
logger.info("Question answering started...")

View file

@ -14,7 +14,7 @@ class EvalConfig(BaseSettings):
# Question answering params
answering_questions: bool = True
qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension'
# Evaluation params
evaluating_answers: bool = True
@ -25,7 +25,7 @@ class EvalConfig(BaseSettings):
"EM",
"f1",
] # Use only 'correctness' for DirectLLM
deepeval_model: str = "gpt-5-mini"
deepeval_model: str = "gpt-4o-mini"
# Metrics params
calculate_metrics: bool = True

View file

@ -2,7 +2,6 @@ import modal
import os
import asyncio
import datetime
import hashlib
import json
from cognee.shared.logging_utils import get_logger
from cognee.eval_framework.eval_config import EvalConfig
@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b
from cognee.eval_framework.answer_generation.run_question_answering_module import (
run_question_answering,
)
import pathlib
from os import path
from modal import Image
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
from cognee.eval_framework.metrics_dashboard import create_dashboard
@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict:
app = modal.App("modal-run-eval")
image = (
modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False)
.copy_local_file("pyproject.toml", "pyproject.toml")
.copy_local_file("poetry.lock", "poetry.lock")
.env(
{
"ENV": os.getenv("ENV"),
"LLM_API_KEY": os.getenv("LLM_API_KEY"),
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
}
)
.pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
image = Image.from_dockerfile(
path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(),
force_build=False,
).add_local_python_source("cognee")
@app.function(
image=image,
max_containers=10,
timeout=86400,
volumes={"/data": vol},
secrets=[modal.Secret.from_name("eval_secrets")],
)
@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
async def modal_run_eval(eval_params=None):
"""Runs evaluation pipeline and returns combined metrics results."""
if eval_params is None:
@ -105,18 +104,7 @@ async def main():
configs = [
EvalConfig(
task_getter_type="Default",
number_of_samples_in_corpus=10,
benchmark="HotPotQA",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
answering_questions=True,
evaluating_answers=True,
calculate_metrics=True,
dashboard=True,
),
EvalConfig(
task_getter_type="Default",
number_of_samples_in_corpus=10,
number_of_samples_in_corpus=25,
benchmark="TwoWikiMultiHop",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
@ -127,7 +115,7 @@ async def main():
),
EvalConfig(
task_getter_type="Default",
number_of_samples_in_corpus=10,
number_of_samples_in_corpus=25,
benchmark="Musique",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,

View file

@ -398,3 +398,18 @@ class GraphDBInterface(ABC):
- node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections.
"""
raise NotImplementedError
@abstractmethod
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
) -> Tuple[List[Node], List[EdgeData]]:
"""
Retrieve nodes and edges filtered by the provided attribute criteria.
Parameters:
-----------
- attribute_filters: A list of dictionaries where keys are attribute names and values
are lists of attribute values to filter by.
"""
raise NotImplementedError

View file

@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type
from cognee.exceptions import CogneeValidationError
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.files.storage import get_file_storage
@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface):
A tuple with two elements: a list of tuples of (node_id, properties) and a list of
tuples of (source_id, target_id, relationship_name, properties).
"""
import time
start_time = time.time()
try:
nodes_query = """
MATCH (n:Node)
@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface):
},
)
)
retrieval_time = time.time() - start_time
logger.info(
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
)
return formatted_nodes, formatted_edges
except Exception as e:
logger.error(f"Failed to get graph data: {e}")
@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface):
formatted_edges.append((source_id, target_id, rel_type, props))
return formatted_nodes, formatted_edges
async def get_id_filtered_graph_data(self, target_ids: list[str]):
"""
Retrieve graph data filtered by specific node IDs, including their direct neighbors
and only edges where one endpoint matches those IDs.
Returns:
nodes: List[dict] -> Each dict includes "id" and all node properties
edges: List[dict] -> Each dict includes "source", "target", "type", "properties"
"""
import time
start_time = time.time()
try:
if not target_ids:
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
return [], []
if not all(isinstance(x, str) for x in target_ids):
raise CogneeValidationError("target_ids must be a list of strings")
query = """
MATCH (n:Node)-[r]->(m:Node)
WHERE n.id IN $target_ids OR m.id IN $target_ids
RETURN n.id, {
name: n.name,
type: n.type,
properties: n.properties
}, m.id, {
name: m.name,
type: m.type,
properties: m.properties
}, r.relationship_name, r.properties
"""
result = await self.query(query, {"target_ids": target_ids})
if not result:
logger.info("No data returned for the supplied IDs")
return [], []
nodes_dict = {}
edges = []
for n_id, n_props, m_id, m_props, r_type, r_props_raw in result:
if n_props.get("properties"):
try:
additional_props = json.loads(n_props["properties"])
n_props.update(additional_props)
del n_props["properties"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties JSON for node {n_id}")
if m_props.get("properties"):
try:
additional_props = json.loads(m_props["properties"])
m_props.update(additional_props)
del m_props["properties"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties JSON for node {m_id}")
nodes_dict[n_id] = (n_id, n_props)
nodes_dict[m_id] = (m_id, m_props)
edge_props = {}
if r_props_raw:
try:
edge_props = json.loads(r_props_raw)
except (json.JSONDecodeError, TypeError):
logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}")
source_id = edge_props.get("source_node_id", n_id)
target_id = edge_props.get("target_node_id", m_id)
edges.append((source_id, target_id, r_type, edge_props))
retrieval_time = time.time() - start_time
logger.info(
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
)
return list(nodes_dict.values()), edges
except Exception as e:
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
raise
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
"""
Get metrics on graph structure and connectivity.

View file

@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface):
logger.error(f"Error during graph data retrieval: {str(e)}")
raise
async def get_id_filtered_graph_data(self, target_ids: list[str]):
"""
Retrieve graph data filtered by specific node IDs, including their direct neighbors
and only edges where one endpoint matches those IDs.
This version uses a single Cypher query for efficiency.
"""
import time
start_time = time.time()
try:
if not target_ids:
logger.warning("No target IDs provided for ID-filtered graph retrieval.")
return [], []
query = """
MATCH ()-[r]-()
WHERE startNode(r).id IN $target_ids
OR endNode(r).id IN $target_ids
WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b
RETURN
properties(a) AS n_properties,
properties(b) AS m_properties,
type(r) AS type,
properties(r) AS properties
"""
result = await self.query(query, {"target_ids": target_ids})
nodes_dict = {}
edges = []
for record in result:
n_props = record["n_properties"]
m_props = record["m_properties"]
r_props = record["properties"]
r_type = record["type"]
nodes_dict[n_props["id"]] = (n_props["id"], n_props)
nodes_dict[m_props["id"]] = (m_props["id"], m_props)
source_id = r_props.get("source_node_id", n_props["id"])
target_id = r_props.get("target_node_id", m_props["id"])
edges.append((source_id, target_id, r_type, r_props))
retrieval_time = time.time() - start_time
logger.info(
f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s"
)
return list(nodes_dict.values()), edges
except Exception as e:
logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}")
raise
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]:

View file

@ -55,6 +55,10 @@ def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type
file_type = Type("text/plain", "txt")
return file_type
if ext in [".csv"]:
file_type = Type("text/csv", "csv")
return file_type
file_type = filetype.guess(file)
# If file type could not be determined consider it a plain text file as they don't have magic number encoding

View file

@ -38,6 +38,7 @@ class LLMConfig(BaseSettings):
"""
structured_output_framework: str = "instructor"
llm_instructor_mode: str = ""
llm_provider: str = "openai"
llm_model: str = "openai/gpt-5-mini"
llm_endpoint: str = ""
@ -181,6 +182,7 @@ class LLMConfig(BaseSettings):
instance.
"""
return {
"llm_instructor_mode": self.llm_instructor_mode.lower(),
"provider": self.llm_provider,
"model": self.llm_model,
"endpoint": self.llm_endpoint,

View file

@ -28,13 +28,16 @@ class AnthropicAdapter(LLMInterface):
name = "Anthropic"
model: str
default_instructor_mode = "anthropic_tools"
def __init__(self, max_completion_tokens: int, model: str = None):
def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
import anthropic
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.patch(
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
mode=instructor.Mode.ANTHROPIC_TOOLS,
mode=instructor.Mode(self.instructor_mode),
)
self.model = model

View file

@ -41,6 +41,7 @@ class GeminiAdapter(LLMInterface):
name: str
model: str
api_key: str
default_instructor_mode = "json_mode"
def __init__(
self,
@ -49,6 +50,7 @@ class GeminiAdapter(LLMInterface):
model: str,
api_version: str,
max_completion_tokens: int,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
@ -63,7 +65,11 @@ class GeminiAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
@retry(
stop=stop_after_delay(128),

View file

@ -41,6 +41,7 @@ class GenericAPIAdapter(LLMInterface):
name: str
model: str
api_key: str
default_instructor_mode = "json_mode"
def __init__(
self,
@ -49,6 +50,7 @@ class GenericAPIAdapter(LLMInterface):
model: str,
name: str,
max_completion_tokens: int,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
@ -63,7 +65,11 @@ class GenericAPIAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
@retry(
stop=stop_after_delay(128),

View file

@ -81,6 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True):
model=llm_config.llm_model,
transcription_model=llm_config.transcription_model,
max_completion_tokens=max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode.lower(),
streaming=llm_config.llm_streaming,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
@ -101,6 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True):
llm_config.llm_model,
"Ollama",
max_completion_tokens=max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.ANTHROPIC:
@ -109,7 +111,9 @@ def get_llm_client(raise_api_key_error: bool = True):
)
return AnthropicAdapter(
max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
max_completion_tokens=max_completion_tokens,
model=llm_config.llm_model,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.CUSTOM:
@ -126,6 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True):
llm_config.llm_model,
"Custom",
max_completion_tokens=max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode.lower(),
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
fallback_model=llm_config.fallback_model,
@ -145,6 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True):
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
api_version=llm_config.llm_api_version,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
elif provider == LLMProvider.MISTRAL:
@ -160,21 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True):
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
)
elif provider == LLMProvider.MISTRAL:
if llm_config.llm_api_key is None:
raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
MistralAdapter,
)
return MistralAdapter(
api_key=llm_config.llm_api_key,
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
else:

View file

@ -37,16 +37,26 @@ class MistralAdapter(LLMInterface):
model: str
api_key: str
max_completion_tokens: int
default_instructor_mode = "mistral_tools"
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
def __init__(
self,
api_key: str,
model: str,
max_completion_tokens: int,
endpoint: str = None,
instructor_mode: str = None,
):
from mistralai import Mistral
self.model = model
self.max_completion_tokens = max_completion_tokens
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion,
mode=instructor.Mode.MISTRAL_TOOLS,
mode=instructor.Mode(self.instructor_mode),
api_key=get_llm_config().llm_api_key,
)

View file

@ -42,8 +42,16 @@ class OllamaAPIAdapter(LLMInterface):
- aclient
"""
default_instructor_mode = "json_mode"
def __init__(
self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
self,
endpoint: str,
api_key: str,
model: str,
name: str,
max_completion_tokens: int,
instructor_mode: str = None,
):
self.name = name
self.model = model
@ -51,8 +59,11 @@ class OllamaAPIAdapter(LLMInterface):
self.endpoint = endpoint
self.max_completion_tokens = max_completion_tokens
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_openai(
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
OpenAI(base_url=self.endpoint, api_key=self.api_key),
mode=instructor.Mode(self.instructor_mode),
)
@retry(

View file

@ -56,6 +56,7 @@ class OpenAIAdapter(LLMInterface):
model: str
api_key: str
api_version: str
default_instructor_mode = "json_schema_mode"
MAX_RETRIES = 5
@ -69,19 +70,21 @@ class OpenAIAdapter(LLMInterface):
model: str,
transcription_model: str,
max_completion_tokens: int,
instructor_mode: str = None,
streaming: bool = False,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
# Make sure all new gpt models will work with this mode as well.
if "gpt-5" in model:
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
self.client = instructor.from_litellm(
litellm.completion, mode=instructor.Mode.JSON_SCHEMA
litellm.completion, mode=instructor.Mode(self.instructor_mode)
)
else:
self.aclient = instructor.from_litellm(litellm.acompletion)

View file

@ -31,6 +31,7 @@ class LoaderEngine:
"pypdf_loader",
"image_loader",
"audio_loader",
"csv_loader",
"unstructured_loader",
"advanced_pdf_loader",
]

View file

@ -3,5 +3,6 @@
from .text_loader import TextLoader
from .audio_loader import AudioLoader
from .image_loader import ImageLoader
from .csv_loader import CsvLoader
__all__ = ["TextLoader", "AudioLoader", "ImageLoader"]
__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"]

View file

@ -0,0 +1,93 @@
import os
from typing import List
import csv
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
class CsvLoader(LoaderInterface):
"""
Core CSV file loader that handles basic CSV file formats.
"""
@property
def supported_extensions(self) -> List[str]:
"""Supported text file extensions."""
return [
"csv",
]
@property
def supported_mime_types(self) -> List[str]:
"""Supported MIME types for text content."""
return [
"text/csv",
]
@property
def loader_name(self) -> str:
"""Unique identifier for this loader."""
return "csv_loader"
def can_handle(self, extension: str, mime_type: str) -> bool:
"""
Check if this loader can handle the given file.
Args:
extension: File extension
mime_type: Optional MIME type
Returns:
True if file can be handled, False otherwise
"""
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
return True
return False
async def load(self, file_path: str, encoding: str = "utf-8", **kwargs):
"""
Load and process the csv file.
Args:
file_path: Path to the file to load
encoding: Text encoding to use (default: utf-8)
**kwargs: Additional configuration (unused)
Returns:
LoaderResult containing the file content and metadata
Raises:
FileNotFoundError: If file doesn't exist
UnicodeDecodeError: If file cannot be decoded with specified encoding
OSError: If file cannot be read
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as f:
file_metadata = await get_file_metadata(f)
# Name ingested file of current loader based on original file content hash
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
row_texts = []
row_index = 1
with open(file_path, "r", encoding=encoding, newline="") as file:
reader = csv.DictReader(file)
for row in reader:
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
row_text = ", ".join(pairs)
row_texts.append(f"Row {row_index}:\n{row_text}\n")
row_index += 1
content = "\n".join(row_texts)
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
storage = get_file_storage(data_root_directory)
full_file_path = await storage.store(storage_file_name, content)
return full_file_path

View file

@ -16,7 +16,7 @@ class TextLoader(LoaderInterface):
@property
def supported_extensions(self) -> List[str]:
"""Supported text file extensions."""
return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"]
return ["txt", "md", "json", "xml", "yaml", "yml", "log"]
@property
def supported_mime_types(self) -> List[str]:
@ -24,7 +24,6 @@ class TextLoader(LoaderInterface):
return [
"text/plain",
"text/markdown",
"text/csv",
"application/json",
"text/xml",
"application/xml",

View file

@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface):
if value is None:
return ""
return str(value).replace("\xa0", " ").strip()
if __name__ == "__main__":
loader = AdvancedPdfLoader()
asyncio.run(
loader.load(
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
)
)

View file

@ -1,5 +1,5 @@
from cognee.infrastructure.loaders.external import PyPdfLoader
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader
# Registry for loader implementations
supported_loaders = {
@ -7,6 +7,7 @@ supported_loaders = {
TextLoader.loader_name: TextLoader,
ImageLoader.loader_name: ImageLoader,
AudioLoader.loader_name: AudioLoader,
CsvLoader.loader_name: CsvLoader,
}
# Try adding optional loaders

View file

@ -0,0 +1,35 @@
from cognee.shared.logging_utils import get_logger
from cognee.tasks.chunks import chunk_by_row
from cognee.modules.chunking.Chunker import Chunker
from .models.DocumentChunk import DocumentChunk
logger = get_logger()
class CsvChunker(Chunker):
async def read(self):
async for content_text in self.get_text():
if content_text is None:
continue
for chunk_data in chunk_by_row(content_text, self.max_chunk_size):
if chunk_data["chunk_size"] <= self.max_chunk_size:
yield DocumentChunk(
id=chunk_data["chunk_id"],
text=chunk_data["text"],
chunk_size=chunk_data["chunk_size"],
is_part_of=self.document,
chunk_index=self.chunk_index,
cut_type=chunk_data["cut_type"],
contains=[],
metadata={
"index_fields": ["text"],
},
)
self.chunk_index += 1
else:
raise ValueError(
f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}"
)

View file

@ -0,0 +1,33 @@
import io
import csv
from typing import Type
from cognee.modules.chunking.Chunker import Chunker
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from .Document import Document
class CsvDocument(Document):
type: str = "csv"
mime_type: str = "text/csv"
async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int):
async def get_text():
async with open_data_file(
self.raw_data_location, mode="r", encoding="utf-8", newline=""
) as file:
content = file.read()
file_like_obj = io.StringIO(content)
reader = csv.DictReader(file_like_obj)
for row in reader:
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
row_text = ", ".join(pairs)
if not row_text.strip():
break
yield row_text
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
async for chunk in chunker.read():
yield chunk

View file

@ -4,3 +4,4 @@ from .TextDocument import TextDocument
from .ImageDocument import ImageDocument
from .AudioDocument import AudioDocument
from .UnstructuredDocument import UnstructuredDocument
from .CsvDocument import CsvDocument

View file

@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph):
def get_edges(self) -> List[Edge]:
return self.edges
async def _get_nodeset_subgraph(
self,
adapter,
node_type,
node_name,
):
"""Retrieve subgraph based on node type and name."""
logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Nodeset does not exist, or empty nodeset projected from the database."
)
return nodes_data, edges_data
async def _get_full_or_id_filtered_graph(
self,
adapter,
relevant_ids_to_filter,
):
"""Retrieve full or ID-filtered graph with fallback."""
if relevant_ids_to_filter is None:
logger.info("Retrieving full graph.")
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty graph projected from the database.")
return nodes_data, edges_data
get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
logger.info("Retrieving ID-filtered graph from database.")
nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
else:
logger.info("Retrieving full graph from database.")
nodes_data, edges_data = await get_graph_data_fn()
if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
logger.warning(
"Id filtered graph returned empty, falling back to full graph retrieval."
)
logger.info("Retrieving full graph")
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError("Empty graph projected from the database.")
return nodes_data, edges_data
async def _get_filtered_graph(
self,
adapter,
memory_fragment_filter,
):
"""Retrieve graph filtered by attributes."""
logger.info("Retrieving graph filtered by memory fragment")
nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
return nodes_data, edges_data
async def project_graph_from_db(
self,
adapter: Union[GraphDBInterface],
@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph):
memory_fragment_filter=[],
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
relevant_ids_to_filter: Optional[List[str]] = None,
triplet_distance_penalty: float = 3.5,
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidDimensionsError()
try:
if node_type is not None and node_name not in [None, [], ""]:
nodes_data, edges_data = await self._get_nodeset_subgraph(
adapter, node_type, node_name
)
elif len(memory_fragment_filter) == 0:
nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
adapter, relevant_ids_to_filter
)
else:
nodes_data, edges_data = await self._get_filtered_graph(
adapter, memory_fragment_filter
)
import time
start_time = time.time()
# Determine projection strategy
if node_type is not None and node_name not in [None, [], ""]:
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Nodeset does not exist, or empty nodetes projected from the database."
)
elif len(memory_fragment_filter) == 0:
nodes_data, edges_data = await adapter.get_graph_data()
if not nodes_data or not edges_data:
raise EntityNotFoundError(message="Empty graph projected from the database.")
else:
nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Empty filtered graph projected from the database."
)
# Process nodes
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
self.add_node(
Node(
str(node_id),
node_attributes,
dimension=node_dimension,
node_penalty=triplet_distance_penalty,
)
)
# Process edges
for source_id, target_id, relationship_type, properties in edges_data:
@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph):
attributes=edge_attributes,
directed=directed,
dimension=edge_dimension,
edge_penalty=triplet_distance_penalty,
)
self.add_edge(edge)

View file

@ -20,13 +20,17 @@ class Node:
status: np.ndarray
def __init__(
self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1
self,
node_id: str,
attributes: Optional[Dict[str, Any]] = None,
dimension: int = 1,
node_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.id = node_id
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float("inf")
self.attributes["vector_distance"] = node_penalty
self.skeleton_neighbours = []
self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int)
@ -105,13 +109,14 @@ class Edge:
attributes: Optional[Dict[str, Any]] = None,
directed: bool = True,
dimension: int = 1,
edge_penalty: float = 3.5,
):
if dimension <= 0:
raise InvalidDimensionsError()
self.node1 = node1
self.node2 = node2
self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float("inf")
self.attributes["vector_distance"] = edge_penalty
self.directed = directed
self.status = np.ones(dimension, dtype=int)

View file

@ -2,6 +2,8 @@ import io
import sys
import traceback
import cognee
def wrap_in_async_handler(user_code: str) -> str:
return (
@ -34,6 +36,7 @@ def run_in_local_sandbox(code, environment=None, loop=None):
environment["print"] = customPrintFunction
environment["running_loop"] = loop
environment["cognee"] = cognee
try:
exec(code, environment)

View file

@ -2,7 +2,7 @@ import os
import difflib
from cognee.shared.logging_utils import get_logger
from collections import deque
from typing import List, Tuple, Dict, Optional, Any, Union
from typing import List, Tuple, Dict, Optional, Any, Union, IO
from rdflib import Graph, URIRef, RDF, RDFS, OWL
from cognee.modules.ontology.exceptions import (
@ -26,44 +26,76 @@ class RDFLibOntologyResolver(BaseOntologyResolver):
def __init__(
self,
ontology_file: Optional[Union[str, List[str]]] = None,
ontology_file: Optional[Union[str, List[str], IO, List[IO]]] = None,
matching_strategy: Optional[MatchingStrategy] = None,
) -> None:
super().__init__(matching_strategy)
self.ontology_file = ontology_file
try:
files_to_load = []
self.graph = None
if ontology_file is not None:
if isinstance(ontology_file, str):
files_to_load = []
file_objects = []
if hasattr(ontology_file, "read"):
file_objects = [ontology_file]
elif isinstance(ontology_file, str):
files_to_load = [ontology_file]
elif isinstance(ontology_file, list):
files_to_load = ontology_file
if all(hasattr(item, "read") for item in ontology_file):
file_objects = ontology_file
else:
files_to_load = ontology_file
else:
raise ValueError(
f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}"
f"ontology_file must be a string, list of strings, file-like object, list of file-like objects, or None. Got: {type(ontology_file)}"
)
if files_to_load:
self.graph = Graph()
loaded_files = []
for file_path in files_to_load:
if os.path.exists(file_path):
self.graph.parse(file_path)
loaded_files.append(file_path)
logger.info("Ontology loaded successfully from file: %s", file_path)
else:
logger.warning(
"Ontology file '%s' not found. Skipping this file.",
file_path,
if file_objects:
self.graph = Graph()
loaded_objects = []
for file_obj in file_objects:
try:
content = file_obj.read()
self.graph.parse(data=content, format="xml")
loaded_objects.append(file_obj)
logger.info("Ontology loaded successfully from file object")
except Exception as e:
logger.warning("Failed to parse ontology file object: %s", str(e))
if not loaded_objects:
logger.info(
"No valid ontology file objects found. No owl ontology will be attached to the graph."
)
self.graph = None
else:
logger.info("Total ontology file objects loaded: %d", len(loaded_objects))
if not loaded_files:
logger.info(
"No valid ontology files found. No owl ontology will be attached to the graph."
)
self.graph = None
elif files_to_load:
self.graph = Graph()
loaded_files = []
for file_path in files_to_load:
if os.path.exists(file_path):
self.graph.parse(file_path)
loaded_files.append(file_path)
logger.info("Ontology loaded successfully from file: %s", file_path)
else:
logger.warning(
"Ontology file '%s' not found. Skipping this file.",
file_path,
)
if not loaded_files:
logger.info(
"No valid ontology files found. No owl ontology will be attached to the graph."
)
self.graph = None
else:
logger.info("Total ontology files loaded: %d", len(loaded_files))
else:
logger.info("Total ontology files loaded: %d", len(loaded_files))
logger.info(
"No ontology file provided. No owl ontology will be attached to the graph."
)
else:
logger.info(
"No ontology file provided. No owl ontology will be attached to the graph."

View file

@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
async def get_completion(

View file

@ -65,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -74,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
self.validation_system_prompt_path = validation_system_prompt_path
self.validation_user_prompt_path = validation_user_prompt_path

View file

@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction
@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
self.system_prompt_path = system_prompt_path
self.system_prompt = system_prompt
self.top_k = top_k if top_k is not None else 5
self.wide_search_top_k = wide_search_top_k
self.node_type = node_type
self.node_name = node_name
self.triplet_distance_penalty = triplet_distance_penalty
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""
@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
collections=vector_index_collections or None,
node_type=self.node_type,
node_name=self.node_name,
wide_search_top_k=self.wide_search_top_k,
triplet_distance_penalty=self.triplet_distance_penalty,
)
return found_triplets
@ -141,6 +147,10 @@ class GraphCompletionRetriever(BaseGraphRetriever):
return triplets
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):
context = await self.resolve_edges_to_text(triplets)
return context
async def get_completion(
self,
query: str,

View file

@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(
@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
self.summarize_prompt_path = summarize_prompt_path

View file

@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path

View file

@ -58,6 +58,8 @@ async def get_memory_fragment(
properties_to_project: Optional[List[str]] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
relevant_ids_to_filter: Optional[List[str]] = None,
triplet_distance_penalty: Optional[float] = 3.5,
) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
if properties_to_project is None:
@ -74,6 +76,8 @@ async def get_memory_fragment(
edge_properties_to_project=["relationship_name", "edge_text"],
node_type=node_type,
node_name=node_name,
relevant_ids_to_filter=relevant_ids_to_filter,
triplet_distance_penalty=triplet_distance_penalty,
)
except EntityNotFoundError:
@ -95,6 +99,8 @@ async def brute_force_triplet_search(
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Edge]:
"""
Performs a brute force search to retrieve the top triplets from the graph.
@ -107,6 +113,8 @@ async def brute_force_triplet_search(
memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse.
node_type: node type to filter
node_name: node name to filter
wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections
triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection
Returns:
list: The top triplet results.
@ -116,10 +124,10 @@ async def brute_force_triplet_search(
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
if memory_fragment is None:
memory_fragment = await get_memory_fragment(
properties_to_project, node_type=node_type, node_name=node_name
)
# Setting wide search limit based on the parameters
non_global_search = node_name is None
wide_search_limit = wide_search_top_k if non_global_search else None
if collections is None:
collections = [
@ -140,7 +148,7 @@ async def brute_force_triplet_search(
async def search_in_collection(collection_name: str):
try:
return await vector_engine.search(
collection_name=collection_name, query_vector=query_vector, limit=None
collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit
)
except CollectionNotFoundError:
return []
@ -156,15 +164,38 @@ async def brute_force_triplet_search(
return []
# Final statistics
projection_time = time.time() - start_time
vector_collection_search_time = time.time() - start_time
logger.info(
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s"
f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s"
)
node_distances = {collection: result for collection, result in zip(collections, results)}
edge_distances = node_distances.get("EdgeType_relationship_name", None)
if wide_search_limit is not None:
relevant_ids_to_filter = list(
{
str(getattr(scored_node, "id"))
for collection_name, score_collection in node_distances.items()
if collection_name != "EdgeType_relationship_name"
and isinstance(score_collection, (list, tuple))
for scored_node in score_collection
if getattr(scored_node, "id", None)
}
)
else:
relevant_ids_to_filter = None
if memory_fragment is None:
memory_fragment = await get_memory_fragment(
properties_to_project=properties_to_project,
node_type=node_type,
node_name=node_name,
relevant_ids_to_filter=relevant_ids_to_filter,
triplet_distance_penalty=triplet_distance_penalty,
)
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
await memory_fragment.map_vector_distances_to_graph_edges(
vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances

View file

@ -37,6 +37,8 @@ async def get_search_type_tools(
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> list:
search_tasks: dict[SearchType, List[Callable]] = {
SearchType.SUMMARIES: [
@ -67,6 +69,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
@ -75,6 +79,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_COMPLETION_COT: [
@ -85,6 +91,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
@ -93,6 +101,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
@ -103,6 +113,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
@ -111,6 +123,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.GRAPH_SUMMARY_COMPLETION: [
@ -121,6 +135,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
@ -129,6 +145,8 @@ async def get_search_type_tools(
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.CODE: [
@ -145,8 +163,16 @@ async def get_search_type_tools(
],
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
SearchType.TEMPORAL: [
TemporalRetriever(top_k=top_k).get_completion,
TemporalRetriever(top_k=top_k).get_context,
TemporalRetriever(
top_k=top_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_completion,
TemporalRetriever(
top_k=top_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
).get_context,
],
SearchType.CHUNKS_LEXICAL: (
lambda _r=JaccardChunksRetriever(top_k=top_k): [

View file

@ -24,6 +24,8 @@ async def no_access_control_search(
last_k: Optional[int] = None,
only_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
search_tools = await get_search_type_tools(
query_type=query_type,
@ -35,6 +37,8 @@ async def no_access_control_search(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
graph_engine = await get_graph_engine()
is_empty = await graph_engine.is_empty()

View file

@ -47,6 +47,8 @@ async def search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[CombinedSearchResult, List[SearchResult]]:
"""
@ -90,6 +92,8 @@ async def search(
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
else:
search_results = [
@ -105,6 +109,8 @@ async def search(
last_k=last_k,
only_context=only_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
]
@ -219,6 +225,8 @@ async def authorized_search(
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
@ -246,6 +254,8 @@ async def authorized_search(
last_k=last_k,
only_context=True,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
context = {}
@ -267,6 +277,8 @@ async def authorized_search(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
@ -306,6 +318,7 @@ async def authorized_search(
last_k=last_k,
only_context=only_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
)
return search_results
@ -325,6 +338,8 @@ async def search_in_datasets_context(
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@ -345,6 +360,8 @@ async def search_in_datasets_context(
only_context: bool = False,
context: Optional[Any] = None,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
@ -378,6 +395,8 @@ async def search_in_datasets_context(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
@ -413,6 +432,8 @@ async def search_in_datasets_context(
only_context=only_context,
context=context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
)

View file

@ -1,4 +1,5 @@
from .chunk_by_word import chunk_by_word
from .chunk_by_sentence import chunk_by_sentence
from .chunk_by_paragraph import chunk_by_paragraph
from .chunk_by_row import chunk_by_row
from .remove_disconnected_chunks import remove_disconnected_chunks

View file

@ -0,0 +1,94 @@
from typing import Any, Dict, Iterator
from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
def _get_pair_size(pair_text: str) -> int:
"""
Calculate the size of a given text in terms of tokens.
If an embedding engine's tokenizer is available, count the tokens for the provided word.
If the tokenizer is not available, assume the word counts as one token.
Parameters:
-----------
- pair_text (str): The key:value pair text for which the token size is to be calculated.
Returns:
--------
- int: The number of tokens representing the text, typically an integer, depending
on the tokenizer's output.
"""
embedding_engine = get_embedding_engine()
if embedding_engine.tokenizer:
return embedding_engine.tokenizer.count_tokens(pair_text)
else:
return 3
def chunk_by_row(
data: str,
max_chunk_size,
) -> Iterator[Dict[str, Any]]:
"""
Chunk the input text by row while enabling exact text reconstruction.
This function divides the given text data into smaller chunks on a line-by-line basis,
ensuring that the size of each chunk is less than or equal to the specified maximum
chunk size. It guarantees that when the generated chunks are concatenated, they
reproduce the original text accurately. The tokenization process is handled by
adapters compatible with the vector engine's embedding model.
Parameters:
-----------
- data (str): The input text to be chunked.
- max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or
words.
"""
current_chunk_list = []
chunk_index = 0
current_chunk_size = 0
lines = data.split("\n\n")
for line in lines:
pairs_text = line.split(", ")
for pair_text in pairs_text:
pair_size = _get_pair_size(pair_text)
if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size):
# Yield current cut chunk
current_chunk = ", ".join(current_chunk_list)
chunk_dict = {
"text": current_chunk,
"chunk_size": current_chunk_size,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"chunk_index": chunk_index,
"cut_type": "row_cut",
}
yield chunk_dict
# Start new chunk with current pair text
current_chunk_list = []
current_chunk_size = 0
chunk_index += 1
current_chunk_list.append(pair_text)
current_chunk_size += pair_size
# Yield row chunk
current_chunk = ", ".join(current_chunk_list)
if current_chunk:
chunk_dict = {
"text": current_chunk,
"chunk_size": current_chunk_size,
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
"chunk_index": chunk_index,
"cut_type": "row_end",
}
yield chunk_dict

View file

@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import (
ImageDocument,
TextDocument,
UnstructuredDocument,
CsvDocument,
)
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.engine.utils.generate_node_id import generate_node_id
@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError
EXTENSION_TO_DOCUMENT_CLASS = {
"pdf": PdfDocument, # Text documents
"txt": TextDocument,
"csv": CsvDocument,
"docx": UnstructuredDocument,
"doc": UnstructuredDocument,
"odt": UnstructuredDocument,

View file

@ -0,0 +1,70 @@
import os
import sys
import uuid
import pytest
import pathlib
from unittest.mock import patch
from cognee.modules.chunking.CsvChunker import CsvChunker
from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row")
GROUND_TRUTH = {
"chunk_size_10": [
{"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0},
{"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1},
{"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2},
{"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3},
],
"chunk_size_128": [
{"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0},
{"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1},
],
}
@pytest.mark.parametrize(
"input_file,chunk_size",
[("example_with_header.csv", 10), ("example_with_header.csv", 128)],
)
@patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine)
@pytest.mark.asyncio
async def test_CsvDocument(mock_engine, input_file, chunk_size):
# Define file paths of test data
csv_file_path = os.path.join(
pathlib.Path(__file__).parent.parent.parent,
"test_data",
input_file,
)
# Define test documents
csv_document = CsvDocument(
id=uuid.uuid4(),
name="example_with_header.csv",
raw_data_location=csv_file_path,
external_metadata="",
mime_type="text/csv",
)
# TEST CSV
ground_truth_key = f"chunk_size_{chunk_size}"
async for ground_truth, row_data in async_gen_zip(
GROUND_TRUTH[ground_truth_key],
csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size),
):
assert ground_truth["token_count"] == row_data.chunk_size, (
f'{ground_truth["token_count"] = } != {row_data.chunk_size = }'
)
assert ground_truth["len_text"] == len(row_data.text), (
f'{ground_truth["len_text"] = } != {len(row_data.text) = }'
)
assert ground_truth["cut_type"] == row_data.cut_type, (
f'{ground_truth["cut_type"] = } != {row_data.cut_type = }'
)
assert ground_truth["chunk_index"] == row_data.chunk_index, (
f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }'
)

View file

@ -5,7 +5,7 @@ from cognee.tasks.web_scraper import DefaultUrlCrawler
@pytest.mark.asyncio
async def test_fetch():
crawler = DefaultUrlCrawler()
url = "https://en.wikipedia.org/wiki/Large_language_model"
url = "http://example.com/"
results = await crawler.fetch_urls(url)
assert len(results) == 1
assert isinstance(results, dict)

View file

@ -11,7 +11,7 @@ skip_in_ci = pytest.mark.skipif(
@skip_in_ci
@pytest.mark.asyncio
async def test_fetch():
url = "https://en.wikipedia.org/wiki/Large_language_model"
url = "http://example.com/"
results = await fetch_with_tavily(url)
assert isinstance(results, dict)
assert len(results) == 1

View file

@ -14,9 +14,7 @@ async def test_url_saves_as_html_file():
await cognee.prune.prune_system(metadata=True)
try:
original_file_path = await save_data_item_to_storage(
"https://en.wikipedia.org/wiki/Large_language_model"
)
original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
file = Path(file_path)
@ -44,9 +42,7 @@ async def test_saved_html_is_valid():
await cognee.prune.prune_system(metadata=True)
try:
original_file_path = await save_data_item_to_storage(
"https://en.wikipedia.org/wiki/Large_language_model"
)
original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
content = Path(file_path).read_text()
@ -72,7 +68,7 @@ async def test_add_url():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add("https://en.wikipedia.org/wiki/Large_language_model")
await cognee.add("http://example.com/")
skip_in_ci = pytest.mark.skipif(
@ -88,7 +84,7 @@ async def test_add_url_with_tavily():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add("https://en.wikipedia.org/wiki/Large_language_model")
await cognee.add("http://example.com/")
@pytest.mark.asyncio
@ -98,7 +94,7 @@ async def test_add_url_without_incremental_loading():
try:
await cognee.add(
"https://en.wikipedia.org/wiki/Large_language_model",
"http://example.com/",
incremental_loading=False,
)
except Exception as e:
@ -112,7 +108,7 @@ async def test_add_url_with_incremental_loading():
try:
await cognee.add(
"https://en.wikipedia.org/wiki/Large_language_model",
"http://example.com/",
incremental_loading=True,
)
except Exception as e:
@ -125,7 +121,7 @@ async def test_add_url_can_define_preferred_loader_as_list_of_str():
await cognee.prune.prune_system(metadata=True)
await cognee.add(
"https://en.wikipedia.org/wiki/Large_language_model",
"http://example.com/",
preferred_loaders=["beautiful_soup_loader"],
)
@ -144,7 +140,7 @@ async def test_add_url_with_extraction_rules():
try:
await cognee.add(
"https://en.wikipedia.org/wiki/Large_language_model",
"http://example.com/",
preferred_loaders={"beautiful_soup_loader": {"extraction_rules": extraction_rules}},
)
except Exception as e:
@ -163,9 +159,7 @@ async def test_loader_is_none_by_default():
}
try:
original_file_path = await save_data_item_to_storage(
"https://en.wikipedia.org/wiki/Large_language_model"
)
original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
file = Path(file_path)
@ -196,9 +190,7 @@ async def test_beautiful_soup_loader_is_selected_loader_if_preferred_loader_prov
}
try:
original_file_path = await save_data_item_to_storage(
"https://en.wikipedia.org/wiki/Large_language_model"
)
original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
file = Path(file_path)
@ -225,9 +217,7 @@ async def test_beautiful_soup_loader_works_with_and_without_arguments():
await cognee.prune.prune_system(metadata=True)
try:
original_file_path = await save_data_item_to_storage(
"https://en.wikipedia.org/wiki/Large_language_model"
)
original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
file = Path(file_path)
@ -263,9 +253,7 @@ async def test_beautiful_soup_loader_successfully_loads_file_if_required_args_pr
await cognee.prune.prune_system(metadata=True)
try:
original_file_path = await save_data_item_to_storage(
"https://en.wikipedia.org/wiki/Large_language_model"
)
original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
file = Path(file_path)
@ -302,9 +290,7 @@ async def test_beautiful_soup_loads_file_successfully():
}
try:
original_file_path = await save_data_item_to_storage(
"https://en.wikipedia.org/wiki/Large_language_model"
)
original_file_path = await save_data_item_to_storage("http://example.com/")
file_path = get_data_file_path(original_file_path)
assert file_path.endswith(".html")
original_file = Path(file_path)

View file

@ -7,6 +7,7 @@ import requests
from pathlib import Path
import sys
import uuid
import json
class TestCogneeServerStart(unittest.TestCase):
@ -90,12 +91,71 @@ class TestCogneeServerStart(unittest.TestCase):
)
}
payload = {"datasets": [dataset_name]}
ontology_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
payload = {"datasets": [dataset_name], "ontology_key": [ontology_key]}
add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50)
if add_response.status_code not in [200, 201]:
add_response.raise_for_status()
ontology_content = b"""<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:owl="http://www.w3.org/2002/07/owl#"
xmlns:rdfs="http://www.w3.org/2000/01/rdf-schema#"
xmlns="http://example.org/ontology#"
xml:base="http://example.org/ontology">
<owl:Ontology rdf:about="http://example.org/ontology"/>
<!-- Classes -->
<owl:Class rdf:ID="Problem"/>
<owl:Class rdf:ID="HardwareProblem"/>
<owl:Class rdf:ID="SoftwareProblem"/>
<owl:Class rdf:ID="Concept"/>
<owl:Class rdf:ID="Object"/>
<owl:Class rdf:ID="Joke"/>
<owl:Class rdf:ID="Image"/>
<owl:Class rdf:ID="Person"/>
<rdf:Description rdf:about="#HardwareProblem">
<rdfs:subClassOf rdf:resource="#Problem"/>
<rdfs:comment>A failure caused by physical components.</rdfs:comment>
</rdf:Description>
<rdf:Description rdf:about="#SoftwareProblem">
<rdfs:subClassOf rdf:resource="#Problem"/>
<rdfs:comment>An error caused by software logic or configuration.</rdfs:comment>
</rdf:Description>
<rdf:Description rdf:about="#Person">
<rdfs:comment>A human being or individual.</rdfs:comment>
</rdf:Description>
<!-- Individuals -->
<Person rdf:ID="programmers">
<rdfs:label>Programmers</rdfs:label>
</Person>
<Object rdf:ID="light_bulb">
<rdfs:label>Light Bulb</rdfs:label>
</Object>
<HardwareProblem rdf:ID="hardware_problem">
<rdfs:label>Hardware Problem</rdfs:label>
</HardwareProblem>
</rdf:RDF>"""
ontology_response = requests.post(
"http://127.0.0.1:8000/api/v1/ontologies",
headers=headers,
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
data={
"ontology_key": json.dumps([ontology_key]),
"description": json.dumps(["Test ontology"]),
},
)
self.assertEqual(ontology_response.status_code, 200)
# Cognify request
url = "http://127.0.0.1:8000/api/v1/cognify"
headers = {
@ -107,6 +167,29 @@ class TestCogneeServerStart(unittest.TestCase):
if cognify_response.status_code not in [200, 201]:
cognify_response.raise_for_status()
datasets_response = requests.get("http://127.0.0.1:8000/api/v1/datasets", headers=headers)
datasets = datasets_response.json()
dataset_id = None
for dataset in datasets:
if dataset["name"] == dataset_name:
dataset_id = dataset["id"]
break
graph_response = requests.get(
f"http://127.0.0.1:8000/api/v1/datasets/{dataset_id}/graph", headers=headers
)
self.assertEqual(graph_response.status_code, 200)
graph_data = graph_response.json()
ontology_nodes = [
node for node in graph_data.get("nodes") if node.get("properties").get("ontology_valid")
]
self.assertGreater(
len(ontology_nodes), 0, "No ontology nodes found - ontology was not integrated"
)
# TODO: Add test to verify cognify pipeline is complete before testing search
# Search request

View file

@ -0,0 +1,3 @@
id,name,age,city,country
1,Eric,30,Beijing,China
2,Joe,35,Berlin,Germany
1 id name age city country
2 1 Eric 30 Beijing China
3 2 Joe 35 Berlin Germany

View file

@ -0,0 +1,272 @@
import pytest
import uuid
from fastapi.testclient import TestClient
from unittest.mock import patch, Mock, AsyncMock
from types import SimpleNamespace
import importlib
from cognee.api.client import app
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def mock_user():
user = Mock()
user.id = "test-user-123"
return user
@pytest.fixture
def mock_default_user():
"""Mock default user for testing."""
return SimpleNamespace(
id=str(uuid.uuid4()),
email="default@example.com",
is_active=True,
tenant_id=str(uuid.uuid4()),
)
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_ontology_success(mock_get_default_user, client, mock_default_user):
"""Test successful ontology upload"""
import json
mock_get_default_user.return_value = mock_default_user
ontology_content = (
b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
)
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
response = client.post(
"/api/v1/ontologies",
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])},
)
assert response.status_code == 200
data = response.json()
assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
assert "uploaded_at" in data["uploaded_ontologies"][0]
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user):
"""Test 400 response for non-.owl files"""
mock_get_default_user.return_value = mock_default_user
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
response = client.post(
"/api/v1/ontologies",
files={"ontology_file": ("test.txt", b"not xml")},
data={"ontology_key": unique_key},
)
assert response.status_code == 400
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user):
"""Test 400 response for missing file or key"""
import json
mock_get_default_user.return_value = mock_default_user
# Missing file
response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])})
assert response.status_code == 400
# Missing key
response = client.post(
"/api/v1/ontologies", files=[("ontology_file", ("test.owl", b"xml", "application/xml"))]
)
assert response.status_code == 400
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user):
"""Test behavior when default user is provided (no explicit authentication)"""
import json
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
mock_get_default_user.return_value = mock_default_user
response = client.post(
"/api/v1/ontologies",
files=[("ontology_file", ("test.owl", b"<rdf></rdf>", "application/xml"))],
data={"ontology_key": json.dumps([unique_key])},
)
# The current system provides a default user when no explicit authentication is given
# This test verifies the system works with conditional authentication
assert response.status_code == 200
data = response.json()
assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
assert "uploaded_at" in data["uploaded_ontologies"][0]
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user):
"""Test uploading multiple ontology files in single request"""
import io
mock_get_default_user.return_value = mock_default_user
# Create mock files
file1_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
file2_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
files = [
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
]
data = {
"ontology_key": '["vehicles", "manufacturers"]',
"descriptions": '["Base vehicles", "Car manufacturers"]',
}
response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 200
result = response.json()
assert "uploaded_ontologies" in result
assert len(result["uploaded_ontologies"]) == 2
assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles"
assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers"
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user):
"""Test that upload endpoint accepts array parameters"""
import io
import json
mock_get_default_user.return_value = mock_default_user
file_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))]
data = {
"ontology_key": json.dumps(["single_key"]),
"descriptions": json.dumps(["Single ontology"]),
}
response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 200
result = response.json()
assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key"
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user):
"""Test cognify endpoint accepts multiple ontology keys"""
payload = {
"datasets": ["test_dataset"],
"ontology_key": ["ontology1", "ontology2"], # Array instead of string
"run_in_background": False,
}
response = client.post("/api/v1/cognify", json=payload)
# Should not fail due to ontology_key type
assert response.status_code in [200, 400, 409] # May fail for other reasons, not type
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user):
"""Test complete workflow: upload multiple ontologies → cognify with multiple keys"""
import io
import json
mock_get_default_user.return_value = mock_default_user
# Step 1: Upload multiple ontologies
file1_content = b"""<?xml version="1.0"?>
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:owl="http://www.w3.org/2002/07/owl#">
<owl:Class rdf:ID="Vehicle"/>
</rdf:RDF>"""
file2_content = b"""<?xml version="1.0"?>
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:owl="http://www.w3.org/2002/07/owl#">
<owl:Class rdf:ID="Manufacturer"/>
</rdf:RDF>"""
files = [
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
]
data = {
"ontology_key": json.dumps(["vehicles", "manufacturers"]),
"descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]),
}
upload_response = client.post("/api/v1/ontologies", files=files, data=data)
assert upload_response.status_code == 200
# Step 2: Verify ontologies are listed
list_response = client.get("/api/v1/ontologies")
assert list_response.status_code == 200
ontologies = list_response.json()
assert "vehicles" in ontologies
assert "manufacturers" in ontologies
# Step 3: Test cognify with multiple ontologies
cognify_payload = {
"datasets": ["test_dataset"],
"ontology_key": ["vehicles", "manufacturers"],
"run_in_background": False,
}
cognify_response = client.post("/api/v1/cognify", json=cognify_payload)
# Should not fail due to ontology handling (may fail for dataset reasons)
assert cognify_response.status_code != 400 # Not a validation error
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_multifile_error_handling(mock_get_default_user, client, mock_default_user):
"""Test error handling for invalid multifile uploads"""
import io
import json
# Test mismatched array lengths
file_content = b"<rdf:RDF></rdf:RDF>"
files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))]
data = {
"ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file
"descriptions": json.dumps(["desc1"]),
}
response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 400
assert "Number of keys must match number of files" in response.json()["error"]
# Test duplicate keys
files = [
("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")),
("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")),
]
data = {
"ontology_key": json.dumps(["duplicate", "duplicate"]),
"descriptions": json.dumps(["desc1", "desc2"]),
}
response = client.post("/api/v1/ontologies", files=files, data=data)
assert response.status_code == 400
assert "Duplicate ontology keys not allowed" in response.json()["error"]
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user):
"""Test cognify with non-existent ontology key"""
mock_get_default_user.return_value = mock_default_user
payload = {
"datasets": ["test_dataset"],
"ontology_key": ["nonexistent_key"],
"run_in_background": False,
}
response = client.post("/api/v1/cognify", json=payload)
assert response.status_code == 409
assert "Ontology key 'nonexistent_key' not found" in response.json()["error"]

View file

@ -9,7 +9,7 @@ def test_node_initialization():
"""Test that a Node is initialized correctly."""
node = Node("node1", {"attr1": "value1"}, dimension=2)
assert node.id == "node1"
assert node.attributes == {"attr1": "value1", "vector_distance": np.inf}
assert node.attributes == {"attr1": "value1", "vector_distance": 3.5}
assert len(node.status) == 2
assert np.all(node.status == 1)
@ -96,7 +96,7 @@ def test_edge_initialization():
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
assert edge.node1 == node1
assert edge.node2 == node2
assert edge.attributes == {"vector_distance": np.inf, "weight": 10}
assert edge.attributes == {"vector_distance": 3.5, "weight": 10}
assert edge.directed is False
assert len(edge.status) == 2
assert np.all(edge.status == 1)

View file

@ -1,4 +1,5 @@
import pytest
from unittest.mock import AsyncMock
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
@ -11,6 +12,30 @@ def setup_graph():
return CogneeGraph()
@pytest.fixture
def mock_adapter():
"""Fixture to create a mock adapter for database operations."""
adapter = AsyncMock()
return adapter
@pytest.fixture
def mock_vector_engine():
"""Fixture to create a mock vector engine."""
engine = AsyncMock()
engine.search = AsyncMock()
return engine
class MockScoredResult:
"""Mock class for vector search results."""
def __init__(self, id, score, payload=None):
self.id = id
self.score = score
self.payload = payload or {}
def test_add_node_success(setup_graph):
"""Test successful addition of a node."""
graph = setup_graph
@ -73,3 +98,433 @@ def test_get_edges_nonexistent_node(setup_graph):
graph = setup_graph
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
graph.get_edges_from_node("nonexistent")
@pytest.mark.asyncio
async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter):
"""Test projecting a full graph from database."""
graph = setup_graph
nodes_data = [
("1", {"name": "Node1", "description": "First node"}),
("2", {"name": "Node2", "description": "Second node"}),
]
edges_data = [
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
]
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
await graph.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name", "description"],
edge_properties_to_project=["relationship_name"],
)
assert len(graph.nodes) == 2
assert len(graph.edges) == 1
assert graph.get_node("1") is not None
assert graph.get_node("2") is not None
assert graph.edges[0].node1.id == "1"
assert graph.edges[0].node2.id == "2"
@pytest.mark.asyncio
async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter):
"""Test projecting an ID-filtered graph from database."""
graph = setup_graph
nodes_data = [
("1", {"name": "Node1"}),
("2", {"name": "Node2"}),
]
edges_data = [
("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}),
]
mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
await graph.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name"],
edge_properties_to_project=["relationship_name"],
relevant_ids_to_filter=["1", "2"],
)
assert len(graph.nodes) == 2
assert len(graph.edges) == 1
mock_adapter.get_id_filtered_graph_data.assert_called_once()
@pytest.mark.asyncio
async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter):
"""Test projecting a nodeset subgraph filtered by node type and name."""
graph = setup_graph
nodes_data = [
("1", {"name": "Alice", "type": "Person"}),
("2", {"name": "Bob", "type": "Person"}),
]
edges_data = [
("1", "2", "KNOWS", {"relationship_name": "knows"}),
]
mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data))
await graph.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name", "type"],
edge_properties_to_project=["relationship_name"],
node_type="Person",
node_name=["Alice"],
)
assert len(graph.nodes) == 2
assert graph.get_node("1") is not None
assert len(graph.edges) == 1
mock_adapter.get_nodeset_subgraph.assert_called_once()
@pytest.mark.asyncio
async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter):
"""Test projecting empty graph raises EntityNotFoundError."""
graph = setup_graph
mock_adapter.get_graph_data = AsyncMock(return_value=([], []))
with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."):
await graph.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name"],
edge_properties_to_project=[],
)
@pytest.mark.asyncio
async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter):
"""Test that edges referencing missing nodes raise error."""
graph = setup_graph
nodes_data = [
("1", {"name": "Node1"}),
]
edges_data = [
("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}),
]
mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data))
with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"):
await graph.project_graph_from_db(
adapter=mock_adapter,
node_properties_to_project=["name"],
edge_properties_to_project=["relationship_name"],
)
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_nodes(setup_graph):
"""Test mapping vector distances to graph nodes."""
graph = setup_graph
node1 = Node("1", {"name": "Node1"})
node2 = Node("2", {"name": "Node2"})
graph.add_node(node1)
graph.add_node(node2)
node_distances = {
"Entity_name": [
MockScoredResult("1", 0.95),
MockScoredResult("2", 0.87),
]
}
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
@pytest.mark.asyncio
async def test_map_vector_distances_partial_node_coverage(setup_graph):
"""Test mapping vector distances when only some nodes have results."""
graph = setup_graph
node1 = Node("1", {"name": "Node1"})
node2 = Node("2", {"name": "Node2"})
node3 = Node("3", {"name": "Node3"})
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
node_distances = {
"Entity_name": [
MockScoredResult("1", 0.95),
MockScoredResult("2", 0.87),
]
}
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
assert graph.get_node("3").attributes.get("vector_distance") == 3.5
@pytest.mark.asyncio
async def test_map_vector_distances_multiple_categories(setup_graph):
"""Test mapping vector distances from multiple collection categories."""
graph = setup_graph
# Create nodes
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
node4 = Node("4")
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
graph.add_node(node4)
node_distances = {
"Entity_name": [
MockScoredResult("1", 0.95),
MockScoredResult("2", 0.87),
],
"TextSummary_text": [
MockScoredResult("3", 0.92),
],
}
await graph.map_vector_distances_to_graph_nodes(node_distances)
assert graph.get_node("1").attributes.get("vector_distance") == 0.95
assert graph.get_node("2").attributes.get("vector_distance") == 0.87
assert graph.get_node("3").attributes.get("vector_distance") == 0.92
assert graph.get_node("4").attributes.get("vector_distance") == 3.5
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine):
"""Test mapping vector distances to edges when edge_distances provided."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(
node1,
node2,
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
)
graph.add_edge(edge)
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
assert graph.edges[0].attributes.get("vector_distance") == 0.92
@pytest.mark.asyncio
async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine):
"""Test mapping edge distances when searching for them."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(
node1,
node2,
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
)
graph.add_edge(edge)
mock_vector_engine.search.return_value = [
MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=None,
)
mock_vector_engine.search.assert_called_once()
assert graph.edges[0].attributes.get("vector_distance") == 0.88
@pytest.mark.asyncio
async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine):
"""Test mapping edge distances when only some edges have results."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"})
edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"})
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
assert graph.edges[0].attributes.get("vector_distance") == 0.92
assert graph.edges[1].attributes.get("vector_distance") == 3.5
@pytest.mark.asyncio
async def test_map_vector_distances_edges_fallback_to_relationship_type(
setup_graph, mock_vector_engine
):
"""Test that edge mapping falls back to relationship_type when edge_text is missing."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(
node1,
node2,
attributes={"relationship_type": "KNOWS"},
)
graph.add_edge(edge)
edge_distances = [
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
assert graph.edges[0].attributes.get("vector_distance") == 0.85
@pytest.mark.asyncio
async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine):
"""Test edge mapping when no edges match the distance results."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(
node1,
node2,
attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"},
)
graph.add_edge(edge)
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
]
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[0.1, 0.2, 0.3],
edge_distances=edge_distances,
)
assert graph.edges[0].attributes.get("vector_distance") == 3.5
@pytest.mark.asyncio
async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine):
"""Test that invalid query vector raises error."""
graph = setup_graph
with pytest.raises(ValueError, match="Failed to generate query embedding"):
await graph.map_vector_distances_to_graph_edges(
vector_engine=mock_vector_engine,
query_vector=[],
edge_distances=None,
)
@pytest.mark.asyncio
async def test_calculate_top_triplet_importances(setup_graph):
"""Test calculating top triplet importances by score."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
node3 = Node("3")
node4 = Node("4")
node1.add_attribute("vector_distance", 0.9)
node2.add_attribute("vector_distance", 0.8)
node3.add_attribute("vector_distance", 0.7)
node4.add_attribute("vector_distance", 0.6)
graph.add_node(node1)
graph.add_node(node2)
graph.add_node(node3)
graph.add_node(node4)
edge1 = Edge(node1, node2)
edge2 = Edge(node2, node3)
edge3 = Edge(node3, node4)
edge1.add_attribute("vector_distance", 0.85)
edge2.add_attribute("vector_distance", 0.75)
edge3.add_attribute("vector_distance", 0.65)
graph.add_edge(edge1)
graph.add_edge(edge2)
graph.add_edge(edge3)
top_triplets = await graph.calculate_top_triplet_importances(k=2)
assert len(top_triplets) == 2
assert top_triplets[0] == edge3
assert top_triplets[1] == edge2
@pytest.mark.asyncio
async def test_calculate_top_triplet_importances_default_distances(setup_graph):
"""Test calculating importances when nodes/edges have no vector distances."""
graph = setup_graph
node1 = Node("1")
node2 = Node("2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(node1, node2)
graph.add_edge(edge)
top_triplets = await graph.calculate_top_triplet_importances(k=1)
assert len(top_triplets) == 1
assert top_triplets[0] == edge

View file

@ -0,0 +1,582 @@
import pytest
from unittest.mock import AsyncMock, patch
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
brute_force_triplet_search,
get_memory_fragment,
)
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
class MockScoredResult:
"""Mock class for vector search results."""
def __init__(self, id, score, payload=None):
self.id = id
self.score = score
self.payload = payload or {}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_empty_query():
"""Test that empty query raises ValueError."""
with pytest.raises(ValueError, match="The query must be a non-empty string."):
await brute_force_triplet_search(query="")
@pytest.mark.asyncio
async def test_brute_force_triplet_search_none_query():
"""Test that None query raises ValueError."""
with pytest.raises(ValueError, match="The query must be a non-empty string."):
await brute_force_triplet_search(query=None)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_negative_top_k():
"""Test that negative top_k raises ValueError."""
with pytest.raises(ValueError, match="top_k must be a positive integer."):
await brute_force_triplet_search(query="test query", top_k=-1)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_zero_top_k():
"""Test that zero top_k raises ValueError."""
with pytest.raises(ValueError, match="top_k must be a positive integer."):
await brute_force_triplet_search(query="test query", top_k=0)
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_limit_global_search():
"""Test that wide_search_limit is applied for global search (node_name=None)."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(
query="test",
node_name=None, # Global search
wide_search_top_k=75,
)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] == 75
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_limit_filtered_search():
"""Test that wide_search_limit is None for filtered search (node_name provided)."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(
query="test",
node_name=["Node1"],
wide_search_top_k=50,
)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] is None
@pytest.mark.asyncio
async def test_brute_force_triplet_search_wide_search_default():
"""Test that wide_search_top_k defaults to 100."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", node_name=None)
for call in mock_vector_engine.search.call_args_list:
assert call[1]["limit"] == 100
@pytest.mark.asyncio
async def test_brute_force_triplet_search_default_collections():
"""Test that default collections are used when none provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test")
expected_collections = [
"Entity_name",
"TextSummary_text",
"EntityType_name",
"DocumentChunk_text",
]
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert call_collections == expected_collections
@pytest.mark.asyncio
async def test_brute_force_triplet_search_custom_collections():
"""Test that custom collections are used when provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
custom_collections = ["CustomCol1", "CustomCol2"]
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query="test", collections=custom_collections)
call_collections = [
call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list
]
assert call_collections == custom_collections
@pytest.mark.asyncio
async def test_brute_force_triplet_search_all_collections_empty():
"""Test that empty list is returned when all collections return no results."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
results = await brute_force_triplet_search(query="test")
assert results == []
# Tests for query embedding
@pytest.mark.asyncio
async def test_brute_force_triplet_search_embeds_query():
"""Test that query is embedded before searching."""
query_text = "test query"
expected_vector = [0.1, 0.2, 0.3]
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector])
mock_vector_engine.search = AsyncMock(return_value=[])
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
):
await brute_force_triplet_search(query=query_text)
mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text])
for call in mock_vector_engine.search.call_args_list:
assert call[1]["query_vector"] == expected_vector
@pytest.mark.asyncio
async def test_brute_force_triplet_search_extracts_node_ids_global_search():
"""Test that node IDs are extracted from search results for global search."""
scored_results = [
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
MockScoredResult("node3", 0.92),
]
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=scored_results)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_reuses_provided_fragment():
"""Test that provided memory fragment is reused instead of creating new one."""
provided_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment"
) as mock_get_fragment,
):
await brute_force_triplet_search(
query="test",
memory_fragment=provided_fragment,
node_name=["node"],
)
mock_get_fragment.assert_not_called()
@pytest.mark.asyncio
async def test_brute_force_triplet_search_creates_fragment_when_not_provided():
"""Test that memory fragment is created when not provided."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment,
):
await brute_force_triplet_search(query="test", node_name=["node"])
mock_get_fragment.assert_called_once()
@pytest.mark.asyncio
async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation():
"""Test that custom top_k is passed to importance calculation."""
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)])
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
),
):
custom_top_k = 15
await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"])
mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k)
@pytest.mark.asyncio
async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found():
"""Test that get_memory_fragment returns empty graph when entity not found."""
mock_graph_engine = AsyncMock()
mock_graph_engine.project_graph_from_db = AsyncMock(
side_effect=EntityNotFoundError("Entity not found")
)
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
return_value=mock_graph_engine,
):
fragment = await get_memory_fragment()
assert isinstance(fragment, CogneeGraph)
assert len(fragment.nodes) == 0
@pytest.mark.asyncio
async def test_get_memory_fragment_returns_empty_graph_on_error():
"""Test that get_memory_fragment returns empty graph on generic error."""
mock_graph_engine = AsyncMock()
mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error"))
with patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine",
return_value=mock_graph_engine,
):
fragment = await get_memory_fragment()
assert isinstance(fragment, CogneeGraph)
assert len(fragment.nodes) == 0
@pytest.mark.asyncio
async def test_brute_force_triplet_search_deduplicates_node_ids():
"""Test that duplicate node IDs across collections are deduplicated."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
]
elif collection_name == "TextSummary_text":
return [
MockScoredResult("node1", 0.90),
MockScoredResult("node3", 0.92),
]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"}
assert len(call_kwargs["relevant_ids_to_filter"]) == 3
@pytest.mark.asyncio
async def test_brute_force_triplet_search_excludes_edge_collection():
"""Test that EdgeType_relationship_name collection is excluded from ID extraction."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [MockScoredResult("node1", 0.95)]
elif collection_name == "EdgeType_relationship_name":
return [MockScoredResult("edge1", 0.88)]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(
query="test",
node_name=None,
collections=["Entity_name", "EdgeType_relationship_name"],
)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert call_kwargs["relevant_ids_to_filter"] == ["node1"]
@pytest.mark.asyncio
async def test_brute_force_triplet_search_skips_nodes_without_ids():
"""Test that nodes without ID attribute are skipped."""
class ScoredResultNoId:
"""Mock result without id attribute."""
def __init__(self, score):
self.score = score
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [
MockScoredResult("node1", 0.95),
ScoredResultNoId(0.90),
MockScoredResult("node2", 0.87),
]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_handles_tuple_results():
"""Test that both list and tuple results are handled correctly."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return (
MockScoredResult("node1", 0.95),
MockScoredResult("node2", 0.87),
)
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}
@pytest.mark.asyncio
async def test_brute_force_triplet_search_mixed_empty_collections():
"""Test ID extraction with mixed empty and non-empty collections."""
def search_side_effect(*args, **kwargs):
collection_name = kwargs.get("collection_name")
if collection_name == "Entity_name":
return [MockScoredResult("node1", 0.95)]
elif collection_name == "TextSummary_text":
return []
elif collection_name == "EntityType_name":
return [MockScoredResult("node2", 0.92)]
else:
return []
mock_vector_engine = AsyncMock()
mock_vector_engine.embedding_engine = AsyncMock()
mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
mock_vector_engine.search = AsyncMock(side_effect=search_side_effect)
mock_fragment = AsyncMock(
map_vector_distances_to_graph_nodes=AsyncMock(),
map_vector_distances_to_graph_edges=AsyncMock(),
calculate_top_triplet_importances=AsyncMock(return_value=[]),
)
with (
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine",
return_value=mock_vector_engine,
),
patch(
"cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment",
return_value=mock_fragment,
) as mock_get_fragment_fn,
):
await brute_force_triplet_search(query="test", node_name=None)
call_kwargs = mock_get_fragment_fn.call_args[1]
assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}

View file

@ -0,0 +1,52 @@
from itertools import product
import numpy as np
import pytest
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.tasks.chunks import chunk_by_row
INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA"
max_chunk_size_vals = [8, 32]
@pytest.mark.parametrize(
"input_text,max_chunk_size",
list(product([INPUT_TEXTS], max_chunk_size_vals)),
)
def test_chunk_by_row_isomorphism(input_text, max_chunk_size):
chunks = chunk_by_row(input_text, max_chunk_size)
reconstructed_text = ", ".join([chunk["text"] for chunk in chunks])
assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
)
@pytest.mark.parametrize(
"input_text,max_chunk_size",
list(product([INPUT_TEXTS], max_chunk_size_vals)),
)
def test_row_chunk_length(input_text, max_chunk_size):
chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size))
embedding_engine = get_embedding_engine()
chunk_lengths = np.array(
[embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks]
)
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
assert np.all(chunk_lengths <= max_chunk_size), (
f"{max_chunk_size = }: {larger_chunks} are too large"
)
@pytest.mark.parametrize(
"input_text,max_chunk_size",
list(product([INPUT_TEXTS], max_chunk_size_vals)),
)
def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size):
chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
f"{chunk_indices = } are not monotonically increasing"
)

10
poetry.lock generated
View file

@ -11656,7 +11656,9 @@ groups = ["main"]
files = [
{file = "SQLAlchemy-2.0.43-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:21ba7a08a4253c5825d1db389d4299f64a100ef9800e4624c8bf70d8f136e6ed"},
{file = "SQLAlchemy-2.0.43-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11b9503fa6f8721bef9b8567730f664c5a5153d25e247aadc69247c4bc605227"},
{file = "SQLAlchemy-2.0.43-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07097c0a1886c150ef2adba2ff7437e84d40c0f7dcb44a2c2b9c905ccfc6361c"},
{file = "SQLAlchemy-2.0.43-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cdeff998cb294896a34e5b2f00e383e7c5c4ef3b4bfa375d9104723f15186443"},
{file = "SQLAlchemy-2.0.43-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:bcf0724a62a5670e5718957e05c56ec2d6850267ea859f8ad2481838f889b42c"},
{file = "SQLAlchemy-2.0.43-cp37-cp37m-win32.whl", hash = "sha256:c697575d0e2b0a5f0433f679bda22f63873821d991e95a90e9e52aae517b2e32"},
{file = "SQLAlchemy-2.0.43-cp37-cp37m-win_amd64.whl", hash = "sha256:d34c0f6dbefd2e816e8f341d0df7d4763d382e3f452423e752ffd1e213da2512"},
{file = "sqlalchemy-2.0.43-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70322986c0c699dca241418fcf18e637a4369e0ec50540a2b907b184c8bca069"},
@ -11691,12 +11693,20 @@ files = [
{file = "sqlalchemy-2.0.43-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9df7126fd9db49e3a5a3999442cc67e9ee8971f3cb9644250107d7296cb2a164"},
{file = "sqlalchemy-2.0.43-cp313-cp313-win32.whl", hash = "sha256:7f1ac7828857fcedb0361b48b9ac4821469f7694089d15550bbcf9ab22564a1d"},
{file = "sqlalchemy-2.0.43-cp313-cp313-win_amd64.whl", hash = "sha256:971ba928fcde01869361f504fcff3b7143b47d30de188b11c6357c0505824197"},
{file = "sqlalchemy-2.0.43-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4e6aeb2e0932f32950cf56a8b4813cb15ff792fc0c9b3752eaf067cfe298496a"},
{file = "sqlalchemy-2.0.43-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:61f964a05356f4bca4112e6334ed7c208174511bd56e6b8fc86dad4d024d4185"},
{file = "sqlalchemy-2.0.43-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46293c39252f93ea0910aababa8752ad628bcce3a10d3f260648dd472256983f"},
{file = "sqlalchemy-2.0.43-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:136063a68644eca9339d02e6693932116f6a8591ac013b0014479a1de664e40a"},
{file = "sqlalchemy-2.0.43-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6e2bf13d9256398d037fef09fd8bf9b0bf77876e22647d10761d35593b9ac547"},
{file = "sqlalchemy-2.0.43-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:44337823462291f17f994d64282a71c51d738fc9ef561bf265f1d0fd9116a782"},
{file = "sqlalchemy-2.0.43-cp38-cp38-win32.whl", hash = "sha256:13194276e69bb2af56198fef7909d48fd34820de01d9c92711a5fa45497cc7ed"},
{file = "sqlalchemy-2.0.43-cp38-cp38-win_amd64.whl", hash = "sha256:334f41fa28de9f9be4b78445e68530da3c5fa054c907176460c81494f4ae1f5e"},
{file = "sqlalchemy-2.0.43-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ceb5c832cc30663aeaf5e39657712f4c4241ad1f638d487ef7216258f6d41fe7"},
{file = "sqlalchemy-2.0.43-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:11f43c39b4b2ec755573952bbcc58d976779d482f6f832d7f33a8d869ae891bf"},
{file = "sqlalchemy-2.0.43-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:413391b2239db55be14fa4223034d7e13325a1812c8396ecd4f2c08696d5ccad"},
{file = "sqlalchemy-2.0.43-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c379e37b08c6c527181a397212346be39319fb64323741d23e46abd97a400d34"},
{file = "sqlalchemy-2.0.43-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03d73ab2a37d9e40dec4984d1813d7878e01dbdc742448d44a7341b7a9f408c7"},
{file = "sqlalchemy-2.0.43-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:8cee08f15d9e238ede42e9bbc1d6e7158d0ca4f176e4eab21f88ac819ae3bd7b"},
{file = "sqlalchemy-2.0.43-cp39-cp39-win32.whl", hash = "sha256:b3edaec7e8b6dc5cd94523c6df4f294014df67097c8217a89929c99975811414"},
{file = "sqlalchemy-2.0.43-cp39-cp39-win_amd64.whl", hash = "sha256:227119ce0a89e762ecd882dc661e0aa677a690c914e358f0dd8932a2e8b2765b"},
{file = "sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc"},

View file

@ -1,7 +1,7 @@
[project]
name = "cognee"
version = "0.3.9"
version = "0.5.0.dev0"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = [
{ name = "Vasilije Markovic" },
@ -156,7 +156,6 @@ Homepage = "https://www.cognee.ai"
Repository = "https://github.com/topoteretes/cognee"
[project.scripts]
cognee = "cognee.cli._cognee:main"
cognee-cli = "cognee.cli._cognee:main"
[build-system]

10
uv.lock generated
View file

@ -929,7 +929,7 @@ wheels = [
[[package]]
name = "cognee"
version = "0.3.9"
version = "0.5.0.dev0"
source = { editable = "." }
dependencies = [
{ name = "aiofiles" },
@ -2560,6 +2560,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/7f/91/ae2eb6b7979e2f9b035a9f612cf70f1bf54aad4e1d125129bef1eae96f19/greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d", size = 584358, upload-time = "2025-08-07T13:18:23.708Z" },
{ url = "https://files.pythonhosted.org/packages/f7/85/433de0c9c0252b22b16d413c9407e6cb3b41df7389afc366ca204dbc1393/greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5", size = 1113550, upload-time = "2025-08-07T13:42:37.467Z" },
{ url = "https://files.pythonhosted.org/packages/a1/8d/88f3ebd2bc96bf7747093696f4335a0a8a4c5acfcf1b757717c0d2474ba3/greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f", size = 1137126, upload-time = "2025-08-07T13:18:20.239Z" },
{ url = "https://files.pythonhosted.org/packages/f1/29/74242b7d72385e29bcc5563fba67dad94943d7cd03552bac320d597f29b2/greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7", size = 1544904, upload-time = "2025-11-04T12:42:04.763Z" },
{ url = "https://files.pythonhosted.org/packages/c8/e2/1572b8eeab0f77df5f6729d6ab6b141e4a84ee8eb9bc8c1e7918f94eda6d/greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8", size = 1611228, upload-time = "2025-11-04T12:42:08.423Z" },
{ url = "https://files.pythonhosted.org/packages/d6/6f/b60b0291d9623c496638c582297ead61f43c4b72eef5e9c926ef4565ec13/greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c", size = 298654, upload-time = "2025-08-07T13:50:00.469Z" },
{ url = "https://files.pythonhosted.org/packages/a4/de/f28ced0a67749cac23fecb02b694f6473f47686dff6afaa211d186e2ef9c/greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2", size = 272305, upload-time = "2025-08-07T13:15:41.288Z" },
{ url = "https://files.pythonhosted.org/packages/09/16/2c3792cba130000bf2a31c5272999113f4764fd9d874fb257ff588ac779a/greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246", size = 632472, upload-time = "2025-08-07T13:42:55.044Z" },
@ -2569,6 +2571,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684, upload-time = "2025-08-07T13:18:25.164Z" },
{ url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647, upload-time = "2025-08-07T13:42:38.655Z" },
{ url = "https://files.pythonhosted.org/packages/3f/cc/b07000438a29ac5cfb2194bfc128151d52f333cee74dd7dfe3fb733fc16c/greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa", size = 1142073, upload-time = "2025-08-07T13:18:21.737Z" },
{ url = "https://files.pythonhosted.org/packages/67/24/28a5b2fa42d12b3d7e5614145f0bd89714c34c08be6aabe39c14dd52db34/greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c", size = 1548385, upload-time = "2025-11-04T12:42:11.067Z" },
{ url = "https://files.pythonhosted.org/packages/6a/05/03f2f0bdd0b0ff9a4f7b99333d57b53a7709c27723ec8123056b084e69cd/greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5", size = 1613329, upload-time = "2025-11-04T12:42:12.928Z" },
{ url = "https://files.pythonhosted.org/packages/d8/0f/30aef242fcab550b0b3520b8e3561156857c94288f0332a79928c31a52cf/greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9", size = 299100, upload-time = "2025-08-07T13:44:12.287Z" },
{ url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" },
{ url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" },
@ -2578,6 +2582,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" },
{ url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" },
{ url = "https://files.pythonhosted.org/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142, upload-time = "2025-08-07T13:18:22.981Z" },
{ url = "https://files.pythonhosted.org/packages/27/45/80935968b53cfd3f33cf99ea5f08227f2646e044568c9b1555b58ffd61c2/greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0", size = 1564846, upload-time = "2025-11-04T12:42:15.191Z" },
{ url = "https://files.pythonhosted.org/packages/69/02/b7c30e5e04752cb4db6202a3858b149c0710e5453b71a3b2aec5d78a1aab/greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d", size = 1633814, upload-time = "2025-11-04T12:42:17.175Z" },
{ url = "https://files.pythonhosted.org/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899, upload-time = "2025-08-07T13:38:53.448Z" },
{ url = "https://files.pythonhosted.org/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" },
{ url = "https://files.pythonhosted.org/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" },
@ -2587,6 +2593,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" },
{ url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" },
{ url = "https://files.pythonhosted.org/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" },
{ url = "https://files.pythonhosted.org/packages/1c/53/f9c440463b3057485b8594d7a638bed53ba531165ef0ca0e6c364b5cc807/greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b", size = 1564759, upload-time = "2025-11-04T12:42:19.395Z" },
{ url = "https://files.pythonhosted.org/packages/47/e4/3bb4240abdd0a8d23f4f88adec746a3099f0d86bfedb623f063b2e3b4df0/greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929", size = 1634288, upload-time = "2025-11-04T12:42:21.174Z" },
{ url = "https://files.pythonhosted.org/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685, upload-time = "2025-08-07T13:24:38.824Z" },
]