diff --git a/.env.template b/.env.template index d178965e8..7defaee09 100644 --- a/.env.template +++ b/.env.template @@ -21,6 +21,10 @@ LLM_PROVIDER="openai" LLM_ENDPOINT="" LLM_API_VERSION="" LLM_MAX_TOKENS="16384" +# Instructor's modes determine how structured data is requested from and extracted from LLM responses +# You can change this type (i.e. mode) via this env variable +# Each LLM has its own default value, e.g. gpt-5 models have "json_schema_mode" +LLM_INSTRUCTOR_MODE="" EMBEDDING_PROVIDER="openai" EMBEDDING_MODEL="openai/text-embedding-3-large" diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index 7c708638c..4131be988 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -194,7 +194,6 @@ async def cognify( Prerequisites: - **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation) - - **Data Added**: Must have data previously added via `cognee.add()` - **Vector Database**: Must be accessible for embeddings storage - **Graph Database**: Must be accessible for relationship storage diff --git a/cognee/api/client.py b/cognee/api/client.py index 19a607ff0..1a08aed56 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -23,6 +23,7 @@ from cognee.api.v1.settings.routers import get_settings_router from cognee.api.v1.datasets.routers import get_datasets_router from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router from cognee.api.v1.search.routers import get_search_router +from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router from cognee.api.v1.memify.routers import get_memify_router from cognee.api.v1.add.routers import get_add_router from cognee.api.v1.delete.routers import get_delete_router @@ -263,6 +264,8 @@ app.include_router( app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"]) +app.include_router(get_ontology_router(), prefix="/api/v1/ontologies", tags=["ontologies"]) + app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"]) app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"]) diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 231bbcd11..4f1497e3c 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -41,6 +41,9 @@ class CognifyPayloadDTO(InDTO): custom_prompt: Optional[str] = Field( default="", description="Custom prompt for entity extraction and graph generation" ) + ontology_key: Optional[List[str]] = Field( + default=None, description="Reference to one or more previously uploaded ontologies" + ) def get_cognify_router() -> APIRouter: @@ -68,6 +71,7 @@ def get_cognify_router() -> APIRouter: - **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted). - **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking). - **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction. + - **ontology_key** (Optional[List[str]]): Reference to one or more previously uploaded ontology files to use for knowledge graph construction. ## Response - **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status @@ -82,7 +86,8 @@ def get_cognify_router() -> APIRouter: { "datasets": ["research_papers", "documentation"], "run_in_background": false, - "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections." + "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.", + "ontology_key": ["medical_ontology_v1"] } ``` @@ -108,13 +113,35 @@ def get_cognify_router() -> APIRouter: ) from cognee.api.v1.cognify import cognify as cognee_cognify + from cognee.api.v1.ontologies.ontologies import OntologyService try: datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets + config_to_use = None + + if payload.ontology_key: + ontology_service = OntologyService() + ontology_contents = ontology_service.get_ontology_contents( + payload.ontology_key, user + ) + + from cognee.modules.ontology.ontology_config import Config + from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import ( + RDFLibOntologyResolver, + ) + from io import StringIO + + ontology_streams = [StringIO(content) for content in ontology_contents] + config_to_use: Config = { + "ontology_config": { + "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_streams) + } + } cognify_run = await cognee_cognify( datasets, user, + config=config_to_use, run_in_background=payload.run_in_background, custom_prompt=payload.custom_prompt, ) diff --git a/cognee/api/v1/ontologies/__init__.py b/cognee/api/v1/ontologies/__init__.py new file mode 100644 index 000000000..b90d46c3d --- /dev/null +++ b/cognee/api/v1/ontologies/__init__.py @@ -0,0 +1,4 @@ +from .ontologies import OntologyService +from .routers.get_ontology_router import get_ontology_router + +__all__ = ["OntologyService", "get_ontology_router"] diff --git a/cognee/api/v1/ontologies/ontologies.py b/cognee/api/v1/ontologies/ontologies.py new file mode 100644 index 000000000..130b4a862 --- /dev/null +++ b/cognee/api/v1/ontologies/ontologies.py @@ -0,0 +1,183 @@ +import os +import json +import tempfile +from pathlib import Path +from datetime import datetime, timezone +from typing import Optional, List +from dataclasses import dataclass + + +@dataclass +class OntologyMetadata: + ontology_key: str + filename: str + size_bytes: int + uploaded_at: str + description: Optional[str] = None + + +class OntologyService: + def __init__(self): + pass + + @property + def base_dir(self) -> Path: + return Path(tempfile.gettempdir()) / "ontologies" + + def _get_user_dir(self, user_id: str) -> Path: + user_dir = self.base_dir / str(user_id) + user_dir.mkdir(parents=True, exist_ok=True) + return user_dir + + def _get_metadata_path(self, user_dir: Path) -> Path: + return user_dir / "metadata.json" + + def _load_metadata(self, user_dir: Path) -> dict: + metadata_path = self._get_metadata_path(user_dir) + if metadata_path.exists(): + with open(metadata_path, "r") as f: + return json.load(f) + return {} + + def _save_metadata(self, user_dir: Path, metadata: dict): + metadata_path = self._get_metadata_path(user_dir) + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + async def upload_ontology( + self, ontology_key: str, file, user, description: Optional[str] = None + ) -> OntologyMetadata: + if not file.filename.lower().endswith(".owl"): + raise ValueError("File must be in .owl format") + + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + if ontology_key in metadata: + raise ValueError(f"Ontology key '{ontology_key}' already exists") + + content = await file.read() + if len(content) > 10 * 1024 * 1024: + raise ValueError("File size exceeds 10MB limit") + + file_path = user_dir / f"{ontology_key}.owl" + with open(file_path, "wb") as f: + f.write(content) + + ontology_metadata = { + "filename": file.filename, + "size_bytes": len(content), + "uploaded_at": datetime.now(timezone.utc).isoformat(), + "description": description, + } + metadata[ontology_key] = ontology_metadata + self._save_metadata(user_dir, metadata) + + return OntologyMetadata( + ontology_key=ontology_key, + filename=file.filename, + size_bytes=len(content), + uploaded_at=ontology_metadata["uploaded_at"], + description=description, + ) + + async def upload_ontologies( + self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None + ) -> List[OntologyMetadata]: + """ + Upload ontology files with their respective keys. + + Args: + ontology_key: List of unique keys for each ontology + files: List of UploadFile objects (same length as keys) + user: Authenticated user + descriptions: Optional list of descriptions for each file + + Returns: + List of OntologyMetadata objects for uploaded files + + Raises: + ValueError: If keys duplicate, file format invalid, or array lengths don't match + """ + if len(ontology_key) != len(files): + raise ValueError("Number of keys must match number of files") + + if len(set(ontology_key)) != len(ontology_key): + raise ValueError("Duplicate ontology keys not allowed") + + if descriptions and len(descriptions) != len(files): + raise ValueError("Number of descriptions must match number of files") + + results = [] + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + for i, (key, file) in enumerate(zip(ontology_key, files)): + if key in metadata: + raise ValueError(f"Ontology key '{key}' already exists") + + if not file.filename.lower().endswith(".owl"): + raise ValueError(f"File '{file.filename}' must be in .owl format") + + content = await file.read() + if len(content) > 10 * 1024 * 1024: + raise ValueError(f"File '{file.filename}' exceeds 10MB limit") + + file_path = user_dir / f"{key}.owl" + with open(file_path, "wb") as f: + f.write(content) + + ontology_metadata = { + "filename": file.filename, + "size_bytes": len(content), + "uploaded_at": datetime.now(timezone.utc).isoformat(), + "description": descriptions[i] if descriptions else None, + } + metadata[key] = ontology_metadata + + results.append( + OntologyMetadata( + ontology_key=key, + filename=file.filename, + size_bytes=len(content), + uploaded_at=ontology_metadata["uploaded_at"], + description=descriptions[i] if descriptions else None, + ) + ) + + self._save_metadata(user_dir, metadata) + return results + + def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]: + """ + Retrieve ontology content for one or more keys. + + Args: + ontology_key: List of ontology keys to retrieve (can contain single item) + user: Authenticated user + + Returns: + List of ontology content strings + + Raises: + ValueError: If any ontology key not found + """ + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + contents = [] + for key in ontology_key: + if key not in metadata: + raise ValueError(f"Ontology key '{key}' not found") + + file_path = user_dir / f"{key}.owl" + if not file_path.exists(): + raise ValueError(f"Ontology file for key '{key}' not found") + + with open(file_path, "r", encoding="utf-8") as f: + contents.append(f.read()) + return contents + + def list_ontologies(self, user) -> dict: + user_dir = self._get_user_dir(str(user.id)) + return self._load_metadata(user_dir) diff --git a/cognee/api/v1/ontologies/routers/__init__.py b/cognee/api/v1/ontologies/routers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py new file mode 100644 index 000000000..ee31c683f --- /dev/null +++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py @@ -0,0 +1,107 @@ +from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException +from fastapi.responses import JSONResponse +from typing import Optional, List + +from cognee.modules.users.models import User +from cognee.modules.users.methods import get_authenticated_user +from cognee.shared.utils import send_telemetry +from cognee import __version__ as cognee_version +from ..ontologies import OntologyService + + +def get_ontology_router() -> APIRouter: + router = APIRouter() + ontology_service = OntologyService() + + @router.post("", response_model=dict) + async def upload_ontology( + ontology_key: str = Form(...), + ontology_file: List[UploadFile] = File(...), + descriptions: Optional[str] = Form(None), + user: User = Depends(get_authenticated_user), + ): + """ + Upload ontology files with their respective keys for later use in cognify operations. + + Supports both single and multiple file uploads: + - Single file: ontology_key=["key"], ontology_file=[file] + - Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2] + + ## Request Parameters + - **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies + - **ontology_file** (List[UploadFile]): OWL format ontology files + - **descriptions** (Optional[str]): JSON array string of optional descriptions + + ## Response + Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps. + + ## Error Codes + - **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded + - **500 Internal Server Error**: File system or processing errors + """ + send_telemetry( + "Ontology Upload API Endpoint Invoked", + user.id, + additional_properties={ + "endpoint": "POST /api/v1/ontologies", + "cognee_version": cognee_version, + }, + ) + + try: + import json + + ontology_keys = json.loads(ontology_key) + description_list = json.loads(descriptions) if descriptions else None + + if not isinstance(ontology_keys, list): + raise ValueError("ontology_key must be a JSON array") + + results = await ontology_service.upload_ontologies( + ontology_keys, ontology_file, user, description_list + ) + + return { + "uploaded_ontologies": [ + { + "ontology_key": result.ontology_key, + "filename": result.filename, + "size_bytes": result.size_bytes, + "uploaded_at": result.uploaded_at, + "description": result.description, + } + for result in results + ] + } + except (json.JSONDecodeError, ValueError) as e: + return JSONResponse(status_code=400, content={"error": str(e)}) + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) + + @router.get("", response_model=dict) + async def list_ontologies(user: User = Depends(get_authenticated_user)): + """ + List all uploaded ontologies for the authenticated user. + + ## Response + Returns a dictionary mapping ontology keys to their metadata including filename, size, and upload timestamp. + + ## Error Codes + - **500 Internal Server Error**: File system or processing errors + """ + send_telemetry( + "Ontology List API Endpoint Invoked", + user.id, + additional_properties={ + "endpoint": "GET /api/v1/ontologies", + "cognee_version": cognee_version, + }, + ) + + try: + metadata = ontology_service.list_ontologies(user) + return metadata + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) + + return router diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index d4e5fbbe6..354331c57 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -31,6 +31,8 @@ async def search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[List[SearchResult], CombinedSearchResult]: """ Search and query the knowledge graph for insights, information, and connections. @@ -200,6 +202,8 @@ async def search( only_context=only_context, use_combined_context=use_combined_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) return filtered_search_results diff --git a/cognee/cli/commands/cognify_command.py b/cognee/cli/commands/cognify_command.py index 16eaf0454..b89c1f70e 100644 --- a/cognee/cli/commands/cognify_command.py +++ b/cognee/cli/commands/cognify_command.py @@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin Processing Pipeline: 1. **Document Classification**: Identifies document types and structures -2. **Permission Validation**: Ensures user has processing rights +2. **Permission Validation**: Ensures user has processing rights 3. **Text Chunking**: Breaks content into semantically meaningful segments 4. **Entity Extraction**: Identifies key concepts, people, places, organizations 5. **Relationship Detection**: Discovers connections between entities @@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge chunker_class = LangchainChunker except ImportError: fmt.warning("LangchainChunker not available, using TextChunker") + elif args.chunker == "CsvChunker": + try: + from cognee.modules.chunking.CsvChunker import CsvChunker + + chunker_class = CsvChunker + except ImportError: + fmt.warning("CsvChunker not available, using TextChunker") result = await cognee.cognify( datasets=datasets, diff --git a/cognee/cli/config.py b/cognee/cli/config.py index d016608c1..082adbaec 100644 --- a/cognee/cli/config.py +++ b/cognee/cli/config.py @@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [ ] # Chunker choices -CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"] +CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"] # Output format choices OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"] diff --git a/cognee/eval_framework/Dockerfile b/cognee/eval_framework/Dockerfile new file mode 100644 index 000000000..e83be3da4 --- /dev/null +++ b/cognee/eval_framework/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.11-slim + +# Set environment variables +ENV PIP_NO_CACHE_DIR=true +ENV PATH="${PATH}:/root/.poetry/bin" +ENV PYTHONPATH=/app +ENV SKIP_MIGRATIONS=true + +# System dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + libpq-dev \ + git \ + curl \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY pyproject.toml poetry.lock README.md /app/ + +RUN pip install poetry + +RUN poetry config virtualenvs.create false + +RUN poetry install --extras distributed --extras evals --extras deepeval --no-root + +COPY cognee/ /app/cognee +COPY distributed/ /app/distributed diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py index 6f166657e..29b3ede68 100644 --- a/cognee/eval_framework/answer_generation/answer_generation_executor.py +++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py @@ -35,6 +35,16 @@ class AnswerGeneratorExecutor: retrieval_context = await retriever.get_context(query_text) search_results = await retriever.get_completion(query_text, retrieval_context) + ############ + #:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure.. + if isinstance(retrieval_context, list): + retrieval_context = await retriever.convert_retrieved_objects_to_context( + triplets=retrieval_context + ) + + if isinstance(search_results, str): + search_results = [search_results] + ############# answer = { "question": query_text, "answer": search_results[0], diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py index d0a2ebe1e..6b55d84b2 100644 --- a/cognee/eval_framework/answer_generation/run_question_answering_module.py +++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py @@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload): async def run_question_answering( - params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None + params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None ) -> List[dict]: if params.get("answering_questions"): logger.info("Question answering started...") diff --git a/cognee/eval_framework/eval_config.py b/cognee/eval_framework/eval_config.py index 6edcc0454..9e6f26688 100644 --- a/cognee/eval_framework/eval_config.py +++ b/cognee/eval_framework/eval_config.py @@ -14,7 +14,7 @@ class EvalConfig(BaseSettings): # Question answering params answering_questions: bool = True - qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension' + qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension' # Evaluation params evaluating_answers: bool = True @@ -25,7 +25,7 @@ class EvalConfig(BaseSettings): "EM", "f1", ] # Use only 'correctness' for DirectLLM - deepeval_model: str = "gpt-5-mini" + deepeval_model: str = "gpt-4o-mini" # Metrics params calculate_metrics: bool = True diff --git a/cognee/eval_framework/modal_run_eval.py b/cognee/eval_framework/modal_run_eval.py index aca2686a5..bc2ff77c5 100644 --- a/cognee/eval_framework/modal_run_eval.py +++ b/cognee/eval_framework/modal_run_eval.py @@ -2,7 +2,6 @@ import modal import os import asyncio import datetime -import hashlib import json from cognee.shared.logging_utils import get_logger from cognee.eval_framework.eval_config import EvalConfig @@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b from cognee.eval_framework.answer_generation.run_question_answering_module import ( run_question_answering, ) +import pathlib +from os import path +from modal import Image from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation from cognee.eval_framework.metrics_dashboard import create_dashboard @@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict: app = modal.App("modal-run-eval") -image = ( - modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False) - .copy_local_file("pyproject.toml", "pyproject.toml") - .copy_local_file("poetry.lock", "poetry.lock") - .env( - { - "ENV": os.getenv("ENV"), - "LLM_API_KEY": os.getenv("LLM_API_KEY"), - "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"), - } - ) - .pip_install("protobuf", "h2", "deepeval", "gdown", "plotly") +image = Image.from_dockerfile( + path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(), + force_build=False, +).add_local_python_source("cognee") + + +@app.function( + image=image, + max_containers=10, + timeout=86400, + volumes={"/data": vol}, + secrets=[modal.Secret.from_name("eval_secrets")], ) - - -@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol}) async def modal_run_eval(eval_params=None): """Runs evaluation pipeline and returns combined metrics results.""" if eval_params is None: @@ -105,18 +104,7 @@ async def main(): configs = [ EvalConfig( task_getter_type="Default", - number_of_samples_in_corpus=10, - benchmark="HotPotQA", - qa_engine="cognee_graph_completion", - building_corpus_from_scratch=True, - answering_questions=True, - evaluating_answers=True, - calculate_metrics=True, - dashboard=True, - ), - EvalConfig( - task_getter_type="Default", - number_of_samples_in_corpus=10, + number_of_samples_in_corpus=25, benchmark="TwoWikiMultiHop", qa_engine="cognee_graph_completion", building_corpus_from_scratch=True, @@ -127,7 +115,7 @@ async def main(): ), EvalConfig( task_getter_type="Default", - number_of_samples_in_corpus=10, + number_of_samples_in_corpus=25, benchmark="Musique", qa_engine="cognee_graph_completion", building_corpus_from_scratch=True, diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 67df1a27c..8f8c96e79 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -398,3 +398,18 @@ class GraphDBInterface(ABC): - node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections. """ raise NotImplementedError + + @abstractmethod + async def get_filtered_graph_data( + self, attribute_filters: List[Dict[str, List[Union[str, int]]]] + ) -> Tuple[List[Node], List[EdgeData]]: + """ + Retrieve nodes and edges filtered by the provided attribute criteria. + + Parameters: + ----------- + + - attribute_filters: A list of dictionaries where keys are attribute names and values + are lists of attribute values to filter by. + """ + raise NotImplementedError diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 8dd160665..9dbc9c1bc 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -12,6 +12,7 @@ from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor from typing import Dict, Any, List, Union, Optional, Tuple, Type +from cognee.exceptions import CogneeValidationError from cognee.shared.logging_utils import get_logger from cognee.infrastructure.utils.run_sync import run_sync from cognee.infrastructure.files.storage import get_file_storage @@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface): A tuple with two elements: a list of tuples of (node_id, properties) and a list of tuples of (source_id, target_id, relationship_name, properties). """ + + import time + + start_time = time.time() + try: nodes_query = """ MATCH (n:Node) @@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface): }, ) ) + + retrieval_time = time.time() - start_time + logger.info( + f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds" + ) return formatted_nodes, formatted_edges except Exception as e: logger.error(f"Failed to get graph data: {e}") @@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface): formatted_edges.append((source_id, target_id, rel_type, props)) return formatted_nodes, formatted_edges + async def get_id_filtered_graph_data(self, target_ids: list[str]): + """ + Retrieve graph data filtered by specific node IDs, including their direct neighbors + and only edges where one endpoint matches those IDs. + + Returns: + nodes: List[dict] -> Each dict includes "id" and all node properties + edges: List[dict] -> Each dict includes "source", "target", "type", "properties" + """ + import time + + start_time = time.time() + + try: + if not target_ids: + logger.warning("No target IDs provided for ID-filtered graph retrieval.") + return [], [] + + if not all(isinstance(x, str) for x in target_ids): + raise CogneeValidationError("target_ids must be a list of strings") + + query = """ + MATCH (n:Node)-[r]->(m:Node) + WHERE n.id IN $target_ids OR m.id IN $target_ids + RETURN n.id, { + name: n.name, + type: n.type, + properties: n.properties + }, m.id, { + name: m.name, + type: m.type, + properties: m.properties + }, r.relationship_name, r.properties + """ + + result = await self.query(query, {"target_ids": target_ids}) + + if not result: + logger.info("No data returned for the supplied IDs") + return [], [] + + nodes_dict = {} + edges = [] + + for n_id, n_props, m_id, m_props, r_type, r_props_raw in result: + if n_props.get("properties"): + try: + additional_props = json.loads(n_props["properties"]) + n_props.update(additional_props) + del n_props["properties"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse properties JSON for node {n_id}") + + if m_props.get("properties"): + try: + additional_props = json.loads(m_props["properties"]) + m_props.update(additional_props) + del m_props["properties"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse properties JSON for node {m_id}") + + nodes_dict[n_id] = (n_id, n_props) + nodes_dict[m_id] = (m_id, m_props) + + edge_props = {} + if r_props_raw: + try: + edge_props = json.loads(r_props_raw) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}") + + source_id = edge_props.get("source_node_id", n_id) + target_id = edge_props.get("target_node_id", m_id) + edges.append((source_id, target_id, r_type, edge_props)) + + retrieval_time = time.time() - start_time + logger.info( + f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s" + ) + + return list(nodes_dict.values()), edges + + except Exception as e: + logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}") + raise + async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]: """ Get metrics on graph structure and connectivity. diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 6216e107e..f3bb8e173 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface): logger.error(f"Error during graph data retrieval: {str(e)}") raise + async def get_id_filtered_graph_data(self, target_ids: list[str]): + """ + Retrieve graph data filtered by specific node IDs, including their direct neighbors + and only edges where one endpoint matches those IDs. + + This version uses a single Cypher query for efficiency. + """ + import time + + start_time = time.time() + + try: + if not target_ids: + logger.warning("No target IDs provided for ID-filtered graph retrieval.") + return [], [] + + query = """ + MATCH ()-[r]-() + WHERE startNode(r).id IN $target_ids + OR endNode(r).id IN $target_ids + WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b + RETURN + properties(a) AS n_properties, + properties(b) AS m_properties, + type(r) AS type, + properties(r) AS properties + """ + + result = await self.query(query, {"target_ids": target_ids}) + + nodes_dict = {} + edges = [] + + for record in result: + n_props = record["n_properties"] + m_props = record["m_properties"] + r_props = record["properties"] + r_type = record["type"] + + nodes_dict[n_props["id"]] = (n_props["id"], n_props) + nodes_dict[m_props["id"]] = (m_props["id"], m_props) + + source_id = r_props.get("source_node_id", n_props["id"]) + target_id = r_props.get("target_node_id", m_props["id"]) + edges.append((source_id, target_id, r_type, r_props)) + + retrieval_time = time.time() - start_time + logger.info( + f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s" + ) + + return list(nodes_dict.values()), edges + + except Exception as e: + logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}") + raise + async def get_nodeset_subgraph( self, node_type: Type[Any], node_name: List[str] ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: diff --git a/cognee/infrastructure/files/utils/guess_file_type.py b/cognee/infrastructure/files/utils/guess_file_type.py index 78b20c93d..4bc96fe80 100644 --- a/cognee/infrastructure/files/utils/guess_file_type.py +++ b/cognee/infrastructure/files/utils/guess_file_type.py @@ -55,6 +55,10 @@ def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type file_type = Type("text/plain", "txt") return file_type + if ext in [".csv"]: + file_type = Type("text/csv", "csv") + return file_type + file_type = filetype.guess(file) # If file type could not be determined consider it a plain text file as they don't have magic number encoding diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 8fd196eaf..2e300dc0c 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -38,6 +38,7 @@ class LLMConfig(BaseSettings): """ structured_output_framework: str = "instructor" + llm_instructor_mode: str = "" llm_provider: str = "openai" llm_model: str = "openai/gpt-5-mini" llm_endpoint: str = "" @@ -181,6 +182,7 @@ class LLMConfig(BaseSettings): instance. """ return { + "llm_instructor_mode": self.llm_instructor_mode.lower(), "provider": self.llm_provider, "model": self.llm_model, "endpoint": self.llm_endpoint, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index bf19d6e86..dbf0dfbea 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -28,13 +28,16 @@ class AnthropicAdapter(LLMInterface): name = "Anthropic" model: str + default_instructor_mode = "anthropic_tools" - def __init__(self, max_completion_tokens: int, model: str = None): + def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None): import anthropic + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode + self.aclient = instructor.patch( create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create, - mode=instructor.Mode.ANTHROPIC_TOOLS, + mode=instructor.Mode(self.instructor_mode), ) self.model = model diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 1187e0cad..226f291d7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -41,6 +41,7 @@ class GeminiAdapter(LLMInterface): name: str model: str api_key: str + default_instructor_mode = "json_mode" def __init__( self, @@ -49,6 +50,7 @@ class GeminiAdapter(LLMInterface): model: str, api_version: str, max_completion_tokens: int, + instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, @@ -63,7 +65,11 @@ class GeminiAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode + + self.aclient = instructor.from_litellm( + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) + ) @retry( stop=stop_after_delay(128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 8bbbaa2cc..9d7f25fc5 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -41,6 +41,7 @@ class GenericAPIAdapter(LLMInterface): name: str model: str api_key: str + default_instructor_mode = "json_mode" def __init__( self, @@ -49,6 +50,7 @@ class GenericAPIAdapter(LLMInterface): model: str, name: str, max_completion_tokens: int, + instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, @@ -63,7 +65,11 @@ class GenericAPIAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode + + self.aclient = instructor.from_litellm( + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) + ) @retry( stop=stop_after_delay(128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index c7dcecc56..39558f36d 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -81,6 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, transcription_model=llm_config.transcription_model, max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode.lower(), streaming=llm_config.llm_streaming, fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, @@ -101,6 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Ollama", max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.ANTHROPIC: @@ -109,7 +111,9 @@ def get_llm_client(raise_api_key_error: bool = True): ) return AnthropicAdapter( - max_completion_tokens=max_completion_tokens, model=llm_config.llm_model + max_completion_tokens=max_completion_tokens, + model=llm_config.llm_model, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.CUSTOM: @@ -126,6 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Custom", max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode.lower(), fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, fallback_model=llm_config.fallback_model, @@ -145,6 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True): max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, api_version=llm_config.llm_api_version, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.MISTRAL: @@ -160,21 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, - ) - - elif provider == LLMProvider.MISTRAL: - if llm_config.llm_api_key is None: - raise LLMAPIKeyNotSetError() - - from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import ( - MistralAdapter, - ) - - return MistralAdapter( - api_key=llm_config.llm_api_key, - model=llm_config.llm_model, - max_completion_tokens=max_completion_tokens, - endpoint=llm_config.llm_endpoint, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) else: diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index 78a3cbff5..355cdae0b 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -37,16 +37,26 @@ class MistralAdapter(LLMInterface): model: str api_key: str max_completion_tokens: int + default_instructor_mode = "mistral_tools" - def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None): + def __init__( + self, + api_key: str, + model: str, + max_completion_tokens: int, + endpoint: str = None, + instructor_mode: str = None, + ): from mistralai import Mistral self.model = model self.max_completion_tokens = max_completion_tokens + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode + self.aclient = instructor.from_litellm( litellm.acompletion, - mode=instructor.Mode.MISTRAL_TOOLS, + mode=instructor.Mode(self.instructor_mode), api_key=get_llm_config().llm_api_key, ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index 9c3d185aa..aabd19867 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -42,8 +42,16 @@ class OllamaAPIAdapter(LLMInterface): - aclient """ + default_instructor_mode = "json_mode" + def __init__( - self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int + self, + endpoint: str, + api_key: str, + model: str, + name: str, + max_completion_tokens: int, + instructor_mode: str = None, ): self.name = name self.model = model @@ -51,8 +59,11 @@ class OllamaAPIAdapter(LLMInterface): self.endpoint = endpoint self.max_completion_tokens = max_completion_tokens + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode + self.aclient = instructor.from_openai( - OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON + OpenAI(base_url=self.endpoint, api_key=self.api_key), + mode=instructor.Mode(self.instructor_mode), ) @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 305b426b8..778c8eec7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -56,6 +56,7 @@ class OpenAIAdapter(LLMInterface): model: str api_key: str api_version: str + default_instructor_mode = "json_schema_mode" MAX_RETRIES = 5 @@ -69,19 +70,21 @@ class OpenAIAdapter(LLMInterface): model: str, transcription_model: str, max_completion_tokens: int, + instructor_mode: str = None, streaming: bool = False, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs. # Make sure all new gpt models will work with this mode as well. if "gpt-5" in model: self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) self.client = instructor.from_litellm( - litellm.completion, mode=instructor.Mode.JSON_SCHEMA + litellm.completion, mode=instructor.Mode(self.instructor_mode) ) else: self.aclient = instructor.from_litellm(litellm.acompletion) diff --git a/cognee/infrastructure/loaders/LoaderEngine.py b/cognee/infrastructure/loaders/LoaderEngine.py index f9511e7c5..4a363a0e6 100644 --- a/cognee/infrastructure/loaders/LoaderEngine.py +++ b/cognee/infrastructure/loaders/LoaderEngine.py @@ -31,6 +31,7 @@ class LoaderEngine: "pypdf_loader", "image_loader", "audio_loader", + "csv_loader", "unstructured_loader", "advanced_pdf_loader", ] diff --git a/cognee/infrastructure/loaders/core/__init__.py b/cognee/infrastructure/loaders/core/__init__.py index 8a2df80f9..09819fbd2 100644 --- a/cognee/infrastructure/loaders/core/__init__.py +++ b/cognee/infrastructure/loaders/core/__init__.py @@ -3,5 +3,6 @@ from .text_loader import TextLoader from .audio_loader import AudioLoader from .image_loader import ImageLoader +from .csv_loader import CsvLoader -__all__ = ["TextLoader", "AudioLoader", "ImageLoader"] +__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"] diff --git a/cognee/infrastructure/loaders/core/csv_loader.py b/cognee/infrastructure/loaders/core/csv_loader.py new file mode 100644 index 000000000..a314a7a24 --- /dev/null +++ b/cognee/infrastructure/loaders/core/csv_loader.py @@ -0,0 +1,93 @@ +import os +from typing import List +import csv +from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface +from cognee.infrastructure.files.storage import get_file_storage, get_storage_config +from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata + + +class CsvLoader(LoaderInterface): + """ + Core CSV file loader that handles basic CSV file formats. + """ + + @property + def supported_extensions(self) -> List[str]: + """Supported text file extensions.""" + return [ + "csv", + ] + + @property + def supported_mime_types(self) -> List[str]: + """Supported MIME types for text content.""" + return [ + "text/csv", + ] + + @property + def loader_name(self) -> str: + """Unique identifier for this loader.""" + return "csv_loader" + + def can_handle(self, extension: str, mime_type: str) -> bool: + """ + Check if this loader can handle the given file. + + Args: + extension: File extension + mime_type: Optional MIME type + + Returns: + True if file can be handled, False otherwise + """ + if extension in self.supported_extensions and mime_type in self.supported_mime_types: + return True + + return False + + async def load(self, file_path: str, encoding: str = "utf-8", **kwargs): + """ + Load and process the csv file. + + Args: + file_path: Path to the file to load + encoding: Text encoding to use (default: utf-8) + **kwargs: Additional configuration (unused) + + Returns: + LoaderResult containing the file content and metadata + + Raises: + FileNotFoundError: If file doesn't exist + UnicodeDecodeError: If file cannot be decoded with specified encoding + OSError: If file cannot be read + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as f: + file_metadata = await get_file_metadata(f) + # Name ingested file of current loader based on original file content hash + storage_file_name = "text_" + file_metadata["content_hash"] + ".txt" + + row_texts = [] + row_index = 1 + + with open(file_path, "r", encoding=encoding, newline="") as file: + reader = csv.DictReader(file) + for row in reader: + pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()] + row_text = ", ".join(pairs) + row_texts.append(f"Row {row_index}:\n{row_text}\n") + row_index += 1 + + content = "\n".join(row_texts) + + storage_config = get_storage_config() + data_root_directory = storage_config["data_root_directory"] + storage = get_file_storage(data_root_directory) + + full_file_path = await storage.store(storage_file_name, content) + + return full_file_path diff --git a/cognee/infrastructure/loaders/core/text_loader.py b/cognee/infrastructure/loaders/core/text_loader.py index a6f94be9b..e478edb22 100644 --- a/cognee/infrastructure/loaders/core/text_loader.py +++ b/cognee/infrastructure/loaders/core/text_loader.py @@ -16,7 +16,7 @@ class TextLoader(LoaderInterface): @property def supported_extensions(self) -> List[str]: """Supported text file extensions.""" - return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"] + return ["txt", "md", "json", "xml", "yaml", "yml", "log"] @property def supported_mime_types(self) -> List[str]: @@ -24,7 +24,6 @@ class TextLoader(LoaderInterface): return [ "text/plain", "text/markdown", - "text/csv", "application/json", "text/xml", "application/xml", diff --git a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py index 6d1412b77..4b3ba296a 100644 --- a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py +++ b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py @@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface): if value is None: return "" return str(value).replace("\xa0", " ").strip() - - -if __name__ == "__main__": - loader = AdvancedPdfLoader() - asyncio.run( - loader.load( - "/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf" - ) - ) diff --git a/cognee/infrastructure/loaders/supported_loaders.py b/cognee/infrastructure/loaders/supported_loaders.py index 156253b53..2b8c3e0b4 100644 --- a/cognee/infrastructure/loaders/supported_loaders.py +++ b/cognee/infrastructure/loaders/supported_loaders.py @@ -1,5 +1,5 @@ from cognee.infrastructure.loaders.external import PyPdfLoader -from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader +from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader # Registry for loader implementations supported_loaders = { @@ -7,6 +7,7 @@ supported_loaders = { TextLoader.loader_name: TextLoader, ImageLoader.loader_name: ImageLoader, AudioLoader.loader_name: AudioLoader, + CsvLoader.loader_name: CsvLoader, } # Try adding optional loaders diff --git a/cognee/modules/chunking/CsvChunker.py b/cognee/modules/chunking/CsvChunker.py new file mode 100644 index 000000000..4ba4a969e --- /dev/null +++ b/cognee/modules/chunking/CsvChunker.py @@ -0,0 +1,35 @@ +from cognee.shared.logging_utils import get_logger + + +from cognee.tasks.chunks import chunk_by_row +from cognee.modules.chunking.Chunker import Chunker +from .models.DocumentChunk import DocumentChunk + +logger = get_logger() + + +class CsvChunker(Chunker): + async def read(self): + async for content_text in self.get_text(): + if content_text is None: + continue + + for chunk_data in chunk_by_row(content_text, self.max_chunk_size): + if chunk_data["chunk_size"] <= self.max_chunk_size: + yield DocumentChunk( + id=chunk_data["chunk_id"], + text=chunk_data["text"], + chunk_size=chunk_data["chunk_size"], + is_part_of=self.document, + chunk_index=self.chunk_index, + cut_type=chunk_data["cut_type"], + contains=[], + metadata={ + "index_fields": ["text"], + }, + ) + self.chunk_index += 1 + else: + raise ValueError( + f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}" + ) diff --git a/cognee/modules/data/processing/document_types/CsvDocument.py b/cognee/modules/data/processing/document_types/CsvDocument.py new file mode 100644 index 000000000..3381275bd --- /dev/null +++ b/cognee/modules/data/processing/document_types/CsvDocument.py @@ -0,0 +1,33 @@ +import io +import csv +from typing import Type + +from cognee.modules.chunking.Chunker import Chunker +from cognee.infrastructure.files.utils.open_data_file import open_data_file +from .Document import Document + + +class CsvDocument(Document): + type: str = "csv" + mime_type: str = "text/csv" + + async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int): + async def get_text(): + async with open_data_file( + self.raw_data_location, mode="r", encoding="utf-8", newline="" + ) as file: + content = file.read() + file_like_obj = io.StringIO(content) + reader = csv.DictReader(file_like_obj) + + for row in reader: + pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()] + row_text = ", ".join(pairs) + if not row_text.strip(): + break + yield row_text + + chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text) + + async for chunk in chunker.read(): + yield chunk diff --git a/cognee/modules/data/processing/document_types/__init__.py b/cognee/modules/data/processing/document_types/__init__.py index 2e862f4ba..133dd53f8 100644 --- a/cognee/modules/data/processing/document_types/__init__.py +++ b/cognee/modules/data/processing/document_types/__init__.py @@ -4,3 +4,4 @@ from .TextDocument import TextDocument from .ImageDocument import ImageDocument from .AudioDocument import AudioDocument from .UnstructuredDocument import UnstructuredDocument +from .CsvDocument import CsvDocument diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index cb7562422..2e0b82e8d 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph): def get_edges(self) -> List[Edge]: return self.edges + async def _get_nodeset_subgraph( + self, + adapter, + node_type, + node_name, + ): + """Retrieve subgraph based on node type and name.""" + logger.info("Retrieving graph filtered by node type and node name (NodeSet).") + nodes_data, edges_data = await adapter.get_nodeset_subgraph( + node_type=node_type, node_name=node_name + ) + if not nodes_data or not edges_data: + raise EntityNotFoundError( + message="Nodeset does not exist, or empty nodeset projected from the database." + ) + return nodes_data, edges_data + + async def _get_full_or_id_filtered_graph( + self, + adapter, + relevant_ids_to_filter, + ): + """Retrieve full or ID-filtered graph with fallback.""" + if relevant_ids_to_filter is None: + logger.info("Retrieving full graph.") + nodes_data, edges_data = await adapter.get_graph_data() + if not nodes_data or not edges_data: + raise EntityNotFoundError(message="Empty graph projected from the database.") + return nodes_data, edges_data + + get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data) + if getattr(adapter.__class__, "get_id_filtered_graph_data", None): + logger.info("Retrieving ID-filtered graph from database.") + nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter) + else: + logger.info("Retrieving full graph from database.") + nodes_data, edges_data = await get_graph_data_fn() + if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data): + logger.warning( + "Id filtered graph returned empty, falling back to full graph retrieval." + ) + logger.info("Retrieving full graph") + nodes_data, edges_data = await adapter.get_graph_data() + + if not nodes_data or not edges_data: + raise EntityNotFoundError("Empty graph projected from the database.") + return nodes_data, edges_data + + async def _get_filtered_graph( + self, + adapter, + memory_fragment_filter, + ): + """Retrieve graph filtered by attributes.""" + logger.info("Retrieving graph filtered by memory fragment") + nodes_data, edges_data = await adapter.get_filtered_graph_data( + attribute_filters=memory_fragment_filter + ) + if not nodes_data or not edges_data: + raise EntityNotFoundError(message="Empty filtered graph projected from the database.") + return nodes_data, edges_data + async def project_graph_from_db( self, adapter: Union[GraphDBInterface], @@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph): memory_fragment_filter=[], node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + relevant_ids_to_filter: Optional[List[str]] = None, + triplet_distance_penalty: float = 3.5, ) -> None: if node_dimension < 1 or edge_dimension < 1: raise InvalidDimensionsError() try: + if node_type is not None and node_name not in [None, [], ""]: + nodes_data, edges_data = await self._get_nodeset_subgraph( + adapter, node_type, node_name + ) + elif len(memory_fragment_filter) == 0: + nodes_data, edges_data = await self._get_full_or_id_filtered_graph( + adapter, relevant_ids_to_filter + ) + else: + nodes_data, edges_data = await self._get_filtered_graph( + adapter, memory_fragment_filter + ) + import time start_time = time.time() - - # Determine projection strategy - if node_type is not None and node_name not in [None, [], ""]: - nodes_data, edges_data = await adapter.get_nodeset_subgraph( - node_type=node_type, node_name=node_name - ) - if not nodes_data or not edges_data: - raise EntityNotFoundError( - message="Nodeset does not exist, or empty nodetes projected from the database." - ) - elif len(memory_fragment_filter) == 0: - nodes_data, edges_data = await adapter.get_graph_data() - if not nodes_data or not edges_data: - raise EntityNotFoundError(message="Empty graph projected from the database.") - else: - nodes_data, edges_data = await adapter.get_filtered_graph_data( - attribute_filters=memory_fragment_filter - ) - if not nodes_data or not edges_data: - raise EntityNotFoundError( - message="Empty filtered graph projected from the database." - ) - # Process nodes for node_id, properties in nodes_data: node_attributes = {key: properties.get(key) for key in node_properties_to_project} - self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension)) + self.add_node( + Node( + str(node_id), + node_attributes, + dimension=node_dimension, + node_penalty=triplet_distance_penalty, + ) + ) # Process edges for source_id, target_id, relationship_type, properties in edges_data: @@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph): attributes=edge_attributes, directed=directed, dimension=edge_dimension, + edge_penalty=triplet_distance_penalty, ) self.add_edge(edge) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index 0ca9c4fb9..62ef8d9fd 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -20,13 +20,17 @@ class Node: status: np.ndarray def __init__( - self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1 + self, + node_id: str, + attributes: Optional[Dict[str, Any]] = None, + dimension: int = 1, + node_penalty: float = 3.5, ): if dimension <= 0: raise InvalidDimensionsError() self.id = node_id self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = float("inf") + self.attributes["vector_distance"] = node_penalty self.skeleton_neighbours = [] self.skeleton_edges = [] self.status = np.ones(dimension, dtype=int) @@ -105,13 +109,14 @@ class Edge: attributes: Optional[Dict[str, Any]] = None, directed: bool = True, dimension: int = 1, + edge_penalty: float = 3.5, ): if dimension <= 0: raise InvalidDimensionsError() self.node1 = node1 self.node2 = node2 self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = float("inf") + self.attributes["vector_distance"] = edge_penalty self.directed = directed self.status = np.ones(dimension, dtype=int) diff --git a/cognee/modules/notebooks/operations/run_in_local_sandbox.py b/cognee/modules/notebooks/operations/run_in_local_sandbox.py index 071deafb7..46499186e 100644 --- a/cognee/modules/notebooks/operations/run_in_local_sandbox.py +++ b/cognee/modules/notebooks/operations/run_in_local_sandbox.py @@ -2,6 +2,8 @@ import io import sys import traceback +import cognee + def wrap_in_async_handler(user_code: str) -> str: return ( @@ -34,6 +36,7 @@ def run_in_local_sandbox(code, environment=None, loop=None): environment["print"] = customPrintFunction environment["running_loop"] = loop + environment["cognee"] = cognee try: exec(code, environment) diff --git a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py index 45e32936a..34d7a946a 100644 --- a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +++ b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py @@ -2,7 +2,7 @@ import os import difflib from cognee.shared.logging_utils import get_logger from collections import deque -from typing import List, Tuple, Dict, Optional, Any, Union +from typing import List, Tuple, Dict, Optional, Any, Union, IO from rdflib import Graph, URIRef, RDF, RDFS, OWL from cognee.modules.ontology.exceptions import ( @@ -26,44 +26,76 @@ class RDFLibOntologyResolver(BaseOntologyResolver): def __init__( self, - ontology_file: Optional[Union[str, List[str]]] = None, + ontology_file: Optional[Union[str, List[str], IO, List[IO]]] = None, matching_strategy: Optional[MatchingStrategy] = None, ) -> None: super().__init__(matching_strategy) self.ontology_file = ontology_file try: - files_to_load = [] + self.graph = None if ontology_file is not None: - if isinstance(ontology_file, str): + files_to_load = [] + file_objects = [] + + if hasattr(ontology_file, "read"): + file_objects = [ontology_file] + elif isinstance(ontology_file, str): files_to_load = [ontology_file] elif isinstance(ontology_file, list): - files_to_load = ontology_file + if all(hasattr(item, "read") for item in ontology_file): + file_objects = ontology_file + else: + files_to_load = ontology_file else: raise ValueError( - f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}" + f"ontology_file must be a string, list of strings, file-like object, list of file-like objects, or None. Got: {type(ontology_file)}" ) - if files_to_load: - self.graph = Graph() - loaded_files = [] - for file_path in files_to_load: - if os.path.exists(file_path): - self.graph.parse(file_path) - loaded_files.append(file_path) - logger.info("Ontology loaded successfully from file: %s", file_path) - else: - logger.warning( - "Ontology file '%s' not found. Skipping this file.", - file_path, + if file_objects: + self.graph = Graph() + loaded_objects = [] + for file_obj in file_objects: + try: + content = file_obj.read() + self.graph.parse(data=content, format="xml") + loaded_objects.append(file_obj) + logger.info("Ontology loaded successfully from file object") + except Exception as e: + logger.warning("Failed to parse ontology file object: %s", str(e)) + + if not loaded_objects: + logger.info( + "No valid ontology file objects found. No owl ontology will be attached to the graph." ) + self.graph = None + else: + logger.info("Total ontology file objects loaded: %d", len(loaded_objects)) - if not loaded_files: - logger.info( - "No valid ontology files found. No owl ontology will be attached to the graph." - ) - self.graph = None + elif files_to_load: + self.graph = Graph() + loaded_files = [] + for file_path in files_to_load: + if os.path.exists(file_path): + self.graph.parse(file_path) + loaded_files.append(file_path) + logger.info("Ontology loaded successfully from file: %s", file_path) + else: + logger.warning( + "Ontology file '%s' not found. Skipping this file.", + file_path, + ) + + if not loaded_files: + logger.info( + "No valid ontology files found. No owl ontology will be attached to the graph." + ) + self.graph = None + else: + logger.info("Total ontology files loaded: %d", len(loaded_files)) else: - logger.info("Total ontology files loaded: %d", len(loaded_files)) + logger.info( + "No ontology file provided. No owl ontology will be attached to the graph." + ) else: logger.info( "No ontology file provided. No owl ontology will be attached to the graph." diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index b07d11fd2..fc49a139b 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) async def get_completion( diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index eb8f502cb..70fcb6cdb 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -65,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -74,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): node_type=node_type, node_name=node_name, save_interaction=save_interaction, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.validation_system_prompt_path = validation_system_prompt_path self.validation_user_prompt_path = validation_user_prompt_path diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index df77a11ac..89e9e47ce 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): """Initialize retriever with prompt paths and search parameters.""" self.save_interaction = save_interaction @@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever): self.system_prompt_path = system_prompt_path self.system_prompt = system_prompt self.top_k = top_k if top_k is not None else 5 + self.wide_search_top_k = wide_search_top_k self.node_type = node_type self.node_name = node_name + self.triplet_distance_penalty = triplet_distance_penalty async def resolve_edges_to_text(self, retrieved_edges: list) -> str: """ @@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): collections=vector_index_collections or None, node_type=self.node_type, node_name=self.node_name, + wide_search_top_k=self.wide_search_top_k, + triplet_distance_penalty=self.triplet_distance_penalty, ) return found_triplets @@ -141,6 +147,10 @@ class GraphCompletionRetriever(BaseGraphRetriever): return triplets + async def convert_retrieved_objects_to_context(self, triplets: List[Edge]): + context = await self.resolve_edges_to_text(triplets) + return context + async def get_completion( self, query: str, diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index 051f39b22..e31ad126e 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): """Initialize retriever with default prompt paths and search parameters.""" super().__init__( @@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.summarize_prompt_path = summarize_prompt_path diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index f3da02c15..87d2ab009 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever): top_k=top_k, node_type=node_type, node_name=node_name, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index f8bdbb97d..2f8a545f7 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -58,6 +58,8 @@ async def get_memory_fragment( properties_to_project: Optional[List[str]] = None, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + relevant_ids_to_filter: Optional[List[str]] = None, + triplet_distance_penalty: Optional[float] = 3.5, ) -> CogneeGraph: """Creates and initializes a CogneeGraph memory fragment with optional property projections.""" if properties_to_project is None: @@ -74,6 +76,8 @@ async def get_memory_fragment( edge_properties_to_project=["relationship_name", "edge_text"], node_type=node_type, node_name=node_name, + relevant_ids_to_filter=relevant_ids_to_filter, + triplet_distance_penalty=triplet_distance_penalty, ) except EntityNotFoundError: @@ -95,6 +99,8 @@ async def brute_force_triplet_search( memory_fragment: Optional[CogneeGraph] = None, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> List[Edge]: """ Performs a brute force search to retrieve the top triplets from the graph. @@ -107,6 +113,8 @@ async def brute_force_triplet_search( memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse. node_type: node type to filter node_name: node name to filter + wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections + triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection Returns: list: The top triplet results. @@ -116,10 +124,10 @@ async def brute_force_triplet_search( if top_k <= 0: raise ValueError("top_k must be a positive integer.") - if memory_fragment is None: - memory_fragment = await get_memory_fragment( - properties_to_project, node_type=node_type, node_name=node_name - ) + # Setting wide search limit based on the parameters + non_global_search = node_name is None + + wide_search_limit = wide_search_top_k if non_global_search else None if collections is None: collections = [ @@ -140,7 +148,7 @@ async def brute_force_triplet_search( async def search_in_collection(collection_name: str): try: return await vector_engine.search( - collection_name=collection_name, query_vector=query_vector, limit=None + collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit ) except CollectionNotFoundError: return [] @@ -156,15 +164,38 @@ async def brute_force_triplet_search( return [] # Final statistics - projection_time = time.time() - start_time + vector_collection_search_time = time.time() - start_time logger.info( - f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s" + f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s" ) node_distances = {collection: result for collection, result in zip(collections, results)} edge_distances = node_distances.get("EdgeType_relationship_name", None) + if wide_search_limit is not None: + relevant_ids_to_filter = list( + { + str(getattr(scored_node, "id")) + for collection_name, score_collection in node_distances.items() + if collection_name != "EdgeType_relationship_name" + and isinstance(score_collection, (list, tuple)) + for scored_node in score_collection + if getattr(scored_node, "id", None) + } + ) + else: + relevant_ids_to_filter = None + + if memory_fragment is None: + memory_fragment = await get_memory_fragment( + properties_to_project=properties_to_project, + node_type=node_type, + node_name=node_name, + relevant_ids_to_filter=relevant_ids_to_filter, + triplet_distance_penalty=triplet_distance_penalty, + ) + await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) await memory_fragment.map_vector_distances_to_graph_edges( vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances diff --git a/cognee/modules/search/methods/get_search_type_tools.py b/cognee/modules/search/methods/get_search_type_tools.py index 72e2db89a..165ec379b 100644 --- a/cognee/modules/search/methods/get_search_type_tools.py +++ b/cognee/modules/search/methods/get_search_type_tools.py @@ -37,6 +37,8 @@ async def get_search_type_tools( node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> list: search_tasks: dict[SearchType, List[Callable]] = { SearchType.SUMMARIES: [ @@ -67,6 +69,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionRetriever( system_prompt_path=system_prompt_path, @@ -75,6 +79,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_COMPLETION_COT: [ @@ -85,6 +91,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionCotRetriever( system_prompt_path=system_prompt_path, @@ -93,6 +101,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [ @@ -103,6 +113,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionContextExtensionRetriever( system_prompt_path=system_prompt_path, @@ -111,6 +123,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_SUMMARY_COMPLETION: [ @@ -121,6 +135,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, @@ -129,6 +145,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.CODE: [ @@ -145,8 +163,16 @@ async def get_search_type_tools( ], SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback], SearchType.TEMPORAL: [ - TemporalRetriever(top_k=top_k).get_completion, - TemporalRetriever(top_k=top_k).get_context, + TemporalRetriever( + top_k=top_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, + ).get_completion, + TemporalRetriever( + top_k=top_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, + ).get_context, ], SearchType.CHUNKS_LEXICAL: ( lambda _r=JaccardChunksRetriever(top_k=top_k): [ diff --git a/cognee/modules/search/methods/no_access_control_search.py b/cognee/modules/search/methods/no_access_control_search.py index fcb02da46..3a703bbc9 100644 --- a/cognee/modules/search/methods/no_access_control_search.py +++ b/cognee/modules/search/methods/no_access_control_search.py @@ -24,6 +24,8 @@ async def no_access_control_search( last_k: Optional[int] = None, only_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: search_tools = await get_search_type_tools( query_type=query_type, @@ -35,6 +37,8 @@ async def no_access_control_search( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) graph_engine = await get_graph_engine() is_empty = await graph_engine.is_empty() diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index b4278424b..9f180d607 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -47,6 +47,8 @@ async def search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[CombinedSearchResult, List[SearchResult]]: """ @@ -90,6 +92,8 @@ async def search( only_context=only_context, use_combined_context=use_combined_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) else: search_results = [ @@ -105,6 +109,8 @@ async def search( last_k=last_k, only_context=only_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) ] @@ -219,6 +225,8 @@ async def authorized_search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[ Tuple[Any, Union[List[Edge], str], List[Dataset]], List[Tuple[Any, Union[List[Edge], str], List[Dataset]]], @@ -246,6 +254,8 @@ async def authorized_search( last_k=last_k, only_context=True, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) context = {} @@ -267,6 +277,8 @@ async def authorized_search( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) search_tools = specific_search_tools if len(search_tools) == 2: @@ -306,6 +318,7 @@ async def authorized_search( last_k=last_k, only_context=only_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, ) return search_results @@ -325,6 +338,8 @@ async def search_in_datasets_context( only_context: bool = False, context: Optional[Any] = None, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]: """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. @@ -345,6 +360,8 @@ async def search_in_datasets_context( only_context: bool = False, context: Optional[Any] = None, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) @@ -378,6 +395,8 @@ async def search_in_datasets_context( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) search_tools = specific_search_tools if len(search_tools) == 2: @@ -413,6 +432,8 @@ async def search_in_datasets_context( only_context=only_context, context=context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) ) diff --git a/cognee/tasks/chunks/__init__.py b/cognee/tasks/chunks/__init__.py index 22ce96be8..37d4de73e 100644 --- a/cognee/tasks/chunks/__init__.py +++ b/cognee/tasks/chunks/__init__.py @@ -1,4 +1,5 @@ from .chunk_by_word import chunk_by_word from .chunk_by_sentence import chunk_by_sentence from .chunk_by_paragraph import chunk_by_paragraph +from .chunk_by_row import chunk_by_row from .remove_disconnected_chunks import remove_disconnected_chunks diff --git a/cognee/tasks/chunks/chunk_by_row.py b/cognee/tasks/chunks/chunk_by_row.py new file mode 100644 index 000000000..8daf13689 --- /dev/null +++ b/cognee/tasks/chunks/chunk_by_row.py @@ -0,0 +1,94 @@ +from typing import Any, Dict, Iterator +from uuid import NAMESPACE_OID, uuid5 + +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine + + +def _get_pair_size(pair_text: str) -> int: + """ + Calculate the size of a given text in terms of tokens. + + If an embedding engine's tokenizer is available, count the tokens for the provided word. + If the tokenizer is not available, assume the word counts as one token. + + Parameters: + ----------- + + - pair_text (str): The key:value pair text for which the token size is to be calculated. + + Returns: + -------- + + - int: The number of tokens representing the text, typically an integer, depending + on the tokenizer's output. + """ + embedding_engine = get_embedding_engine() + if embedding_engine.tokenizer: + return embedding_engine.tokenizer.count_tokens(pair_text) + else: + return 3 + + +def chunk_by_row( + data: str, + max_chunk_size, +) -> Iterator[Dict[str, Any]]: + """ + Chunk the input text by row while enabling exact text reconstruction. + + This function divides the given text data into smaller chunks on a line-by-line basis, + ensuring that the size of each chunk is less than or equal to the specified maximum + chunk size. It guarantees that when the generated chunks are concatenated, they + reproduce the original text accurately. The tokenization process is handled by + adapters compatible with the vector engine's embedding model. + + Parameters: + ----------- + + - data (str): The input text to be chunked. + - max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or + words. + """ + current_chunk_list = [] + chunk_index = 0 + current_chunk_size = 0 + + lines = data.split("\n\n") + for line in lines: + pairs_text = line.split(", ") + + for pair_text in pairs_text: + pair_size = _get_pair_size(pair_text) + if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size): + # Yield current cut chunk + current_chunk = ", ".join(current_chunk_list) + chunk_dict = { + "text": current_chunk, + "chunk_size": current_chunk_size, + "chunk_id": uuid5(NAMESPACE_OID, current_chunk), + "chunk_index": chunk_index, + "cut_type": "row_cut", + } + + yield chunk_dict + + # Start new chunk with current pair text + current_chunk_list = [] + current_chunk_size = 0 + chunk_index += 1 + + current_chunk_list.append(pair_text) + current_chunk_size += pair_size + + # Yield row chunk + current_chunk = ", ".join(current_chunk_list) + if current_chunk: + chunk_dict = { + "text": current_chunk, + "chunk_size": current_chunk_size, + "chunk_id": uuid5(NAMESPACE_OID, current_chunk), + "chunk_index": chunk_index, + "cut_type": "row_end", + } + + yield chunk_dict diff --git a/cognee/tasks/documents/classify_documents.py b/cognee/tasks/documents/classify_documents.py index 9fa512906..e4f13ebd1 100644 --- a/cognee/tasks/documents/classify_documents.py +++ b/cognee/tasks/documents/classify_documents.py @@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import ( ImageDocument, TextDocument, UnstructuredDocument, + CsvDocument, ) from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.engine.utils.generate_node_id import generate_node_id @@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError EXTENSION_TO_DOCUMENT_CLASS = { "pdf": PdfDocument, # Text documents "txt": TextDocument, + "csv": CsvDocument, "docx": UnstructuredDocument, "doc": UnstructuredDocument, "odt": UnstructuredDocument, diff --git a/cognee/tests/integration/documents/CsvDocument_test.py b/cognee/tests/integration/documents/CsvDocument_test.py new file mode 100644 index 000000000..421bb81bd --- /dev/null +++ b/cognee/tests/integration/documents/CsvDocument_test.py @@ -0,0 +1,70 @@ +import os +import sys +import uuid +import pytest +import pathlib +from unittest.mock import patch + +from cognee.modules.chunking.CsvChunker import CsvChunker +from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument +from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine +from cognee.tests.integration.documents.async_gen_zip import async_gen_zip + +chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row") + + +GROUND_TRUTH = { + "chunk_size_10": [ + {"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0}, + {"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1}, + {"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2}, + {"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3}, + ], + "chunk_size_128": [ + {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0}, + {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1}, + ], +} + + +@pytest.mark.parametrize( + "input_file,chunk_size", + [("example_with_header.csv", 10), ("example_with_header.csv", 128)], +) +@patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine) +@pytest.mark.asyncio +async def test_CsvDocument(mock_engine, input_file, chunk_size): + # Define file paths of test data + csv_file_path = os.path.join( + pathlib.Path(__file__).parent.parent.parent, + "test_data", + input_file, + ) + + # Define test documents + csv_document = CsvDocument( + id=uuid.uuid4(), + name="example_with_header.csv", + raw_data_location=csv_file_path, + external_metadata="", + mime_type="text/csv", + ) + + # TEST CSV + ground_truth_key = f"chunk_size_{chunk_size}" + async for ground_truth, row_data in async_gen_zip( + GROUND_TRUTH[ground_truth_key], + csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size), + ): + assert ground_truth["token_count"] == row_data.chunk_size, ( + f'{ground_truth["token_count"] = } != {row_data.chunk_size = }' + ) + assert ground_truth["len_text"] == len(row_data.text), ( + f'{ground_truth["len_text"] = } != {len(row_data.text) = }' + ) + assert ground_truth["cut_type"] == row_data.cut_type, ( + f'{ground_truth["cut_type"] = } != {row_data.cut_type = }' + ) + assert ground_truth["chunk_index"] == row_data.chunk_index, ( + f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }' + ) diff --git a/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py b/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py index 156cc87a4..af2595b14 100644 --- a/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py +++ b/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py @@ -5,7 +5,7 @@ from cognee.tasks.web_scraper import DefaultUrlCrawler @pytest.mark.asyncio async def test_fetch(): crawler = DefaultUrlCrawler() - url = "https://en.wikipedia.org/wiki/Large_language_model" + url = "http://example.com/" results = await crawler.fetch_urls(url) assert len(results) == 1 assert isinstance(results, dict) diff --git a/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py b/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py index 946ce8378..5db9b58ce 100644 --- a/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py +++ b/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py @@ -11,7 +11,7 @@ skip_in_ci = pytest.mark.skipif( @skip_in_ci @pytest.mark.asyncio async def test_fetch(): - url = "https://en.wikipedia.org/wiki/Large_language_model" + url = "http://example.com/" results = await fetch_with_tavily(url) assert isinstance(results, dict) assert len(results) == 1 diff --git a/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py b/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py index d91b075aa..200f40a94 100644 --- a/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py +++ b/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py @@ -14,9 +14,7 @@ async def test_url_saves_as_html_file(): await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -44,9 +42,7 @@ async def test_saved_html_is_valid(): await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) content = Path(file_path).read_text() @@ -72,7 +68,7 @@ async def test_add_url(): await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - await cognee.add("https://en.wikipedia.org/wiki/Large_language_model") + await cognee.add("http://example.com/") skip_in_ci = pytest.mark.skipif( @@ -88,7 +84,7 @@ async def test_add_url_with_tavily(): await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - await cognee.add("https://en.wikipedia.org/wiki/Large_language_model") + await cognee.add("http://example.com/") @pytest.mark.asyncio @@ -98,7 +94,7 @@ async def test_add_url_without_incremental_loading(): try: await cognee.add( - "https://en.wikipedia.org/wiki/Large_language_model", + "http://example.com/", incremental_loading=False, ) except Exception as e: @@ -112,7 +108,7 @@ async def test_add_url_with_incremental_loading(): try: await cognee.add( - "https://en.wikipedia.org/wiki/Large_language_model", + "http://example.com/", incremental_loading=True, ) except Exception as e: @@ -125,7 +121,7 @@ async def test_add_url_can_define_preferred_loader_as_list_of_str(): await cognee.prune.prune_system(metadata=True) await cognee.add( - "https://en.wikipedia.org/wiki/Large_language_model", + "http://example.com/", preferred_loaders=["beautiful_soup_loader"], ) @@ -144,7 +140,7 @@ async def test_add_url_with_extraction_rules(): try: await cognee.add( - "https://en.wikipedia.org/wiki/Large_language_model", + "http://example.com/", preferred_loaders={"beautiful_soup_loader": {"extraction_rules": extraction_rules}}, ) except Exception as e: @@ -163,9 +159,7 @@ async def test_loader_is_none_by_default(): } try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -196,9 +190,7 @@ async def test_beautiful_soup_loader_is_selected_loader_if_preferred_loader_prov } try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -225,9 +217,7 @@ async def test_beautiful_soup_loader_works_with_and_without_arguments(): await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -263,9 +253,7 @@ async def test_beautiful_soup_loader_successfully_loads_file_if_required_args_pr await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -302,9 +290,7 @@ async def test_beautiful_soup_loads_file_successfully(): } try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") original_file = Path(file_path) diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py index ab68a8ef1..ddffe53a4 100644 --- a/cognee/tests/test_cognee_server_start.py +++ b/cognee/tests/test_cognee_server_start.py @@ -7,6 +7,7 @@ import requests from pathlib import Path import sys import uuid +import json class TestCogneeServerStart(unittest.TestCase): @@ -90,12 +91,71 @@ class TestCogneeServerStart(unittest.TestCase): ) } - payload = {"datasets": [dataset_name]} + ontology_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + payload = {"datasets": [dataset_name], "ontology_key": [ontology_key]} add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50) if add_response.status_code not in [200, 201]: add_response.raise_for_status() + ontology_content = b""" + + + + + + + + + + + + + + + + A failure caused by physical components. + + + + + An error caused by software logic or configuration. + + + + A human being or individual. + + + + + Programmers + + + + Light Bulb + + + + Hardware Problem + + + """ + + ontology_response = requests.post( + "http://127.0.0.1:8000/api/v1/ontologies", + headers=headers, + files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], + data={ + "ontology_key": json.dumps([ontology_key]), + "description": json.dumps(["Test ontology"]), + }, + ) + self.assertEqual(ontology_response.status_code, 200) + # Cognify request url = "http://127.0.0.1:8000/api/v1/cognify" headers = { @@ -107,6 +167,29 @@ class TestCogneeServerStart(unittest.TestCase): if cognify_response.status_code not in [200, 201]: cognify_response.raise_for_status() + datasets_response = requests.get("http://127.0.0.1:8000/api/v1/datasets", headers=headers) + + datasets = datasets_response.json() + dataset_id = None + for dataset in datasets: + if dataset["name"] == dataset_name: + dataset_id = dataset["id"] + break + + graph_response = requests.get( + f"http://127.0.0.1:8000/api/v1/datasets/{dataset_id}/graph", headers=headers + ) + self.assertEqual(graph_response.status_code, 200) + + graph_data = graph_response.json() + ontology_nodes = [ + node for node in graph_data.get("nodes") if node.get("properties").get("ontology_valid") + ] + + self.assertGreater( + len(ontology_nodes), 0, "No ontology nodes found - ontology was not integrated" + ) + # TODO: Add test to verify cognify pipeline is complete before testing search # Search request diff --git a/cognee/tests/test_data/example_with_header.csv b/cognee/tests/test_data/example_with_header.csv new file mode 100644 index 000000000..dc900e5ef --- /dev/null +++ b/cognee/tests/test_data/example_with_header.csv @@ -0,0 +1,3 @@ +id,name,age,city,country +1,Eric,30,Beijing,China +2,Joe,35,Berlin,Germany diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py new file mode 100644 index 000000000..af3a4d90e --- /dev/null +++ b/cognee/tests/unit/api/test_ontology_endpoint.py @@ -0,0 +1,272 @@ +import pytest +import uuid +from fastapi.testclient import TestClient +from unittest.mock import patch, Mock, AsyncMock +from types import SimpleNamespace +import importlib +from cognee.api.client import app + +gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") + + +@pytest.fixture +def client(): + return TestClient(app) + + +@pytest.fixture +def mock_user(): + user = Mock() + user.id = "test-user-123" + return user + + +@pytest.fixture +def mock_default_user(): + """Mock default user for testing.""" + return SimpleNamespace( + id=str(uuid.uuid4()), + email="default@example.com", + is_active=True, + tenant_id=str(uuid.uuid4()), + ) + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_success(mock_get_default_user, client, mock_default_user): + """Test successful ontology upload""" + import json + + mock_get_default_user.return_value = mock_default_user + ontology_content = ( + b"" + ) + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + + response = client.post( + "/api/v1/ontologies", + files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], + data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key + assert "uploaded_at" in data["uploaded_ontologies"][0] + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user): + """Test 400 response for non-.owl files""" + mock_get_default_user.return_value = mock_default_user + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + response = client.post( + "/api/v1/ontologies", + files={"ontology_file": ("test.txt", b"not xml")}, + data={"ontology_key": unique_key}, + ) + assert response.status_code == 400 + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user): + """Test 400 response for missing file or key""" + import json + + mock_get_default_user.return_value = mock_default_user + # Missing file + response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])}) + assert response.status_code == 400 + + # Missing key + response = client.post( + "/api/v1/ontologies", files=[("ontology_file", ("test.owl", b"xml", "application/xml"))] + ) + assert response.status_code == 400 + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user): + """Test behavior when default user is provided (no explicit authentication)""" + import json + + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + mock_get_default_user.return_value = mock_default_user + response = client.post( + "/api/v1/ontologies", + files=[("ontology_file", ("test.owl", b"", "application/xml"))], + data={"ontology_key": json.dumps([unique_key])}, + ) + + # The current system provides a default user when no explicit authentication is given + # This test verifies the system works with conditional authentication + assert response.status_code == 200 + data = response.json() + assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key + assert "uploaded_at" in data["uploaded_ontologies"][0] + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user): + """Test uploading multiple ontology files in single request""" + import io + + mock_get_default_user.return_value = mock_default_user + # Create mock files + file1_content = b"" + file2_content = b"" + + files = [ + ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), + ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), + ] + data = { + "ontology_key": '["vehicles", "manufacturers"]', + "descriptions": '["Base vehicles", "Car manufacturers"]', + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + + assert response.status_code == 200 + result = response.json() + assert "uploaded_ontologies" in result + assert len(result["uploaded_ontologies"]) == 2 + assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles" + assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers" + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user): + """Test that upload endpoint accepts array parameters""" + import io + import json + + mock_get_default_user.return_value = mock_default_user + file_content = b"" + + files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))] + data = { + "ontology_key": json.dumps(["single_key"]), + "descriptions": json.dumps(["Single ontology"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + + assert response.status_code == 200 + result = response.json() + assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key" + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user): + """Test cognify endpoint accepts multiple ontology keys""" + payload = { + "datasets": ["test_dataset"], + "ontology_key": ["ontology1", "ontology2"], # Array instead of string + "run_in_background": False, + } + + response = client.post("/api/v1/cognify", json=payload) + + # Should not fail due to ontology_key type + assert response.status_code in [200, 400, 409] # May fail for other reasons, not type + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user): + """Test complete workflow: upload multiple ontologies → cognify with multiple keys""" + import io + import json + + mock_get_default_user.return_value = mock_default_user + # Step 1: Upload multiple ontologies + file1_content = b""" + + + """ + + file2_content = b""" + + + """ + + files = [ + ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), + ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), + ] + data = { + "ontology_key": json.dumps(["vehicles", "manufacturers"]), + "descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]), + } + + upload_response = client.post("/api/v1/ontologies", files=files, data=data) + assert upload_response.status_code == 200 + + # Step 2: Verify ontologies are listed + list_response = client.get("/api/v1/ontologies") + assert list_response.status_code == 200 + ontologies = list_response.json() + assert "vehicles" in ontologies + assert "manufacturers" in ontologies + + # Step 3: Test cognify with multiple ontologies + cognify_payload = { + "datasets": ["test_dataset"], + "ontology_key": ["vehicles", "manufacturers"], + "run_in_background": False, + } + + cognify_response = client.post("/api/v1/cognify", json=cognify_payload) + # Should not fail due to ontology handling (may fail for dataset reasons) + assert cognify_response.status_code != 400 # Not a validation error + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_multifile_error_handling(mock_get_default_user, client, mock_default_user): + """Test error handling for invalid multifile uploads""" + import io + import json + + # Test mismatched array lengths + file_content = b"" + files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))] + data = { + "ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file + "descriptions": json.dumps(["desc1"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + assert response.status_code == 400 + assert "Number of keys must match number of files" in response.json()["error"] + + # Test duplicate keys + files = [ + ("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")), + ("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")), + ] + data = { + "ontology_key": json.dumps(["duplicate", "duplicate"]), + "descriptions": json.dumps(["desc1", "desc2"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + assert response.status_code == 400 + assert "Duplicate ontology keys not allowed" in response.json()["error"] + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user): + """Test cognify with non-existent ontology key""" + mock_get_default_user.return_value = mock_default_user + + payload = { + "datasets": ["test_dataset"], + "ontology_key": ["nonexistent_key"], + "run_in_background": False, + } + + response = client.post("/api/v1/cognify", json=payload) + assert response.status_code == 409 + assert "Ontology key 'nonexistent_key' not found" in response.json()["error"] diff --git a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py index 37ba113b5..1d2b79cf9 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py @@ -9,7 +9,7 @@ def test_node_initialization(): """Test that a Node is initialized correctly.""" node = Node("node1", {"attr1": "value1"}, dimension=2) assert node.id == "node1" - assert node.attributes == {"attr1": "value1", "vector_distance": np.inf} + assert node.attributes == {"attr1": "value1", "vector_distance": 3.5} assert len(node.status) == 2 assert np.all(node.status == 1) @@ -96,7 +96,7 @@ def test_edge_initialization(): edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2) assert edge.node1 == node1 assert edge.node2 == node2 - assert edge.attributes == {"vector_distance": np.inf, "weight": 10} + assert edge.attributes == {"vector_distance": 3.5, "weight": 10} assert edge.directed is False assert len(edge.status) == 2 assert np.all(edge.status == 1) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 6888648c3..711479387 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -1,4 +1,5 @@ import pytest +from unittest.mock import AsyncMock from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph @@ -11,6 +12,30 @@ def setup_graph(): return CogneeGraph() +@pytest.fixture +def mock_adapter(): + """Fixture to create a mock adapter for database operations.""" + adapter = AsyncMock() + return adapter + + +@pytest.fixture +def mock_vector_engine(): + """Fixture to create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + def test_add_node_success(setup_graph): """Test successful addition of a node.""" graph = setup_graph @@ -73,3 +98,433 @@ def test_get_edges_nonexistent_node(setup_graph): graph = setup_graph with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."): graph.get_edges_from_node("nonexistent") + + +@pytest.mark.asyncio +async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter): + """Test projecting a full graph from database.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1", "description": "First node"}), + ("2", {"name": "Node2", "description": "Second node"}), + ] + edges_data = [ + ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name", "description"], + edge_properties_to_project=["relationship_name"], + ) + + assert len(graph.nodes) == 2 + assert len(graph.edges) == 1 + assert graph.get_node("1") is not None + assert graph.get_node("2") is not None + assert graph.edges[0].node1.id == "1" + assert graph.edges[0].node2.id == "2" + + +@pytest.mark.asyncio +async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter): + """Test projecting an ID-filtered graph from database.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1"}), + ("2", {"name": "Node2"}), + ] + edges_data = [ + ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=["relationship_name"], + relevant_ids_to_filter=["1", "2"], + ) + + assert len(graph.nodes) == 2 + assert len(graph.edges) == 1 + mock_adapter.get_id_filtered_graph_data.assert_called_once() + + +@pytest.mark.asyncio +async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter): + """Test projecting a nodeset subgraph filtered by node type and name.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Alice", "type": "Person"}), + ("2", {"name": "Bob", "type": "Person"}), + ] + edges_data = [ + ("1", "2", "KNOWS", {"relationship_name": "knows"}), + ] + + mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name", "type"], + edge_properties_to_project=["relationship_name"], + node_type="Person", + node_name=["Alice"], + ) + + assert len(graph.nodes) == 2 + assert graph.get_node("1") is not None + assert len(graph.edges) == 1 + mock_adapter.get_nodeset_subgraph.assert_called_once() + + +@pytest.mark.asyncio +async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter): + """Test projecting empty graph raises EntityNotFoundError.""" + graph = setup_graph + + mock_adapter.get_graph_data = AsyncMock(return_value=([], [])) + + with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."): + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=[], + ) + + +@pytest.mark.asyncio +async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter): + """Test that edges referencing missing nodes raise error.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1"}), + ] + edges_data = [ + ("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"): + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=["relationship_name"], + ) + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_nodes(setup_graph): + """Test mapping vector distances to graph nodes.""" + graph = setup_graph + + node1 = Node("1", {"name": "Node1"}) + node2 = Node("2", {"name": "Node2"}) + graph.add_node(node1) + graph.add_node(node2) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ] + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + + +@pytest.mark.asyncio +async def test_map_vector_distances_partial_node_coverage(setup_graph): + """Test mapping vector distances when only some nodes have results.""" + graph = setup_graph + + node1 = Node("1", {"name": "Node1"}) + node2 = Node("2", {"name": "Node2"}) + node3 = Node("3", {"name": "Node3"}) + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ] + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + assert graph.get_node("3").attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_multiple_categories(setup_graph): + """Test mapping vector distances from multiple collection categories.""" + graph = setup_graph + + # Create nodes + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + node4 = Node("4") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + graph.add_node(node4) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ], + "TextSummary_text": [ + MockScoredResult("3", 0.92), + ], + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + assert graph.get_node("3").attributes.get("vector_distance") == 0.92 + assert graph.get_node("4").attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine): + """Test mapping vector distances to edges when edge_distances provided.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.92 + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine): + """Test mapping edge distances when searching for them.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + mock_vector_engine.search.return_value = [ + MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=None, + ) + + mock_vector_engine.search.assert_called_once() + assert graph.edges[0].attributes.get("vector_distance") == 0.88 + + +@pytest.mark.asyncio +async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine): + """Test mapping edge distances when only some edges have results.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"}) + edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"}) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.92 + assert graph.edges[1].attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_edges_fallback_to_relationship_type( + setup_graph, mock_vector_engine +): + """Test that edge mapping falls back to relationship_type when edge_text is missing.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"relationship_type": "KNOWS"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.85 + + +@pytest.mark.asyncio +async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine): + """Test edge mapping when no edges match the distance results.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine): + """Test that invalid query vector raises error.""" + graph = setup_graph + + with pytest.raises(ValueError, match="Failed to generate query embedding"): + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[], + edge_distances=None, + ) + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances(setup_graph): + """Test calculating top triplet importances by score.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + node4 = Node("4") + + node1.add_attribute("vector_distance", 0.9) + node2.add_attribute("vector_distance", 0.8) + node3.add_attribute("vector_distance", 0.7) + node4.add_attribute("vector_distance", 0.6) + + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + graph.add_node(node4) + + edge1 = Edge(node1, node2) + edge2 = Edge(node2, node3) + edge3 = Edge(node3, node4) + + edge1.add_attribute("vector_distance", 0.85) + edge2.add_attribute("vector_distance", 0.75) + edge3.add_attribute("vector_distance", 0.65) + + graph.add_edge(edge1) + graph.add_edge(edge2) + graph.add_edge(edge3) + + top_triplets = await graph.calculate_top_triplet_importances(k=2) + + assert len(top_triplets) == 2 + + assert top_triplets[0] == edge3 + assert top_triplets[1] == edge2 + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_default_distances(setup_graph): + """Test calculating importances when nodes/edges have no vector distances.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge(node1, node2) + graph.add_edge(edge) + + top_triplets = await graph.calculate_top_triplet_importances(k=1) + + assert len(top_triplets) == 1 + assert top_triplets[0] == edge diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py new file mode 100644 index 000000000..5eb6fb105 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -0,0 +1,582 @@ +import pytest +from unittest.mock import AsyncMock, patch + +from cognee.modules.retrieval.utils.brute_force_triplet_search import ( + brute_force_triplet_search, + get_memory_fragment, +) +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_empty_query(): + """Test that empty query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query="") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_none_query(): + """Test that None query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query=None) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_negative_top_k(): + """Test that negative top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=-1) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_zero_top_k(): + """Test that zero top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=0) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_global_search(): + """Test that wide_search_limit is applied for global search (node_name=None).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=None, # Global search + wide_search_top_k=75, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 75 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): + """Test that wide_search_limit is None for filtered search (node_name provided).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=["Node1"], + wide_search_top_k=50, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_default(): + """Test that wide_search_top_k defaults to 100.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", node_name=None) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 100 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_default_collections(): + """Test that default collections are used when none provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test") + + expected_collections = [ + "Entity_name", + "TextSummary_text", + "EntityType_name", + "DocumentChunk_text", + ] + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert call_collections == expected_collections + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_custom_collections(): + """Test that custom collections are used when provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + custom_collections = ["CustomCol1", "CustomCol2"] + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", collections=custom_collections) + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert call_collections == custom_collections + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_all_collections_empty(): + """Test that empty list is returned when all collections return no results.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + results = await brute_force_triplet_search(query="test") + assert results == [] + + +# Tests for query embedding + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_embeds_query(): + """Test that query is embedded before searching.""" + query_text = "test query" + expected_vector = [0.1, 0.2, 0.3] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query=query_text) + + mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text]) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["query_vector"] == expected_vector + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_extracts_node_ids_global_search(): + """Test that node IDs are extracted from search results for global search.""" + scored_results = [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + MockScoredResult("node3", 0.92), + ] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=scored_results) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_reuses_provided_fragment(): + """Test that provided memory fragment is reused instead of creating new one.""" + provided_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment" + ) as mock_get_fragment, + ): + await brute_force_triplet_search( + query="test", + memory_fragment=provided_fragment, + node_name=["node"], + ) + + mock_get_fragment.assert_not_called() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): + """Test that memory fragment is created when not provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + await brute_force_triplet_search(query="test", node_name=["node"]) + + mock_get_fragment.assert_called_once() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation(): + """Test that custom top_k is passed to importance calculation.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + custom_top_k = 15 + await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) + + mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): + """Test that get_memory_fragment returns empty graph when entity not found.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.project_graph_from_db = AsyncMock( + side_effect=EntityNotFoundError("Entity not found") + ) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ): + fragment = await get_memory_fragment() + + assert isinstance(fragment, CogneeGraph) + assert len(fragment.nodes) == 0 + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_error(): + """Test that get_memory_fragment returns empty graph on generic error.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error")) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ): + fragment = await get_memory_fragment() + + assert isinstance(fragment, CogneeGraph) + assert len(fragment.nodes) == 0 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_deduplicates_node_ids(): + """Test that duplicate node IDs across collections are deduplicated.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ] + elif collection_name == "TextSummary_text": + return [ + MockScoredResult("node1", 0.90), + MockScoredResult("node3", 0.92), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + assert len(call_kwargs["relevant_ids_to_filter"]) == 3 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_excludes_edge_collection(): + """Test that EdgeType_relationship_name collection is excluded from ID extraction.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "EdgeType_relationship_name": + return [MockScoredResult("edge1", 0.88)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search( + query="test", + node_name=None, + collections=["Entity_name", "EdgeType_relationship_name"], + ) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] == ["node1"] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_skips_nodes_without_ids(): + """Test that nodes without ID attribute are skipped.""" + + class ScoredResultNoId: + """Mock result without id attribute.""" + + def __init__(self, score): + self.score = score + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + ScoredResultNoId(0.90), + MockScoredResult("node2", 0.87), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_handles_tuple_results(): + """Test that both list and tuple results are handled correctly.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return ( + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ) + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_mixed_empty_collections(): + """Test ID extraction with mixed empty and non-empty collections.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "TextSummary_text": + return [] + elif collection_name == "EntityType_name": + return [MockScoredResult("node2", 0.92)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} diff --git a/cognee/tests/unit/processing/chunks/chunk_by_row_test.py b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py new file mode 100644 index 000000000..7d6a73a06 --- /dev/null +++ b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py @@ -0,0 +1,52 @@ +from itertools import product + +import numpy as np +import pytest + +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine +from cognee.tasks.chunks import chunk_by_row + +INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA" +max_chunk_size_vals = [8, 32] + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_chunk_by_row_isomorphism(input_text, max_chunk_size): + chunks = chunk_by_row(input_text, max_chunk_size) + reconstructed_text = ", ".join([chunk["text"] for chunk in chunks]) + assert reconstructed_text == input_text, ( + f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" + ) + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_row_chunk_length(input_text, max_chunk_size): + chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)) + embedding_engine = get_embedding_engine() + + chunk_lengths = np.array( + [embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks] + ) + + larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size] + assert np.all(chunk_lengths <= max_chunk_size), ( + f"{max_chunk_size = }: {larger_chunks} are too large" + ) + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size): + chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size) + chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) + assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( + f"{chunk_indices = } are not monotonically increasing" + ) diff --git a/poetry.lock b/poetry.lock index 67de51633..6e88ccd22 100644 --- a/poetry.lock +++ b/poetry.lock @@ -11656,7 +11656,9 @@ groups = ["main"] files = [ {file = "SQLAlchemy-2.0.43-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:21ba7a08a4253c5825d1db389d4299f64a100ef9800e4624c8bf70d8f136e6ed"}, {file = "SQLAlchemy-2.0.43-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11b9503fa6f8721bef9b8567730f664c5a5153d25e247aadc69247c4bc605227"}, + {file = "SQLAlchemy-2.0.43-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07097c0a1886c150ef2adba2ff7437e84d40c0f7dcb44a2c2b9c905ccfc6361c"}, {file = "SQLAlchemy-2.0.43-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cdeff998cb294896a34e5b2f00e383e7c5c4ef3b4bfa375d9104723f15186443"}, + {file = "SQLAlchemy-2.0.43-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:bcf0724a62a5670e5718957e05c56ec2d6850267ea859f8ad2481838f889b42c"}, {file = "SQLAlchemy-2.0.43-cp37-cp37m-win32.whl", hash = "sha256:c697575d0e2b0a5f0433f679bda22f63873821d991e95a90e9e52aae517b2e32"}, {file = "SQLAlchemy-2.0.43-cp37-cp37m-win_amd64.whl", hash = "sha256:d34c0f6dbefd2e816e8f341d0df7d4763d382e3f452423e752ffd1e213da2512"}, {file = "sqlalchemy-2.0.43-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70322986c0c699dca241418fcf18e637a4369e0ec50540a2b907b184c8bca069"}, @@ -11691,12 +11693,20 @@ files = [ {file = "sqlalchemy-2.0.43-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9df7126fd9db49e3a5a3999442cc67e9ee8971f3cb9644250107d7296cb2a164"}, {file = "sqlalchemy-2.0.43-cp313-cp313-win32.whl", hash = "sha256:7f1ac7828857fcedb0361b48b9ac4821469f7694089d15550bbcf9ab22564a1d"}, {file = "sqlalchemy-2.0.43-cp313-cp313-win_amd64.whl", hash = "sha256:971ba928fcde01869361f504fcff3b7143b47d30de188b11c6357c0505824197"}, + {file = "sqlalchemy-2.0.43-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4e6aeb2e0932f32950cf56a8b4813cb15ff792fc0c9b3752eaf067cfe298496a"}, + {file = "sqlalchemy-2.0.43-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:61f964a05356f4bca4112e6334ed7c208174511bd56e6b8fc86dad4d024d4185"}, {file = "sqlalchemy-2.0.43-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46293c39252f93ea0910aababa8752ad628bcce3a10d3f260648dd472256983f"}, + {file = "sqlalchemy-2.0.43-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:136063a68644eca9339d02e6693932116f6a8591ac013b0014479a1de664e40a"}, {file = "sqlalchemy-2.0.43-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6e2bf13d9256398d037fef09fd8bf9b0bf77876e22647d10761d35593b9ac547"}, + {file = "sqlalchemy-2.0.43-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:44337823462291f17f994d64282a71c51d738fc9ef561bf265f1d0fd9116a782"}, {file = "sqlalchemy-2.0.43-cp38-cp38-win32.whl", hash = "sha256:13194276e69bb2af56198fef7909d48fd34820de01d9c92711a5fa45497cc7ed"}, {file = "sqlalchemy-2.0.43-cp38-cp38-win_amd64.whl", hash = "sha256:334f41fa28de9f9be4b78445e68530da3c5fa054c907176460c81494f4ae1f5e"}, + {file = "sqlalchemy-2.0.43-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ceb5c832cc30663aeaf5e39657712f4c4241ad1f638d487ef7216258f6d41fe7"}, + {file = "sqlalchemy-2.0.43-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:11f43c39b4b2ec755573952bbcc58d976779d482f6f832d7f33a8d869ae891bf"}, {file = "sqlalchemy-2.0.43-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:413391b2239db55be14fa4223034d7e13325a1812c8396ecd4f2c08696d5ccad"}, + {file = "sqlalchemy-2.0.43-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c379e37b08c6c527181a397212346be39319fb64323741d23e46abd97a400d34"}, {file = "sqlalchemy-2.0.43-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03d73ab2a37d9e40dec4984d1813d7878e01dbdc742448d44a7341b7a9f408c7"}, + {file = "sqlalchemy-2.0.43-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:8cee08f15d9e238ede42e9bbc1d6e7158d0ca4f176e4eab21f88ac819ae3bd7b"}, {file = "sqlalchemy-2.0.43-cp39-cp39-win32.whl", hash = "sha256:b3edaec7e8b6dc5cd94523c6df4f294014df67097c8217a89929c99975811414"}, {file = "sqlalchemy-2.0.43-cp39-cp39-win_amd64.whl", hash = "sha256:227119ce0a89e762ecd882dc661e0aa677a690c914e358f0dd8932a2e8b2765b"}, {file = "sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc"}, diff --git a/pyproject.toml b/pyproject.toml index 13266f83e..a9b895dfb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "cognee" -version = "0.3.9" +version = "0.5.0.dev0" description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." authors = [ { name = "Vasilije Markovic" }, @@ -156,7 +156,6 @@ Homepage = "https://www.cognee.ai" Repository = "https://github.com/topoteretes/cognee" [project.scripts] -cognee = "cognee.cli._cognee:main" cognee-cli = "cognee.cli._cognee:main" [build-system] diff --git a/uv.lock b/uv.lock index 8c35a3366..cc66c3d7e 100644 --- a/uv.lock +++ b/uv.lock @@ -929,7 +929,7 @@ wheels = [ [[package]] name = "cognee" -version = "0.3.9" +version = "0.5.0.dev0" source = { editable = "." } dependencies = [ { name = "aiofiles" }, @@ -2560,6 +2560,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/91/ae2eb6b7979e2f9b035a9f612cf70f1bf54aad4e1d125129bef1eae96f19/greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d", size = 584358, upload-time = "2025-08-07T13:18:23.708Z" }, { url = "https://files.pythonhosted.org/packages/f7/85/433de0c9c0252b22b16d413c9407e6cb3b41df7389afc366ca204dbc1393/greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5", size = 1113550, upload-time = "2025-08-07T13:42:37.467Z" }, { url = "https://files.pythonhosted.org/packages/a1/8d/88f3ebd2bc96bf7747093696f4335a0a8a4c5acfcf1b757717c0d2474ba3/greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f", size = 1137126, upload-time = "2025-08-07T13:18:20.239Z" }, + { url = "https://files.pythonhosted.org/packages/f1/29/74242b7d72385e29bcc5563fba67dad94943d7cd03552bac320d597f29b2/greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7", size = 1544904, upload-time = "2025-11-04T12:42:04.763Z" }, + { url = "https://files.pythonhosted.org/packages/c8/e2/1572b8eeab0f77df5f6729d6ab6b141e4a84ee8eb9bc8c1e7918f94eda6d/greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8", size = 1611228, upload-time = "2025-11-04T12:42:08.423Z" }, { url = "https://files.pythonhosted.org/packages/d6/6f/b60b0291d9623c496638c582297ead61f43c4b72eef5e9c926ef4565ec13/greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c", size = 298654, upload-time = "2025-08-07T13:50:00.469Z" }, { url = "https://files.pythonhosted.org/packages/a4/de/f28ced0a67749cac23fecb02b694f6473f47686dff6afaa211d186e2ef9c/greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2", size = 272305, upload-time = "2025-08-07T13:15:41.288Z" }, { url = "https://files.pythonhosted.org/packages/09/16/2c3792cba130000bf2a31c5272999113f4764fd9d874fb257ff588ac779a/greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246", size = 632472, upload-time = "2025-08-07T13:42:55.044Z" }, @@ -2569,6 +2571,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684, upload-time = "2025-08-07T13:18:25.164Z" }, { url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647, upload-time = "2025-08-07T13:42:38.655Z" }, { url = "https://files.pythonhosted.org/packages/3f/cc/b07000438a29ac5cfb2194bfc128151d52f333cee74dd7dfe3fb733fc16c/greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa", size = 1142073, upload-time = "2025-08-07T13:18:21.737Z" }, + { url = "https://files.pythonhosted.org/packages/67/24/28a5b2fa42d12b3d7e5614145f0bd89714c34c08be6aabe39c14dd52db34/greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c", size = 1548385, upload-time = "2025-11-04T12:42:11.067Z" }, + { url = "https://files.pythonhosted.org/packages/6a/05/03f2f0bdd0b0ff9a4f7b99333d57b53a7709c27723ec8123056b084e69cd/greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5", size = 1613329, upload-time = "2025-11-04T12:42:12.928Z" }, { url = "https://files.pythonhosted.org/packages/d8/0f/30aef242fcab550b0b3520b8e3561156857c94288f0332a79928c31a52cf/greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9", size = 299100, upload-time = "2025-08-07T13:44:12.287Z" }, { url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" }, { url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" }, @@ -2578,6 +2582,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" }, { url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" }, { url = "https://files.pythonhosted.org/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142, upload-time = "2025-08-07T13:18:22.981Z" }, + { url = "https://files.pythonhosted.org/packages/27/45/80935968b53cfd3f33cf99ea5f08227f2646e044568c9b1555b58ffd61c2/greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0", size = 1564846, upload-time = "2025-11-04T12:42:15.191Z" }, + { url = "https://files.pythonhosted.org/packages/69/02/b7c30e5e04752cb4db6202a3858b149c0710e5453b71a3b2aec5d78a1aab/greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d", size = 1633814, upload-time = "2025-11-04T12:42:17.175Z" }, { url = "https://files.pythonhosted.org/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899, upload-time = "2025-08-07T13:38:53.448Z" }, { url = "https://files.pythonhosted.org/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" }, { url = "https://files.pythonhosted.org/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" }, @@ -2587,6 +2593,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" }, { url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" }, { url = "https://files.pythonhosted.org/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" }, + { url = "https://files.pythonhosted.org/packages/1c/53/f9c440463b3057485b8594d7a638bed53ba531165ef0ca0e6c364b5cc807/greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b", size = 1564759, upload-time = "2025-11-04T12:42:19.395Z" }, + { url = "https://files.pythonhosted.org/packages/47/e4/3bb4240abdd0a8d23f4f88adec746a3099f0d86bfedb623f063b2e3b4df0/greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929", size = 1634288, upload-time = "2025-11-04T12:42:21.174Z" }, { url = "https://files.pythonhosted.org/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685, upload-time = "2025-08-07T13:24:38.824Z" }, ]