From 742866b4c9f1d4aa53ab60fb54b79474fbfea0d2 Mon Sep 17 00:00:00 2001 From: EricXiao Date: Wed, 22 Oct 2025 16:56:46 +0800 Subject: [PATCH 01/24] feat: csv ingestion loader & chunk Signed-off-by: EricXiao --- cognee/cli/commands/cognify_command.py | 9 +- cognee/cli/config.py | 2 +- .../files/utils/guess_file_type.py | 43 +++++ .../files/utils/is_csv_content.py | 181 ++++++++++++++++++ cognee/infrastructure/loaders/LoaderEngine.py | 1 + .../infrastructure/loaders/core/__init__.py | 3 +- .../infrastructure/loaders/core/csv_loader.py | 93 +++++++++ .../loaders/core/text_loader.py | 3 +- .../loaders/supported_loaders.py | 3 +- cognee/modules/chunking/CsvChunker.py | 35 ++++ .../processing/document_types/CsvDocument.py | 33 ++++ .../processing/document_types/__init__.py | 1 + cognee/tasks/chunks/__init__.py | 1 + cognee/tasks/chunks/chunk_by_row.py | 94 +++++++++ cognee/tasks/documents/classify_documents.py | 2 + .../integration/documents/CsvDocument_test.py | 70 +++++++ .../tests/test_data/example_with_header.csv | 3 + .../processing/chunks/chunk_by_row_test.py | 52 +++++ 18 files changed, 623 insertions(+), 6 deletions(-) create mode 100644 cognee/infrastructure/files/utils/is_csv_content.py create mode 100644 cognee/infrastructure/loaders/core/csv_loader.py create mode 100644 cognee/modules/chunking/CsvChunker.py create mode 100644 cognee/modules/data/processing/document_types/CsvDocument.py create mode 100644 cognee/tasks/chunks/chunk_by_row.py create mode 100644 cognee/tests/integration/documents/CsvDocument_test.py create mode 100644 cognee/tests/test_data/example_with_header.csv create mode 100644 cognee/tests/unit/processing/chunks/chunk_by_row_test.py diff --git a/cognee/cli/commands/cognify_command.py b/cognee/cli/commands/cognify_command.py index 16eaf0454..b89c1f70e 100644 --- a/cognee/cli/commands/cognify_command.py +++ b/cognee/cli/commands/cognify_command.py @@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin Processing Pipeline: 1. **Document Classification**: Identifies document types and structures -2. **Permission Validation**: Ensures user has processing rights +2. **Permission Validation**: Ensures user has processing rights 3. **Text Chunking**: Breaks content into semantically meaningful segments 4. **Entity Extraction**: Identifies key concepts, people, places, organizations 5. **Relationship Detection**: Discovers connections between entities @@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge chunker_class = LangchainChunker except ImportError: fmt.warning("LangchainChunker not available, using TextChunker") + elif args.chunker == "CsvChunker": + try: + from cognee.modules.chunking.CsvChunker import CsvChunker + + chunker_class = CsvChunker + except ImportError: + fmt.warning("CsvChunker not available, using TextChunker") result = await cognee.cognify( datasets=datasets, diff --git a/cognee/cli/config.py b/cognee/cli/config.py index d016608c1..082adbaec 100644 --- a/cognee/cli/config.py +++ b/cognee/cli/config.py @@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [ ] # Chunker choices -CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"] +CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"] # Output format choices OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"] diff --git a/cognee/infrastructure/files/utils/guess_file_type.py b/cognee/infrastructure/files/utils/guess_file_type.py index dcdd68cad..10f59a400 100644 --- a/cognee/infrastructure/files/utils/guess_file_type.py +++ b/cognee/infrastructure/files/utils/guess_file_type.py @@ -1,6 +1,8 @@ from typing import BinaryIO import filetype + from .is_text_content import is_text_content +from .is_csv_content import is_csv_content class FileTypeException(Exception): @@ -134,3 +136,44 @@ def guess_file_type(file: BinaryIO) -> filetype.Type: raise FileTypeException(f"Unknown file detected: {file.name}.") return file_type + + +class CsvFileType(filetype.Type): + """ + Match CSV file types based on MIME type and extension. + + Public methods: + - match + + Instance variables: + - MIME: The MIME type of the CSV. + - EXTENSION: The file extension of the CSV. + """ + + MIME = "text/csv" + EXTENSION = "csv" + + def __init__(self): + super().__init__(mime=self.MIME, extension=self.EXTENSION) + + def match(self, buf): + """ + Determine if the given buffer contains csv content. + + Parameters: + ----------- + + - buf: The buffer to check for csv content. + + Returns: + -------- + + Returns True if the buffer is identified as csv content, otherwise False. + """ + + return is_csv_content(buf) + + +csv_file_type = CsvFileType() + +filetype.add_type(csv_file_type) diff --git a/cognee/infrastructure/files/utils/is_csv_content.py b/cognee/infrastructure/files/utils/is_csv_content.py new file mode 100644 index 000000000..07b7ea69b --- /dev/null +++ b/cognee/infrastructure/files/utils/is_csv_content.py @@ -0,0 +1,181 @@ +import csv +from collections import Counter + + +def is_csv_content(content): + """ + Heuristically determine whether a bytes-like object is CSV text. + + Strategy (fail-fast and cheap to expensive): + 1) Decode: Try a small ordered list of common encodings with strict errors. + 2) Line sampling: require >= 2 non-empty lines; sample up to 50 lines. + 3) Delimiter detection: + - Prefer csv.Sniffer() with common delimiters. + - Fallback to a lightweight consistency heuristic. + 4) Lightweight parse check: + - Parse a few lines with the delimiter. + - Ensure at least 2 valid rows and relatively stable column counts. + + Returns: + bool: True if the buffer looks like CSV; False otherwise. + """ + try: + encoding_list = [ + "utf-8", + "utf-8-sig", + "utf-32-le", + "utf-32-be", + "utf-16-le", + "utf-16-be", + "gb18030", + "shift_jis", + "cp949", + "cp1252", + "iso-8859-1", + ] + + # Try to decode strictly—if decoding fails for all encodings, it's not text/CSV. + text = None + for enc in encoding_list: + try: + text = content.decode(enc, errors="strict") + break + except UnicodeDecodeError: + continue + if text is None: + return False + + # Reject empty/whitespace-only payloads. + stripped = text.strip() + if not stripped: + return False + + # Split into logical lines and drop empty ones. Require at least two lines. + lines = [ln for ln in text.splitlines() if ln.strip()] + if len(lines) < 2: + return False + + # Take a small sample to keep sniffing cheap and predictable. + sample_lines = lines[:50] + + # Detect delimiter using csv.Sniffer first; if that fails, use our heuristic. + delimiter = _sniff_delimiter(sample_lines) or _heuristic_delimiter(sample_lines) + if not delimiter: + return False + + # Finally, do a lightweight parse sanity check with the chosen delimiter. + return _lightweight_parse_check(sample_lines, delimiter) + except Exception: + return False + + +def _sniff_delimiter(lines): + """ + Try Python's built-in csv.Sniffer on a sample. + + Args: + lines (list[str]): Sample lines (already decoded). + + Returns: + str | None: The detected delimiter if sniffing succeeds; otherwise None. + """ + # Join up to 50 lines to form the sample string Sniffer will inspect. + sample = "\n".join(lines[:50]) + try: + dialect = csv.Sniffer().sniff(sample, delimiters=",\t;|") + return dialect.delimiter + except Exception: + # Sniffer is known to be brittle on small/dirty samples—silently fallback. + return None + + +def _heuristic_delimiter(lines): + """ + Fallback delimiter detection based on count consistency per line. + + Heuristic: + - For each candidate delimiter, count occurrences per line. + - Keep only lines with count > 0 (line must contain the delimiter). + - Require at least half of lines to contain the delimiter (min 2). + - Compute the mode (most common count). If the proportion of lines that + exhibit the modal count is >= 80%, accept that delimiter. + + Args: + lines (list[str]): Sample lines. + + Returns: + str | None: Best delimiter if one meets the consistency threshold; else None. + """ + candidates = [",", "\t", ";", "|"] + best = None + best_score = 0.0 + + for d in candidates: + # Count how many times the delimiter appears in each line. + counts = [ln.count(d) for ln in lines] + # Consider only lines that actually contain the delimiter at least once. + nonzero = [c for c in counts if c > 0] + + # Require that more than half of lines (and at least 2) contain the delimiter. + if len(nonzero) < max(2, int(0.5 * len(lines))): + continue + + # Find the modal count and its frequency. + cnt = Counter(nonzero) + pairs = cnt.most_common(1) + if not pairs: + continue + + mode, mode_freq = pairs[0] + # Consistency ratio: lines with the modal count / total lines in the sample. + consistency = mode_freq / len(lines) + # Accept if consistent enough and better than any previous candidate. + if mode >= 1 and consistency >= 0.80 and consistency > best_score: + best = d + best_score = consistency + + return best + + +def _lightweight_parse_check(lines, delimiter): + """ + Parse a few lines with csv.reader and check structural stability. + + Heuristic: + - Parse up to 5 lines with the given delimiter. + - Count column widths per parsed row. + - Require at least 2 non-empty rows. + - Allow at most 1 row whose width deviates by >2 columns from the first row. + + Args: + lines (list[str]): Sample lines (decoded). + delimiter (str): Delimiter chosen by sniffing/heuristics. + + Returns: + bool: True if parsing looks stable; False otherwise. + """ + try: + # csv.reader accepts any iterable of strings; feeding the first 10 lines is fine. + reader = csv.reader(lines[:10], delimiter=delimiter) + widths = [] + valid_rows = 0 + for row in reader: + if not row: + continue + + widths.append(len(row)) + valid_rows += 1 + + # Need at least two meaningful rows to make a judgment. + if valid_rows < 2: + return False + + if widths: + first = widths[0] + # Count rows whose width deviates significantly (>2) from the first row. + unstable = sum(1 for w in widths if abs(w - first) > 2) + # Permit at most 1 unstable row among the parsed sample. + return unstable <= 1 + return False + except Exception: + return False diff --git a/cognee/infrastructure/loaders/LoaderEngine.py b/cognee/infrastructure/loaders/LoaderEngine.py index 6b62f7641..37e63c9fc 100644 --- a/cognee/infrastructure/loaders/LoaderEngine.py +++ b/cognee/infrastructure/loaders/LoaderEngine.py @@ -30,6 +30,7 @@ class LoaderEngine: "pypdf_loader", "image_loader", "audio_loader", + "csv_loader", "unstructured_loader", "advanced_pdf_loader", ] diff --git a/cognee/infrastructure/loaders/core/__init__.py b/cognee/infrastructure/loaders/core/__init__.py index 8a2df80f9..09819fbd2 100644 --- a/cognee/infrastructure/loaders/core/__init__.py +++ b/cognee/infrastructure/loaders/core/__init__.py @@ -3,5 +3,6 @@ from .text_loader import TextLoader from .audio_loader import AudioLoader from .image_loader import ImageLoader +from .csv_loader import CsvLoader -__all__ = ["TextLoader", "AudioLoader", "ImageLoader"] +__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"] diff --git a/cognee/infrastructure/loaders/core/csv_loader.py b/cognee/infrastructure/loaders/core/csv_loader.py new file mode 100644 index 000000000..a314a7a24 --- /dev/null +++ b/cognee/infrastructure/loaders/core/csv_loader.py @@ -0,0 +1,93 @@ +import os +from typing import List +import csv +from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface +from cognee.infrastructure.files.storage import get_file_storage, get_storage_config +from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata + + +class CsvLoader(LoaderInterface): + """ + Core CSV file loader that handles basic CSV file formats. + """ + + @property + def supported_extensions(self) -> List[str]: + """Supported text file extensions.""" + return [ + "csv", + ] + + @property + def supported_mime_types(self) -> List[str]: + """Supported MIME types for text content.""" + return [ + "text/csv", + ] + + @property + def loader_name(self) -> str: + """Unique identifier for this loader.""" + return "csv_loader" + + def can_handle(self, extension: str, mime_type: str) -> bool: + """ + Check if this loader can handle the given file. + + Args: + extension: File extension + mime_type: Optional MIME type + + Returns: + True if file can be handled, False otherwise + """ + if extension in self.supported_extensions and mime_type in self.supported_mime_types: + return True + + return False + + async def load(self, file_path: str, encoding: str = "utf-8", **kwargs): + """ + Load and process the csv file. + + Args: + file_path: Path to the file to load + encoding: Text encoding to use (default: utf-8) + **kwargs: Additional configuration (unused) + + Returns: + LoaderResult containing the file content and metadata + + Raises: + FileNotFoundError: If file doesn't exist + UnicodeDecodeError: If file cannot be decoded with specified encoding + OSError: If file cannot be read + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as f: + file_metadata = await get_file_metadata(f) + # Name ingested file of current loader based on original file content hash + storage_file_name = "text_" + file_metadata["content_hash"] + ".txt" + + row_texts = [] + row_index = 1 + + with open(file_path, "r", encoding=encoding, newline="") as file: + reader = csv.DictReader(file) + for row in reader: + pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()] + row_text = ", ".join(pairs) + row_texts.append(f"Row {row_index}:\n{row_text}\n") + row_index += 1 + + content = "\n".join(row_texts) + + storage_config = get_storage_config() + data_root_directory = storage_config["data_root_directory"] + storage = get_file_storage(data_root_directory) + + full_file_path = await storage.store(storage_file_name, content) + + return full_file_path diff --git a/cognee/infrastructure/loaders/core/text_loader.py b/cognee/infrastructure/loaders/core/text_loader.py index a6f94be9b..e478edb22 100644 --- a/cognee/infrastructure/loaders/core/text_loader.py +++ b/cognee/infrastructure/loaders/core/text_loader.py @@ -16,7 +16,7 @@ class TextLoader(LoaderInterface): @property def supported_extensions(self) -> List[str]: """Supported text file extensions.""" - return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"] + return ["txt", "md", "json", "xml", "yaml", "yml", "log"] @property def supported_mime_types(self) -> List[str]: @@ -24,7 +24,6 @@ class TextLoader(LoaderInterface): return [ "text/plain", "text/markdown", - "text/csv", "application/json", "text/xml", "application/xml", diff --git a/cognee/infrastructure/loaders/supported_loaders.py b/cognee/infrastructure/loaders/supported_loaders.py index d103babe3..b506df5f3 100644 --- a/cognee/infrastructure/loaders/supported_loaders.py +++ b/cognee/infrastructure/loaders/supported_loaders.py @@ -1,5 +1,5 @@ from cognee.infrastructure.loaders.external import PyPdfLoader -from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader +from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader # Registry for loader implementations supported_loaders = { @@ -7,6 +7,7 @@ supported_loaders = { TextLoader.loader_name: TextLoader, ImageLoader.loader_name: ImageLoader, AudioLoader.loader_name: AudioLoader, + CsvLoader.loader_name: CsvLoader, } # Try adding optional loaders diff --git a/cognee/modules/chunking/CsvChunker.py b/cognee/modules/chunking/CsvChunker.py new file mode 100644 index 000000000..4ba4a969e --- /dev/null +++ b/cognee/modules/chunking/CsvChunker.py @@ -0,0 +1,35 @@ +from cognee.shared.logging_utils import get_logger + + +from cognee.tasks.chunks import chunk_by_row +from cognee.modules.chunking.Chunker import Chunker +from .models.DocumentChunk import DocumentChunk + +logger = get_logger() + + +class CsvChunker(Chunker): + async def read(self): + async for content_text in self.get_text(): + if content_text is None: + continue + + for chunk_data in chunk_by_row(content_text, self.max_chunk_size): + if chunk_data["chunk_size"] <= self.max_chunk_size: + yield DocumentChunk( + id=chunk_data["chunk_id"], + text=chunk_data["text"], + chunk_size=chunk_data["chunk_size"], + is_part_of=self.document, + chunk_index=self.chunk_index, + cut_type=chunk_data["cut_type"], + contains=[], + metadata={ + "index_fields": ["text"], + }, + ) + self.chunk_index += 1 + else: + raise ValueError( + f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}" + ) diff --git a/cognee/modules/data/processing/document_types/CsvDocument.py b/cognee/modules/data/processing/document_types/CsvDocument.py new file mode 100644 index 000000000..3381275bd --- /dev/null +++ b/cognee/modules/data/processing/document_types/CsvDocument.py @@ -0,0 +1,33 @@ +import io +import csv +from typing import Type + +from cognee.modules.chunking.Chunker import Chunker +from cognee.infrastructure.files.utils.open_data_file import open_data_file +from .Document import Document + + +class CsvDocument(Document): + type: str = "csv" + mime_type: str = "text/csv" + + async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int): + async def get_text(): + async with open_data_file( + self.raw_data_location, mode="r", encoding="utf-8", newline="" + ) as file: + content = file.read() + file_like_obj = io.StringIO(content) + reader = csv.DictReader(file_like_obj) + + for row in reader: + pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()] + row_text = ", ".join(pairs) + if not row_text.strip(): + break + yield row_text + + chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text) + + async for chunk in chunker.read(): + yield chunk diff --git a/cognee/modules/data/processing/document_types/__init__.py b/cognee/modules/data/processing/document_types/__init__.py index 2e862f4ba..133dd53f8 100644 --- a/cognee/modules/data/processing/document_types/__init__.py +++ b/cognee/modules/data/processing/document_types/__init__.py @@ -4,3 +4,4 @@ from .TextDocument import TextDocument from .ImageDocument import ImageDocument from .AudioDocument import AudioDocument from .UnstructuredDocument import UnstructuredDocument +from .CsvDocument import CsvDocument diff --git a/cognee/tasks/chunks/__init__.py b/cognee/tasks/chunks/__init__.py index 22ce96be8..37d4de73e 100644 --- a/cognee/tasks/chunks/__init__.py +++ b/cognee/tasks/chunks/__init__.py @@ -1,4 +1,5 @@ from .chunk_by_word import chunk_by_word from .chunk_by_sentence import chunk_by_sentence from .chunk_by_paragraph import chunk_by_paragraph +from .chunk_by_row import chunk_by_row from .remove_disconnected_chunks import remove_disconnected_chunks diff --git a/cognee/tasks/chunks/chunk_by_row.py b/cognee/tasks/chunks/chunk_by_row.py new file mode 100644 index 000000000..8daf13689 --- /dev/null +++ b/cognee/tasks/chunks/chunk_by_row.py @@ -0,0 +1,94 @@ +from typing import Any, Dict, Iterator +from uuid import NAMESPACE_OID, uuid5 + +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine + + +def _get_pair_size(pair_text: str) -> int: + """ + Calculate the size of a given text in terms of tokens. + + If an embedding engine's tokenizer is available, count the tokens for the provided word. + If the tokenizer is not available, assume the word counts as one token. + + Parameters: + ----------- + + - pair_text (str): The key:value pair text for which the token size is to be calculated. + + Returns: + -------- + + - int: The number of tokens representing the text, typically an integer, depending + on the tokenizer's output. + """ + embedding_engine = get_embedding_engine() + if embedding_engine.tokenizer: + return embedding_engine.tokenizer.count_tokens(pair_text) + else: + return 3 + + +def chunk_by_row( + data: str, + max_chunk_size, +) -> Iterator[Dict[str, Any]]: + """ + Chunk the input text by row while enabling exact text reconstruction. + + This function divides the given text data into smaller chunks on a line-by-line basis, + ensuring that the size of each chunk is less than or equal to the specified maximum + chunk size. It guarantees that when the generated chunks are concatenated, they + reproduce the original text accurately. The tokenization process is handled by + adapters compatible with the vector engine's embedding model. + + Parameters: + ----------- + + - data (str): The input text to be chunked. + - max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or + words. + """ + current_chunk_list = [] + chunk_index = 0 + current_chunk_size = 0 + + lines = data.split("\n\n") + for line in lines: + pairs_text = line.split(", ") + + for pair_text in pairs_text: + pair_size = _get_pair_size(pair_text) + if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size): + # Yield current cut chunk + current_chunk = ", ".join(current_chunk_list) + chunk_dict = { + "text": current_chunk, + "chunk_size": current_chunk_size, + "chunk_id": uuid5(NAMESPACE_OID, current_chunk), + "chunk_index": chunk_index, + "cut_type": "row_cut", + } + + yield chunk_dict + + # Start new chunk with current pair text + current_chunk_list = [] + current_chunk_size = 0 + chunk_index += 1 + + current_chunk_list.append(pair_text) + current_chunk_size += pair_size + + # Yield row chunk + current_chunk = ", ".join(current_chunk_list) + if current_chunk: + chunk_dict = { + "text": current_chunk, + "chunk_size": current_chunk_size, + "chunk_id": uuid5(NAMESPACE_OID, current_chunk), + "chunk_index": chunk_index, + "cut_type": "row_end", + } + + yield chunk_dict diff --git a/cognee/tasks/documents/classify_documents.py b/cognee/tasks/documents/classify_documents.py index 9fa512906..e4f13ebd1 100644 --- a/cognee/tasks/documents/classify_documents.py +++ b/cognee/tasks/documents/classify_documents.py @@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import ( ImageDocument, TextDocument, UnstructuredDocument, + CsvDocument, ) from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.engine.utils.generate_node_id import generate_node_id @@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError EXTENSION_TO_DOCUMENT_CLASS = { "pdf": PdfDocument, # Text documents "txt": TextDocument, + "csv": CsvDocument, "docx": UnstructuredDocument, "doc": UnstructuredDocument, "odt": UnstructuredDocument, diff --git a/cognee/tests/integration/documents/CsvDocument_test.py b/cognee/tests/integration/documents/CsvDocument_test.py new file mode 100644 index 000000000..421bb81bd --- /dev/null +++ b/cognee/tests/integration/documents/CsvDocument_test.py @@ -0,0 +1,70 @@ +import os +import sys +import uuid +import pytest +import pathlib +from unittest.mock import patch + +from cognee.modules.chunking.CsvChunker import CsvChunker +from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument +from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine +from cognee.tests.integration.documents.async_gen_zip import async_gen_zip + +chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row") + + +GROUND_TRUTH = { + "chunk_size_10": [ + {"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0}, + {"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1}, + {"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2}, + {"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3}, + ], + "chunk_size_128": [ + {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0}, + {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1}, + ], +} + + +@pytest.mark.parametrize( + "input_file,chunk_size", + [("example_with_header.csv", 10), ("example_with_header.csv", 128)], +) +@patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine) +@pytest.mark.asyncio +async def test_CsvDocument(mock_engine, input_file, chunk_size): + # Define file paths of test data + csv_file_path = os.path.join( + pathlib.Path(__file__).parent.parent.parent, + "test_data", + input_file, + ) + + # Define test documents + csv_document = CsvDocument( + id=uuid.uuid4(), + name="example_with_header.csv", + raw_data_location=csv_file_path, + external_metadata="", + mime_type="text/csv", + ) + + # TEST CSV + ground_truth_key = f"chunk_size_{chunk_size}" + async for ground_truth, row_data in async_gen_zip( + GROUND_TRUTH[ground_truth_key], + csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size), + ): + assert ground_truth["token_count"] == row_data.chunk_size, ( + f'{ground_truth["token_count"] = } != {row_data.chunk_size = }' + ) + assert ground_truth["len_text"] == len(row_data.text), ( + f'{ground_truth["len_text"] = } != {len(row_data.text) = }' + ) + assert ground_truth["cut_type"] == row_data.cut_type, ( + f'{ground_truth["cut_type"] = } != {row_data.cut_type = }' + ) + assert ground_truth["chunk_index"] == row_data.chunk_index, ( + f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }' + ) diff --git a/cognee/tests/test_data/example_with_header.csv b/cognee/tests/test_data/example_with_header.csv new file mode 100644 index 000000000..dc900e5ef --- /dev/null +++ b/cognee/tests/test_data/example_with_header.csv @@ -0,0 +1,3 @@ +id,name,age,city,country +1,Eric,30,Beijing,China +2,Joe,35,Berlin,Germany diff --git a/cognee/tests/unit/processing/chunks/chunk_by_row_test.py b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py new file mode 100644 index 000000000..7d6a73a06 --- /dev/null +++ b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py @@ -0,0 +1,52 @@ +from itertools import product + +import numpy as np +import pytest + +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine +from cognee.tasks.chunks import chunk_by_row + +INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA" +max_chunk_size_vals = [8, 32] + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_chunk_by_row_isomorphism(input_text, max_chunk_size): + chunks = chunk_by_row(input_text, max_chunk_size) + reconstructed_text = ", ".join([chunk["text"] for chunk in chunks]) + assert reconstructed_text == input_text, ( + f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" + ) + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_row_chunk_length(input_text, max_chunk_size): + chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)) + embedding_engine = get_embedding_engine() + + chunk_lengths = np.array( + [embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks] + ) + + larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size] + assert np.all(chunk_lengths <= max_chunk_size), ( + f"{max_chunk_size = }: {larger_chunks} are too large" + ) + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size): + chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size) + chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) + assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( + f"{chunk_indices = } are not monotonically increasing" + ) From 8566516ceca89a0e85db6aa5ba967f5d8070b2c7 Mon Sep 17 00:00:00 2001 From: EricXiao Date: Wed, 22 Oct 2025 16:59:07 +0800 Subject: [PATCH 02/24] chore: Remove local test code Signed-off-by: EricXiao --- .../loaders/external/advanced_pdf_loader.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py index 6d1412b77..4b3ba296a 100644 --- a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py +++ b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py @@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface): if value is None: return "" return str(value).replace("\xa0", " ").strip() - - -if __name__ == "__main__": - loader = AdvancedPdfLoader() - asyncio.run( - loader.load( - "/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf" - ) - ) From a4a9e762465ecaf0dcdb9b0132db7951d11b437c Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Sun, 2 Nov 2025 17:05:03 +0500 Subject: [PATCH 03/24] feat: add ontology endpoint in REST API - Add POST /api/v1/ontologies endpoint for file upload - Add GET /api/v1/ontologies endpoint for listing ontologies - Implement OntologyService for file management and metadata - Integrate ontology_key parameter in cognify endpoint - Update RDFLibOntologyResolver to support file-like objects - Add essential test suite for ontology endpoints --- cognee/api/client.py | 3 + .../v1/cognify/routers/get_cognify_router.py | 31 +++++- cognee/api/v1/ontologies/__init__.py | 4 + cognee/api/v1/ontologies/ontologies.py | 101 ++++++++++++++++++ cognee/api/v1/ontologies/routers/__init__.py | 0 .../ontologies/routers/get_ontology_router.py | 89 +++++++++++++++ .../rdf_xml/RDFLibOntologyResolver.py | 69 +++++++----- cognee/tests/test_ontology_endpoint.py | 89 +++++++++++++++ 8 files changed, 356 insertions(+), 30 deletions(-) create mode 100644 cognee/api/v1/ontologies/__init__.py create mode 100644 cognee/api/v1/ontologies/ontologies.py create mode 100644 cognee/api/v1/ontologies/routers/__init__.py create mode 100644 cognee/api/v1/ontologies/routers/get_ontology_router.py create mode 100644 cognee/tests/test_ontology_endpoint.py diff --git a/cognee/api/client.py b/cognee/api/client.py index 6766c12de..89e9eb2f5 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -23,6 +23,7 @@ from cognee.api.v1.settings.routers import get_settings_router from cognee.api.v1.datasets.routers import get_datasets_router from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router from cognee.api.v1.search.routers import get_search_router +from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router from cognee.api.v1.memify.routers import get_memify_router from cognee.api.v1.add.routers import get_add_router from cognee.api.v1.delete.routers import get_delete_router @@ -258,6 +259,8 @@ app.include_router( app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"]) +app.include_router(get_ontology_router(), prefix="/api/v1/ontologies", tags=["ontologies"]) + app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"]) app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"]) diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 231bbcd11..246cc6c56 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -41,6 +41,10 @@ class CognifyPayloadDTO(InDTO): custom_prompt: Optional[str] = Field( default="", description="Custom prompt for entity extraction and graph generation" ) + ontology_key: Optional[str] = Field( + default=None, + description="Reference to previously uploaded ontology" + ) def get_cognify_router() -> APIRouter: @@ -68,6 +72,7 @@ def get_cognify_router() -> APIRouter: - **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted). - **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking). - **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction. + - **ontology_key** (Optional[str]): Reference to a previously uploaded ontology file to use for knowledge graph construction. ## Response - **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status @@ -82,7 +87,8 @@ def get_cognify_router() -> APIRouter: { "datasets": ["research_papers", "documentation"], "run_in_background": false, - "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections." + "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.", + "ontology_key": "medical_ontology_v1" } ``` @@ -108,13 +114,36 @@ def get_cognify_router() -> APIRouter: ) from cognee.api.v1.cognify import cognify as cognee_cognify + from cognee.api.v1.ontologies.ontologies import OntologyService try: datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets + config_to_use = None + + if payload.ontology_key: + ontology_service = OntologyService() + try: + ontology_content = ontology_service.get_ontology_content(payload.ontology_key, user) + + from cognee.modules.ontology.ontology_config import Config + from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver + from io import StringIO + + ontology_stream = StringIO(ontology_content) + config_to_use: Config = { + "ontology_config": { + "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_stream) + } + } + except ValueError as e: + return JSONResponse( + status_code=400, content={"error": f"Ontology error: {str(e)}"} + ) cognify_run = await cognee_cognify( datasets, user, + config=config_to_use, run_in_background=payload.run_in_background, custom_prompt=payload.custom_prompt, ) diff --git a/cognee/api/v1/ontologies/__init__.py b/cognee/api/v1/ontologies/__init__.py new file mode 100644 index 000000000..c25064edc --- /dev/null +++ b/cognee/api/v1/ontologies/__init__.py @@ -0,0 +1,4 @@ +from .ontologies import OntologyService +from .routers.get_ontology_router import get_ontology_router + +__all__ = ["OntologyService", "get_ontology_router"] \ No newline at end of file diff --git a/cognee/api/v1/ontologies/ontologies.py b/cognee/api/v1/ontologies/ontologies.py new file mode 100644 index 000000000..fb7f3cd9a --- /dev/null +++ b/cognee/api/v1/ontologies/ontologies.py @@ -0,0 +1,101 @@ +import os +import json +import tempfile +from pathlib import Path +from datetime import datetime, timezone +from typing import Optional +from dataclasses import dataclass + +@dataclass +class OntologyMetadata: + ontology_key: str + filename: str + size_bytes: int + uploaded_at: str + description: Optional[str] = None + +class OntologyService: + def __init__(self): + pass + + @property + def base_dir(self) -> Path: + return Path(tempfile.gettempdir()) / "ontologies" + + def _get_user_dir(self, user_id: str) -> Path: + user_dir = self.base_dir / str(user_id) + user_dir.mkdir(parents=True, exist_ok=True) + return user_dir + + def _get_metadata_path(self, user_dir: Path) -> Path: + return user_dir / "metadata.json" + + def _load_metadata(self, user_dir: Path) -> dict: + metadata_path = self._get_metadata_path(user_dir) + if metadata_path.exists(): + with open(metadata_path, 'r') as f: + return json.load(f) + return {} + + def _save_metadata(self, user_dir: Path, metadata: dict): + metadata_path = self._get_metadata_path(user_dir) + with open(metadata_path, 'w') as f: + json.dump(metadata, f, indent=2) + + async def upload_ontology(self, ontology_key: str, file, user, description: Optional[str] = None) -> OntologyMetadata: + # Validate file format + if not file.filename.lower().endswith('.owl'): + raise ValueError("File must be in .owl format") + + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + # Check for duplicate key + if ontology_key in metadata: + raise ValueError(f"Ontology key '{ontology_key}' already exists") + + # Read file content + content = await file.read() + if len(content) > 10 * 1024 * 1024: # 10MB limit + raise ValueError("File size exceeds 10MB limit") + + # Save file + file_path = user_dir / f"{ontology_key}.owl" + with open(file_path, 'wb') as f: + f.write(content) + + # Update metadata + ontology_metadata = { + "filename": file.filename, + "size_bytes": len(content), + "uploaded_at": datetime.now(timezone.utc).isoformat(), + "description": description + } + metadata[ontology_key] = ontology_metadata + self._save_metadata(user_dir, metadata) + + return OntologyMetadata( + ontology_key=ontology_key, + filename=file.filename, + size_bytes=len(content), + uploaded_at=ontology_metadata["uploaded_at"], + description=description + ) + + def get_ontology_content(self, ontology_key: str, user) -> str: + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + if ontology_key not in metadata: + raise ValueError(f"Ontology key '{ontology_key}' not found") + + file_path = user_dir / f"{ontology_key}.owl" + if not file_path.exists(): + raise ValueError(f"Ontology file for key '{ontology_key}' not found") + + with open(file_path, 'r', encoding='utf-8') as f: + return f.read() + + def list_ontologies(self, user) -> dict: + user_dir = self._get_user_dir(str(user.id)) + return self._load_metadata(user_dir) \ No newline at end of file diff --git a/cognee/api/v1/ontologies/routers/__init__.py b/cognee/api/v1/ontologies/routers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py new file mode 100644 index 000000000..c171fa7bb --- /dev/null +++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py @@ -0,0 +1,89 @@ +from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException +from fastapi.responses import JSONResponse +from typing import Optional + +from cognee.modules.users.models import User +from cognee.modules.users.methods import get_authenticated_user +from cognee.shared.utils import send_telemetry +from cognee import __version__ as cognee_version +from ..ontologies import OntologyService + +def get_ontology_router() -> APIRouter: + router = APIRouter() + ontology_service = OntologyService() + + @router.post("", response_model=dict) + async def upload_ontology( + ontology_key: str = Form(...), + ontology_file: UploadFile = File(...), + description: Optional[str] = Form(None), + user: User = Depends(get_authenticated_user) + ): + """ + Upload an ontology file with a named key for later use in cognify operations. + + ## Request Parameters + - **ontology_key** (str): User-defined identifier for the ontology + - **ontology_file** (UploadFile): OWL format ontology file + - **description** (Optional[str]): Optional description of the ontology + + ## Response + Returns metadata about the uploaded ontology including key, filename, size, and upload timestamp. + + ## Error Codes + - **400 Bad Request**: Invalid file format, duplicate key, file size exceeded + - **500 Internal Server Error**: File system or processing errors + """ + send_telemetry( + "Ontology Upload API Endpoint Invoked", + user.id, + additional_properties={ + "endpoint": "POST /api/v1/ontologies", + "cognee_version": cognee_version, + }, + ) + + try: + result = await ontology_service.upload_ontology( + ontology_key, ontology_file, user, description + ) + return { + "ontology_key": result.ontology_key, + "filename": result.filename, + "size_bytes": result.size_bytes, + "uploaded_at": result.uploaded_at + } + except ValueError as e: + return JSONResponse(status_code=400, content={"error": str(e)}) + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) + + @router.get("", response_model=dict) + async def list_ontologies( + user: User = Depends(get_authenticated_user) + ): + """ + List all uploaded ontologies for the authenticated user. + + ## Response + Returns a dictionary mapping ontology keys to their metadata including filename, size, and upload timestamp. + + ## Error Codes + - **500 Internal Server Error**: File system or processing errors + """ + send_telemetry( + "Ontology List API Endpoint Invoked", + user.id, + additional_properties={ + "endpoint": "GET /api/v1/ontologies", + "cognee_version": cognee_version, + }, + ) + + try: + metadata = ontology_service.list_ontologies(user) + return metadata + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) + + return router \ No newline at end of file diff --git a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py index 45e32936a..4acc8861b 100644 --- a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +++ b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py @@ -2,7 +2,7 @@ import os import difflib from cognee.shared.logging_utils import get_logger from collections import deque -from typing import List, Tuple, Dict, Optional, Any, Union +from typing import List, Tuple, Dict, Optional, Any, Union, IO from rdflib import Graph, URIRef, RDF, RDFS, OWL from cognee.modules.ontology.exceptions import ( @@ -26,44 +26,55 @@ class RDFLibOntologyResolver(BaseOntologyResolver): def __init__( self, - ontology_file: Optional[Union[str, List[str]]] = None, + ontology_file: Optional[Union[str, List[str], IO]] = None, matching_strategy: Optional[MatchingStrategy] = None, ) -> None: super().__init__(matching_strategy) self.ontology_file = ontology_file try: - files_to_load = [] + self.graph = None if ontology_file is not None: - if isinstance(ontology_file, str): - files_to_load = [ontology_file] - elif isinstance(ontology_file, list): - files_to_load = ontology_file + if hasattr(ontology_file, "read"): + self.graph = Graph() + content = ontology_file.read() + self.graph.parse(data=content, format="xml") + logger.info("Ontology loaded successfully from file object") else: - raise ValueError( - f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}" - ) - - if files_to_load: - self.graph = Graph() - loaded_files = [] - for file_path in files_to_load: - if os.path.exists(file_path): - self.graph.parse(file_path) - loaded_files.append(file_path) - logger.info("Ontology loaded successfully from file: %s", file_path) + files_to_load = [] + if isinstance(ontology_file, str): + files_to_load = [ontology_file] + elif isinstance(ontology_file, list): + files_to_load = ontology_file else: - logger.warning( - "Ontology file '%s' not found. Skipping this file.", - file_path, + raise ValueError( + f"ontology_file must be a string, list of strings, file-like object, or None. Got: {type(ontology_file)}" ) - if not loaded_files: - logger.info( - "No valid ontology files found. No owl ontology will be attached to the graph." - ) - self.graph = None - else: - logger.info("Total ontology files loaded: %d", len(loaded_files)) + if files_to_load: + self.graph = Graph() + loaded_files = [] + for file_path in files_to_load: + if os.path.exists(file_path): + self.graph.parse(file_path) + loaded_files.append(file_path) + logger.info("Ontology loaded successfully from file: %s", file_path) + else: + logger.warning( + "Ontology file '%s' not found. Skipping this file.", + file_path, + ) + + if not loaded_files: + logger.info( + "No valid ontology files found. No owl ontology will be attached to the graph." + ) + self.graph = None + else: + logger.info("Total ontology files loaded: %d", len(loaded_files)) + else: + logger.info( + "No ontology file provided. No owl ontology will be attached to the graph." + ) else: logger.info( "No ontology file provided. No owl ontology will be attached to the graph." diff --git a/cognee/tests/test_ontology_endpoint.py b/cognee/tests/test_ontology_endpoint.py new file mode 100644 index 000000000..4849f8649 --- /dev/null +++ b/cognee/tests/test_ontology_endpoint.py @@ -0,0 +1,89 @@ +import pytest +import uuid +from fastapi.testclient import TestClient +from unittest.mock import patch, Mock, AsyncMock +from types import SimpleNamespace +import importlib +from cognee.api.client import app + +gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") + +@pytest.fixture +def client(): + return TestClient(app) + +@pytest.fixture +def mock_user(): + user = Mock() + user.id = "test-user-123" + return user + +@pytest.fixture +def mock_default_user(): + """Mock default user for testing.""" + return SimpleNamespace( + id=uuid.uuid4(), + email="default@example.com", + is_active=True, + tenant_id=uuid.uuid4() + ) + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_success(mock_get_default_user, client, mock_default_user): + """Test successful ontology upload""" + mock_get_default_user.return_value = mock_default_user + ontology_content = b"" + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + + response = client.post( + "/api/v1/ontologies", + files={"ontology_file": ("test.owl", ontology_content)}, + data={"ontology_key": unique_key, "description": "Test"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["ontology_key"] == unique_key + assert "uploaded_at" in data + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user): + """Test 400 response for non-.owl files""" + mock_get_default_user.return_value = mock_default_user + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + response = client.post( + "/api/v1/ontologies", + files={"ontology_file": ("test.txt", b"not xml")}, + data={"ontology_key": unique_key} + ) + assert response.status_code == 400 + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user): + """Test 400 response for missing file or key""" + mock_get_default_user.return_value = mock_default_user + # Missing file + response = client.post("/api/v1/ontologies", data={"ontology_key": "test"}) + assert response.status_code == 400 + + # Missing key + response = client.post("/api/v1/ontologies", files={"ontology_file": ("test.owl", b"xml")}) + assert response.status_code == 400 + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user): + """Test behavior when default user is provided (no explicit authentication)""" + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + mock_get_default_user.return_value = mock_default_user + response = client.post( + "/api/v1/ontologies", + files={"ontology_file": ("test.owl", b"")}, + data={"ontology_key": unique_key} + ) + + # The current system provides a default user when no explicit authentication is given + # This test verifies the system works with conditional authentication + assert response.status_code == 200 + data = response.json() + assert data["ontology_key"] == unique_key + assert "uploaded_at" in data \ No newline at end of file From 79bd2b2576b913528feb92ca6832242133bf9822 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 6 Nov 2025 09:48:01 +0100 Subject: [PATCH 04/24] chore: fixes ruff formatting --- .../v1/cognify/routers/get_cognify_router.py | 15 ++++++++---- cognee/api/v1/ontologies/__init__.py | 2 +- cognee/api/v1/ontologies/ontologies.py | 22 ++++++++++------- .../ontologies/routers/get_ontology_router.py | 11 ++++----- cognee/tests/test_ontology_endpoint.py | 24 ++++++++++++------- 5 files changed, 44 insertions(+), 30 deletions(-) diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 246cc6c56..252ffe7bf 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -42,8 +42,7 @@ class CognifyPayloadDTO(InDTO): default="", description="Custom prompt for entity extraction and graph generation" ) ontology_key: Optional[str] = Field( - default=None, - description="Reference to previously uploaded ontology" + default=None, description="Reference to previously uploaded ontology" ) @@ -123,16 +122,22 @@ def get_cognify_router() -> APIRouter: if payload.ontology_key: ontology_service = OntologyService() try: - ontology_content = ontology_service.get_ontology_content(payload.ontology_key, user) + ontology_content = ontology_service.get_ontology_content( + payload.ontology_key, user + ) from cognee.modules.ontology.ontology_config import Config - from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver + from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import ( + RDFLibOntologyResolver, + ) from io import StringIO ontology_stream = StringIO(ontology_content) config_to_use: Config = { "ontology_config": { - "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_stream) + "ontology_resolver": RDFLibOntologyResolver( + ontology_file=ontology_stream + ) } } except ValueError as e: diff --git a/cognee/api/v1/ontologies/__init__.py b/cognee/api/v1/ontologies/__init__.py index c25064edc..b90d46c3d 100644 --- a/cognee/api/v1/ontologies/__init__.py +++ b/cognee/api/v1/ontologies/__init__.py @@ -1,4 +1,4 @@ from .ontologies import OntologyService from .routers.get_ontology_router import get_ontology_router -__all__ = ["OntologyService", "get_ontology_router"] \ No newline at end of file +__all__ = ["OntologyService", "get_ontology_router"] diff --git a/cognee/api/v1/ontologies/ontologies.py b/cognee/api/v1/ontologies/ontologies.py index fb7f3cd9a..6bfb7658e 100644 --- a/cognee/api/v1/ontologies/ontologies.py +++ b/cognee/api/v1/ontologies/ontologies.py @@ -6,6 +6,7 @@ from datetime import datetime, timezone from typing import Optional from dataclasses import dataclass + @dataclass class OntologyMetadata: ontology_key: str @@ -14,6 +15,7 @@ class OntologyMetadata: uploaded_at: str description: Optional[str] = None + class OntologyService: def __init__(self): pass @@ -33,18 +35,20 @@ class OntologyService: def _load_metadata(self, user_dir: Path) -> dict: metadata_path = self._get_metadata_path(user_dir) if metadata_path.exists(): - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: return json.load(f) return {} def _save_metadata(self, user_dir: Path, metadata: dict): metadata_path = self._get_metadata_path(user_dir) - with open(metadata_path, 'w') as f: + with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2) - async def upload_ontology(self, ontology_key: str, file, user, description: Optional[str] = None) -> OntologyMetadata: + async def upload_ontology( + self, ontology_key: str, file, user, description: Optional[str] = None + ) -> OntologyMetadata: # Validate file format - if not file.filename.lower().endswith('.owl'): + if not file.filename.lower().endswith(".owl"): raise ValueError("File must be in .owl format") user_dir = self._get_user_dir(str(user.id)) @@ -61,7 +65,7 @@ class OntologyService: # Save file file_path = user_dir / f"{ontology_key}.owl" - with open(file_path, 'wb') as f: + with open(file_path, "wb") as f: f.write(content) # Update metadata @@ -69,7 +73,7 @@ class OntologyService: "filename": file.filename, "size_bytes": len(content), "uploaded_at": datetime.now(timezone.utc).isoformat(), - "description": description + "description": description, } metadata[ontology_key] = ontology_metadata self._save_metadata(user_dir, metadata) @@ -79,7 +83,7 @@ class OntologyService: filename=file.filename, size_bytes=len(content), uploaded_at=ontology_metadata["uploaded_at"], - description=description + description=description, ) def get_ontology_content(self, ontology_key: str, user) -> str: @@ -93,9 +97,9 @@ class OntologyService: if not file_path.exists(): raise ValueError(f"Ontology file for key '{ontology_key}' not found") - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: return f.read() def list_ontologies(self, user) -> dict: user_dir = self._get_user_dir(str(user.id)) - return self._load_metadata(user_dir) \ No newline at end of file + return self._load_metadata(user_dir) diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py index c171fa7bb..f5c51ba21 100644 --- a/cognee/api/v1/ontologies/routers/get_ontology_router.py +++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py @@ -8,6 +8,7 @@ from cognee.shared.utils import send_telemetry from cognee import __version__ as cognee_version from ..ontologies import OntologyService + def get_ontology_router() -> APIRouter: router = APIRouter() ontology_service = OntologyService() @@ -17,7 +18,7 @@ def get_ontology_router() -> APIRouter: ontology_key: str = Form(...), ontology_file: UploadFile = File(...), description: Optional[str] = Form(None), - user: User = Depends(get_authenticated_user) + user: User = Depends(get_authenticated_user), ): """ Upload an ontology file with a named key for later use in cognify operations. @@ -51,7 +52,7 @@ def get_ontology_router() -> APIRouter: "ontology_key": result.ontology_key, "filename": result.filename, "size_bytes": result.size_bytes, - "uploaded_at": result.uploaded_at + "uploaded_at": result.uploaded_at, } except ValueError as e: return JSONResponse(status_code=400, content={"error": str(e)}) @@ -59,9 +60,7 @@ def get_ontology_router() -> APIRouter: return JSONResponse(status_code=500, content={"error": str(e)}) @router.get("", response_model=dict) - async def list_ontologies( - user: User = Depends(get_authenticated_user) - ): + async def list_ontologies(user: User = Depends(get_authenticated_user)): """ List all uploaded ontologies for the authenticated user. @@ -86,4 +85,4 @@ def get_ontology_router() -> APIRouter: except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) - return router \ No newline at end of file + return router diff --git a/cognee/tests/test_ontology_endpoint.py b/cognee/tests/test_ontology_endpoint.py index 4849f8649..b5cedfafe 100644 --- a/cognee/tests/test_ontology_endpoint.py +++ b/cognee/tests/test_ontology_endpoint.py @@ -8,37 +8,40 @@ from cognee.api.client import app gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") + @pytest.fixture def client(): return TestClient(app) + @pytest.fixture def mock_user(): user = Mock() user.id = "test-user-123" return user + @pytest.fixture def mock_default_user(): """Mock default user for testing.""" return SimpleNamespace( - id=uuid.uuid4(), - email="default@example.com", - is_active=True, - tenant_id=uuid.uuid4() + id=uuid.uuid4(), email="default@example.com", is_active=True, tenant_id=uuid.uuid4() ) + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_success(mock_get_default_user, client, mock_default_user): """Test successful ontology upload""" mock_get_default_user.return_value = mock_default_user - ontology_content = b"" + ontology_content = ( + b"" + ) unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" response = client.post( "/api/v1/ontologies", files={"ontology_file": ("test.owl", ontology_content)}, - data={"ontology_key": unique_key, "description": "Test"} + data={"ontology_key": unique_key, "description": "Test"}, ) assert response.status_code == 200 @@ -46,6 +49,7 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use assert data["ontology_key"] == unique_key assert "uploaded_at" in data + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user): """Test 400 response for non-.owl files""" @@ -54,10 +58,11 @@ def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_defaul response = client.post( "/api/v1/ontologies", files={"ontology_file": ("test.txt", b"not xml")}, - data={"ontology_key": unique_key} + data={"ontology_key": unique_key}, ) assert response.status_code == 400 + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user): """Test 400 response for missing file or key""" @@ -70,6 +75,7 @@ def test_upload_ontology_missing_data(mock_get_default_user, client, mock_defaul response = client.post("/api/v1/ontologies", files={"ontology_file": ("test.owl", b"xml")}) assert response.status_code == 400 + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user): """Test behavior when default user is provided (no explicit authentication)""" @@ -78,7 +84,7 @@ def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_defaul response = client.post( "/api/v1/ontologies", files={"ontology_file": ("test.owl", b"")}, - data={"ontology_key": unique_key} + data={"ontology_key": unique_key}, ) # The current system provides a default user when no explicit authentication is given @@ -86,4 +92,4 @@ def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_defaul assert response.status_code == 200 data = response.json() assert data["ontology_key"] == unique_key - assert "uploaded_at" in data \ No newline at end of file + assert "uploaded_at" in data From a058250c95390df84f1c44419cff8adf3a4e8269 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 6 Nov 2025 13:03:11 +0100 Subject: [PATCH 05/24] fix: add cognee to the local run environment --- cognee/modules/notebooks/operations/run_in_local_sandbox.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cognee/modules/notebooks/operations/run_in_local_sandbox.py b/cognee/modules/notebooks/operations/run_in_local_sandbox.py index 071deafb7..46499186e 100644 --- a/cognee/modules/notebooks/operations/run_in_local_sandbox.py +++ b/cognee/modules/notebooks/operations/run_in_local_sandbox.py @@ -2,6 +2,8 @@ import io import sys import traceback +import cognee + def wrap_in_async_handler(user_code: str) -> str: return ( @@ -34,6 +36,7 @@ def run_in_local_sandbox(code, environment=None, loop=None): environment["print"] = customPrintFunction environment["running_loop"] = loop + environment["cognee"] = cognee try: exec(code, environment) From f9cde2f375be2accf6c9bd7fb5f5c681971f692a Mon Sep 17 00:00:00 2001 From: Pavel Zorin Date: Thu, 13 Nov 2025 13:35:07 +0100 Subject: [PATCH 06/24] Fix: Remove cognee script from pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 13266f83e..2436911e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,7 +156,6 @@ Homepage = "https://www.cognee.ai" Repository = "https://github.com/topoteretes/cognee" [project.scripts] -cognee = "cognee.cli._cognee:main" cognee-cli = "cognee.cli._cognee:main" [build-system] From 3b7d030817cea67f08af121d936a9e31312ae38c Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 13 Nov 2025 16:06:07 +0100 Subject: [PATCH 07/24] fix: remove duplicate mistral adapter creation --- .../litellm_instructor/llm/get_llm_client.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index c7dcecc56..bbdfe49e9 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -162,20 +162,5 @@ def get_llm_client(raise_api_key_error: bool = True): endpoint=llm_config.llm_endpoint, ) - elif provider == LLMProvider.MISTRAL: - if llm_config.llm_api_key is None: - raise LLMAPIKeyNotSetError() - - from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import ( - MistralAdapter, - ) - - return MistralAdapter( - api_key=llm_config.llm_api_key, - model=llm_config.llm_model, - max_completion_tokens=max_completion_tokens, - endpoint=llm_config.llm_endpoint, - ) - else: raise UnsupportedLLMProviderError(provider) From c6454338f9374c0e871938523eca237d6e5a1d16 Mon Sep 17 00:00:00 2001 From: Pavel Zorin Date: Thu, 13 Nov 2025 17:35:16 +0100 Subject: [PATCH 08/24] Fix: MCP remove cognee.add() preprequisite from the doc --- cognee-mcp/src/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index 7c708638c..4131be988 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -194,7 +194,6 @@ async def cognify( Prerequisites: - **LLM_API_KEY**: Must be configured (required for entity extraction and graph generation) - - **Data Added**: Must have data previously added via `cognee.add()` - **Vector Database**: Must be accessible for embeddings storage - **Graph Database**: Must be accessible for relationship storage From 2337d36f7b3968cfeff06b00613f7464c8d0ca93 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 13 Nov 2025 18:25:07 +0100 Subject: [PATCH 09/24] feat: add variable to control instructor mode --- cognee/infrastructure/llm/config.py | 2 ++ .../litellm_instructor/llm/anthropic/adapter.py | 8 +++++++- .../litellm_instructor/llm/gemini/adapter.py | 12 +++++++++++- .../llm/generic_llm_api/adapter.py | 12 +++++++++++- .../litellm_instructor/llm/mistral/adapter.py | 8 +++++++- .../litellm_instructor/llm/ollama/adapter.py | 12 +++++++++++- .../litellm_instructor/llm/openai/adapter.py | 11 +++++++++-- 7 files changed, 58 insertions(+), 7 deletions(-) diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 8fd196eaf..c87054ff6 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -38,6 +38,7 @@ class LLMConfig(BaseSettings): """ structured_output_framework: str = "instructor" + llm_instructor_mode: Optional[str] = None llm_provider: str = "openai" llm_model: str = "openai/gpt-5-mini" llm_endpoint: str = "" @@ -181,6 +182,7 @@ class LLMConfig(BaseSettings): instance. """ return { + "llm_instructor_mode": self.llm_instructor_mode, "provider": self.llm_provider, "model": self.llm_model, "endpoint": self.llm_endpoint, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index bf19d6e86..6fb78718e 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -28,13 +28,19 @@ class AnthropicAdapter(LLMInterface): name = "Anthropic" model: str + default_instructor_mode = "anthropic_tools" def __init__(self, max_completion_tokens: int, model: str = None): import anthropic + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) + self.aclient = instructor.patch( create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create, - mode=instructor.Mode.ANTHROPIC_TOOLS, + mode=instructor.Mode(instructor_mode), ) self.model = model diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 1187e0cad..68dddc7b7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -41,6 +41,7 @@ class GeminiAdapter(LLMInterface): name: str model: str api_key: str + default_instructor_mode = "json_mode" def __init__( self, @@ -63,7 +64,16 @@ class GeminiAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) + from cognee.infrastructure.llm.config import get_llm_config + + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) + + self.aclient = instructor.from_litellm( + litellm.acompletion, mode=instructor.Mode(instructor_mode) + ) @retry( stop=stop_after_delay(128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 8bbbaa2cc..ea32dced1 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -41,6 +41,7 @@ class GenericAPIAdapter(LLMInterface): name: str model: str api_key: str + default_instructor_mode = "json_mode" def __init__( self, @@ -63,7 +64,16 @@ class GenericAPIAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) + from cognee.infrastructure.llm.config import get_llm_config + + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) + + self.aclient = instructor.from_litellm( + litellm.acompletion, mode=instructor.Mode(instructor_mode) + ) @retry( stop=stop_after_delay(128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index 78a3cbff5..bed88ce3c 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -37,6 +37,7 @@ class MistralAdapter(LLMInterface): model: str api_key: str max_completion_tokens: int + default_instructor_mode = "mistral_tools" def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None): from mistralai import Mistral @@ -44,9 +45,14 @@ class MistralAdapter(LLMInterface): self.model = model self.max_completion_tokens = max_completion_tokens + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) + self.aclient = instructor.from_litellm( litellm.acompletion, - mode=instructor.Mode.MISTRAL_TOOLS, + mode=instructor.Mode(instructor_mode), api_key=get_llm_config().llm_api_key, ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index 9c3d185aa..aa24a7911 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -42,6 +42,8 @@ class OllamaAPIAdapter(LLMInterface): - aclient """ + default_instructor_mode = "json_mode" + def __init__( self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int ): @@ -51,8 +53,16 @@ class OllamaAPIAdapter(LLMInterface): self.endpoint = endpoint self.max_completion_tokens = max_completion_tokens + from cognee.infrastructure.llm.config import get_llm_config + + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) + self.aclient = instructor.from_openai( - OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON + OpenAI(base_url=self.endpoint, api_key=self.api_key), + mode=instructor.Mode(instructor_mode), ) @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 305b426b8..69367602d 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -56,6 +56,7 @@ class OpenAIAdapter(LLMInterface): model: str api_key: str api_version: str + default_instructor_mode = "json_schema_mode" MAX_RETRIES = 5 @@ -74,14 +75,20 @@ class OpenAIAdapter(LLMInterface): fallback_api_key: str = None, fallback_endpoint: str = None, ): + from cognee.infrastructure.llm.config import get_llm_config + + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs. # Make sure all new gpt models will work with this mode as well. if "gpt-5" in model: self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA + litellm.acompletion, mode=instructor.Mode(instructor_mode) ) self.client = instructor.from_litellm( - litellm.completion, mode=instructor.Mode.JSON_SCHEMA + litellm.completion, mode=instructor.Mode(instructor_mode) ) else: self.aclient = instructor.from_litellm(litellm.acompletion) From 661c194f97df5053f70a52d1638c77c23e0d50e3 Mon Sep 17 00:00:00 2001 From: EricXiao Date: Fri, 14 Nov 2025 15:21:47 +0800 Subject: [PATCH 10/24] fix: Resolve issue with csv suffix classification Signed-off-by: EricXiao --- cognee/infrastructure/files/utils/guess_file_type.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cognee/infrastructure/files/utils/guess_file_type.py b/cognee/infrastructure/files/utils/guess_file_type.py index 78b20c93d..4bc96fe80 100644 --- a/cognee/infrastructure/files/utils/guess_file_type.py +++ b/cognee/infrastructure/files/utils/guess_file_type.py @@ -55,6 +55,10 @@ def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type file_type = Type("text/plain", "txt") return file_type + if ext in [".csv"]: + file_type = Type("text/csv", "csv") + return file_type + file_type = filetype.guess(file) # If file type could not be determined consider it a plain text file as they don't have magic number encoding From 205f5a9e0c6d0fb72cc94fa20ce2ba814ebef0d5 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 14 Nov 2025 11:05:39 +0100 Subject: [PATCH 11/24] fix: Fix based on PR comments --- .../litellm_instructor/llm/anthropic/adapter.py | 9 +++------ .../litellm_instructor/llm/gemini/adapter.py | 10 +++------- .../llm/generic_llm_api/adapter.py | 10 +++------- .../litellm_instructor/llm/get_llm_client.py | 9 ++++++++- .../litellm_instructor/llm/mistral/adapter.py | 16 ++++++++++------ .../litellm_instructor/llm/ollama/adapter.py | 17 +++++++++-------- .../litellm_instructor/llm/openai/adapter.py | 12 ++++-------- 7 files changed, 40 insertions(+), 43 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index 6fb78718e..dbf0dfbea 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -30,17 +30,14 @@ class AnthropicAdapter(LLMInterface): model: str default_instructor_mode = "anthropic_tools" - def __init__(self, max_completion_tokens: int, model: str = None): + def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None): import anthropic - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.patch( create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create, - mode=instructor.Mode(instructor_mode), + mode=instructor.Mode(self.instructor_mode), ) self.model = model diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 68dddc7b7..226f291d7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -50,6 +50,7 @@ class GeminiAdapter(LLMInterface): model: str, api_version: str, max_completion_tokens: int, + instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, @@ -64,15 +65,10 @@ class GeminiAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - from cognee.infrastructure.llm.config import get_llm_config - - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode(instructor_mode) + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index ea32dced1..9d7f25fc5 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -50,6 +50,7 @@ class GenericAPIAdapter(LLMInterface): model: str, name: str, max_completion_tokens: int, + instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, @@ -64,15 +65,10 @@ class GenericAPIAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - from cognee.infrastructure.llm.config import get_llm_config - - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode(instructor_mode) + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index c7dcecc56..537eda1b2 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -81,6 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, transcription_model=llm_config.transcription_model, max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode, streaming=llm_config.llm_streaming, fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, @@ -101,6 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Ollama", max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode, ) elif provider == LLMProvider.ANTHROPIC: @@ -109,7 +111,9 @@ def get_llm_client(raise_api_key_error: bool = True): ) return AnthropicAdapter( - max_completion_tokens=max_completion_tokens, model=llm_config.llm_model + max_completion_tokens=max_completion_tokens, + model=llm_config.llm_model, + instructor_mode=llm_config.llm_instructor_mode, ) elif provider == LLMProvider.CUSTOM: @@ -126,6 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Custom", max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode, fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, fallback_model=llm_config.fallback_model, @@ -145,6 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True): max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, api_version=llm_config.llm_api_version, + instructor_mode=llm_config.llm_instructor_mode, ) elif provider == LLMProvider.MISTRAL: @@ -160,6 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, + instructor_mode=llm_config.llm_instructor_mode, ) elif provider == LLMProvider.MISTRAL: diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index bed88ce3c..355cdae0b 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -39,20 +39,24 @@ class MistralAdapter(LLMInterface): max_completion_tokens: int default_instructor_mode = "mistral_tools" - def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None): + def __init__( + self, + api_key: str, + model: str, + max_completion_tokens: int, + endpoint: str = None, + instructor_mode: str = None, + ): from mistralai import Mistral self.model = model self.max_completion_tokens = max_completion_tokens - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_litellm( litellm.acompletion, - mode=instructor.Mode(instructor_mode), + mode=instructor.Mode(self.instructor_mode), api_key=get_llm_config().llm_api_key, ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index aa24a7911..aabd19867 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -45,7 +45,13 @@ class OllamaAPIAdapter(LLMInterface): default_instructor_mode = "json_mode" def __init__( - self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int + self, + endpoint: str, + api_key: str, + model: str, + name: str, + max_completion_tokens: int, + instructor_mode: str = None, ): self.name = name self.model = model @@ -53,16 +59,11 @@ class OllamaAPIAdapter(LLMInterface): self.endpoint = endpoint self.max_completion_tokens = max_completion_tokens - from cognee.infrastructure.llm.config import get_llm_config - - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_openai( OpenAI(base_url=self.endpoint, api_key=self.api_key), - mode=instructor.Mode(instructor_mode), + mode=instructor.Mode(self.instructor_mode), ) @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 69367602d..778c8eec7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -70,25 +70,21 @@ class OpenAIAdapter(LLMInterface): model: str, transcription_model: str, max_completion_tokens: int, + instructor_mode: str = None, streaming: bool = False, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): - from cognee.infrastructure.llm.config import get_llm_config - - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs. # Make sure all new gpt models will work with this mode as well. if "gpt-5" in model: self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode(instructor_mode) + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) self.client = instructor.from_litellm( - litellm.completion, mode=instructor.Mode(instructor_mode) + litellm.completion, mode=instructor.Mode(self.instructor_mode) ) else: self.aclient = instructor.from_litellm(litellm.acompletion) From 844b8d635a7646750dc63dd4be13de07f2996940 Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Fri, 14 Nov 2025 22:13:00 +0500 Subject: [PATCH 12/24] feat: enhance ontology handling to support multiple uploads and retrievals --- .../v1/cognify/routers/get_cognify_router.py | 41 +++---- cognee/api/v1/ontologies/ontologies.py | 108 +++++++++++++++--- .../ontologies/routers/get_ontology_router.py | 51 ++++++--- .../rdf_xml/RDFLibOntologyResolver.py | 93 +++++++++------ 4 files changed, 202 insertions(+), 91 deletions(-) diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 252ffe7bf..4f1497e3c 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -41,8 +41,8 @@ class CognifyPayloadDTO(InDTO): custom_prompt: Optional[str] = Field( default="", description="Custom prompt for entity extraction and graph generation" ) - ontology_key: Optional[str] = Field( - default=None, description="Reference to previously uploaded ontology" + ontology_key: Optional[List[str]] = Field( + default=None, description="Reference to one or more previously uploaded ontologies" ) @@ -71,7 +71,7 @@ def get_cognify_router() -> APIRouter: - **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted). - **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking). - **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction. - - **ontology_key** (Optional[str]): Reference to a previously uploaded ontology file to use for knowledge graph construction. + - **ontology_key** (Optional[List[str]]): Reference to one or more previously uploaded ontology files to use for knowledge graph construction. ## Response - **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status @@ -87,7 +87,7 @@ def get_cognify_router() -> APIRouter: "datasets": ["research_papers", "documentation"], "run_in_background": false, "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.", - "ontology_key": "medical_ontology_v1" + "ontology_key": ["medical_ontology_v1"] } ``` @@ -121,29 +121,22 @@ def get_cognify_router() -> APIRouter: if payload.ontology_key: ontology_service = OntologyService() - try: - ontology_content = ontology_service.get_ontology_content( - payload.ontology_key, user - ) + ontology_contents = ontology_service.get_ontology_contents( + payload.ontology_key, user + ) - from cognee.modules.ontology.ontology_config import Config - from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import ( - RDFLibOntologyResolver, - ) - from io import StringIO + from cognee.modules.ontology.ontology_config import Config + from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import ( + RDFLibOntologyResolver, + ) + from io import StringIO - ontology_stream = StringIO(ontology_content) - config_to_use: Config = { - "ontology_config": { - "ontology_resolver": RDFLibOntologyResolver( - ontology_file=ontology_stream - ) - } + ontology_streams = [StringIO(content) for content in ontology_contents] + config_to_use: Config = { + "ontology_config": { + "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_streams) } - except ValueError as e: - return JSONResponse( - status_code=400, content={"error": f"Ontology error: {str(e)}"} - ) + } cognify_run = await cognee_cognify( datasets, diff --git a/cognee/api/v1/ontologies/ontologies.py b/cognee/api/v1/ontologies/ontologies.py index 6bfb7658e..130b4a862 100644 --- a/cognee/api/v1/ontologies/ontologies.py +++ b/cognee/api/v1/ontologies/ontologies.py @@ -3,7 +3,7 @@ import json import tempfile from pathlib import Path from datetime import datetime, timezone -from typing import Optional +from typing import Optional, List from dataclasses import dataclass @@ -47,28 +47,23 @@ class OntologyService: async def upload_ontology( self, ontology_key: str, file, user, description: Optional[str] = None ) -> OntologyMetadata: - # Validate file format if not file.filename.lower().endswith(".owl"): raise ValueError("File must be in .owl format") user_dir = self._get_user_dir(str(user.id)) metadata = self._load_metadata(user_dir) - # Check for duplicate key if ontology_key in metadata: raise ValueError(f"Ontology key '{ontology_key}' already exists") - # Read file content content = await file.read() - if len(content) > 10 * 1024 * 1024: # 10MB limit + if len(content) > 10 * 1024 * 1024: raise ValueError("File size exceeds 10MB limit") - # Save file file_path = user_dir / f"{ontology_key}.owl" with open(file_path, "wb") as f: f.write(content) - # Update metadata ontology_metadata = { "filename": file.filename, "size_bytes": len(content), @@ -86,19 +81,102 @@ class OntologyService: description=description, ) - def get_ontology_content(self, ontology_key: str, user) -> str: + async def upload_ontologies( + self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None + ) -> List[OntologyMetadata]: + """ + Upload ontology files with their respective keys. + + Args: + ontology_key: List of unique keys for each ontology + files: List of UploadFile objects (same length as keys) + user: Authenticated user + descriptions: Optional list of descriptions for each file + + Returns: + List of OntologyMetadata objects for uploaded files + + Raises: + ValueError: If keys duplicate, file format invalid, or array lengths don't match + """ + if len(ontology_key) != len(files): + raise ValueError("Number of keys must match number of files") + + if len(set(ontology_key)) != len(ontology_key): + raise ValueError("Duplicate ontology keys not allowed") + + if descriptions and len(descriptions) != len(files): + raise ValueError("Number of descriptions must match number of files") + + results = [] user_dir = self._get_user_dir(str(user.id)) metadata = self._load_metadata(user_dir) - if ontology_key not in metadata: - raise ValueError(f"Ontology key '{ontology_key}' not found") + for i, (key, file) in enumerate(zip(ontology_key, files)): + if key in metadata: + raise ValueError(f"Ontology key '{key}' already exists") - file_path = user_dir / f"{ontology_key}.owl" - if not file_path.exists(): - raise ValueError(f"Ontology file for key '{ontology_key}' not found") + if not file.filename.lower().endswith(".owl"): + raise ValueError(f"File '{file.filename}' must be in .owl format") - with open(file_path, "r", encoding="utf-8") as f: - return f.read() + content = await file.read() + if len(content) > 10 * 1024 * 1024: + raise ValueError(f"File '{file.filename}' exceeds 10MB limit") + + file_path = user_dir / f"{key}.owl" + with open(file_path, "wb") as f: + f.write(content) + + ontology_metadata = { + "filename": file.filename, + "size_bytes": len(content), + "uploaded_at": datetime.now(timezone.utc).isoformat(), + "description": descriptions[i] if descriptions else None, + } + metadata[key] = ontology_metadata + + results.append( + OntologyMetadata( + ontology_key=key, + filename=file.filename, + size_bytes=len(content), + uploaded_at=ontology_metadata["uploaded_at"], + description=descriptions[i] if descriptions else None, + ) + ) + + self._save_metadata(user_dir, metadata) + return results + + def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]: + """ + Retrieve ontology content for one or more keys. + + Args: + ontology_key: List of ontology keys to retrieve (can contain single item) + user: Authenticated user + + Returns: + List of ontology content strings + + Raises: + ValueError: If any ontology key not found + """ + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + contents = [] + for key in ontology_key: + if key not in metadata: + raise ValueError(f"Ontology key '{key}' not found") + + file_path = user_dir / f"{key}.owl" + if not file_path.exists(): + raise ValueError(f"Ontology file for key '{key}' not found") + + with open(file_path, "r", encoding="utf-8") as f: + contents.append(f.read()) + return contents def list_ontologies(self, user) -> dict: user_dir = self._get_user_dir(str(user.id)) diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py index f5c51ba21..ee31c683f 100644 --- a/cognee/api/v1/ontologies/routers/get_ontology_router.py +++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException from fastapi.responses import JSONResponse -from typing import Optional +from typing import Optional, List from cognee.modules.users.models import User from cognee.modules.users.methods import get_authenticated_user @@ -16,23 +16,27 @@ def get_ontology_router() -> APIRouter: @router.post("", response_model=dict) async def upload_ontology( ontology_key: str = Form(...), - ontology_file: UploadFile = File(...), - description: Optional[str] = Form(None), + ontology_file: List[UploadFile] = File(...), + descriptions: Optional[str] = Form(None), user: User = Depends(get_authenticated_user), ): """ - Upload an ontology file with a named key for later use in cognify operations. + Upload ontology files with their respective keys for later use in cognify operations. + + Supports both single and multiple file uploads: + - Single file: ontology_key=["key"], ontology_file=[file] + - Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2] ## Request Parameters - - **ontology_key** (str): User-defined identifier for the ontology - - **ontology_file** (UploadFile): OWL format ontology file - - **description** (Optional[str]): Optional description of the ontology + - **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies + - **ontology_file** (List[UploadFile]): OWL format ontology files + - **descriptions** (Optional[str]): JSON array string of optional descriptions ## Response - Returns metadata about the uploaded ontology including key, filename, size, and upload timestamp. + Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps. ## Error Codes - - **400 Bad Request**: Invalid file format, duplicate key, file size exceeded + - **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded - **500 Internal Server Error**: File system or processing errors """ send_telemetry( @@ -45,16 +49,31 @@ def get_ontology_router() -> APIRouter: ) try: - result = await ontology_service.upload_ontology( - ontology_key, ontology_file, user, description + import json + + ontology_keys = json.loads(ontology_key) + description_list = json.loads(descriptions) if descriptions else None + + if not isinstance(ontology_keys, list): + raise ValueError("ontology_key must be a JSON array") + + results = await ontology_service.upload_ontologies( + ontology_keys, ontology_file, user, description_list ) + return { - "ontology_key": result.ontology_key, - "filename": result.filename, - "size_bytes": result.size_bytes, - "uploaded_at": result.uploaded_at, + "uploaded_ontologies": [ + { + "ontology_key": result.ontology_key, + "filename": result.filename, + "size_bytes": result.size_bytes, + "uploaded_at": result.uploaded_at, + "description": result.description, + } + for result in results + ] } - except ValueError as e: + except (json.JSONDecodeError, ValueError) as e: return JSONResponse(status_code=400, content={"error": str(e)}) except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) diff --git a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py index 4acc8861b..34d7a946a 100644 --- a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +++ b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py @@ -26,7 +26,7 @@ class RDFLibOntologyResolver(BaseOntologyResolver): def __init__( self, - ontology_file: Optional[Union[str, List[str], IO]] = None, + ontology_file: Optional[Union[str, List[str], IO, List[IO]]] = None, matching_strategy: Optional[MatchingStrategy] = None, ) -> None: super().__init__(matching_strategy) @@ -34,47 +34,68 @@ class RDFLibOntologyResolver(BaseOntologyResolver): try: self.graph = None if ontology_file is not None: + files_to_load = [] + file_objects = [] + if hasattr(ontology_file, "read"): - self.graph = Graph() - content = ontology_file.read() - self.graph.parse(data=content, format="xml") - logger.info("Ontology loaded successfully from file object") - else: - files_to_load = [] - if isinstance(ontology_file, str): - files_to_load = [ontology_file] - elif isinstance(ontology_file, list): + file_objects = [ontology_file] + elif isinstance(ontology_file, str): + files_to_load = [ontology_file] + elif isinstance(ontology_file, list): + if all(hasattr(item, "read") for item in ontology_file): + file_objects = ontology_file + else: files_to_load = ontology_file - else: - raise ValueError( - f"ontology_file must be a string, list of strings, file-like object, or None. Got: {type(ontology_file)}" - ) + else: + raise ValueError( + f"ontology_file must be a string, list of strings, file-like object, list of file-like objects, or None. Got: {type(ontology_file)}" + ) - if files_to_load: - self.graph = Graph() - loaded_files = [] - for file_path in files_to_load: - if os.path.exists(file_path): - self.graph.parse(file_path) - loaded_files.append(file_path) - logger.info("Ontology loaded successfully from file: %s", file_path) - else: - logger.warning( - "Ontology file '%s' not found. Skipping this file.", - file_path, - ) + if file_objects: + self.graph = Graph() + loaded_objects = [] + for file_obj in file_objects: + try: + content = file_obj.read() + self.graph.parse(data=content, format="xml") + loaded_objects.append(file_obj) + logger.info("Ontology loaded successfully from file object") + except Exception as e: + logger.warning("Failed to parse ontology file object: %s", str(e)) - if not loaded_files: - logger.info( - "No valid ontology files found. No owl ontology will be attached to the graph." - ) - self.graph = None - else: - logger.info("Total ontology files loaded: %d", len(loaded_files)) - else: + if not loaded_objects: logger.info( - "No ontology file provided. No owl ontology will be attached to the graph." + "No valid ontology file objects found. No owl ontology will be attached to the graph." ) + self.graph = None + else: + logger.info("Total ontology file objects loaded: %d", len(loaded_objects)) + + elif files_to_load: + self.graph = Graph() + loaded_files = [] + for file_path in files_to_load: + if os.path.exists(file_path): + self.graph.parse(file_path) + loaded_files.append(file_path) + logger.info("Ontology loaded successfully from file: %s", file_path) + else: + logger.warning( + "Ontology file '%s' not found. Skipping this file.", + file_path, + ) + + if not loaded_files: + logger.info( + "No valid ontology files found. No owl ontology will be attached to the graph." + ) + self.graph = None + else: + logger.info("Total ontology files loaded: %d", len(loaded_files)) + else: + logger.info( + "No ontology file provided. No owl ontology will be attached to the graph." + ) else: logger.info( "No ontology file provided. No owl ontology will be attached to the graph." From 01f1c099cc972e2222f1174c515e1baf87fbb9d6 Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Fri, 14 Nov 2025 22:20:54 +0500 Subject: [PATCH 13/24] test: enhance server start test with ontology upload verification - Extend test_cognee_server_start to upload ontology and verify integration - Move test_ontology_endpoint from tests/ to tests/unit/api/ --- cognee/tests/test_cognee_server_start.py | 45 ++++++++++++++++++- .../{ => unit/api}/test_ontology_endpoint.py | 0 2 files changed, 44 insertions(+), 1 deletion(-) rename cognee/tests/{ => unit/api}/test_ontology_endpoint.py (100%) diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py index ab68a8ef1..d6aa55a98 100644 --- a/cognee/tests/test_cognee_server_start.py +++ b/cognee/tests/test_cognee_server_start.py @@ -7,6 +7,7 @@ import requests from pathlib import Path import sys import uuid +import json class TestCogneeServerStart(unittest.TestCase): @@ -90,12 +91,31 @@ class TestCogneeServerStart(unittest.TestCase): ) } - payload = {"datasets": [dataset_name]} + ontology_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + payload = {"datasets": [dataset_name], "ontology_key": [ontology_key]} add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50) if add_response.status_code not in [200, 201]: add_response.raise_for_status() + ontology_content = b""" + + + + + """ + + ontology_response = requests.post( + "http://127.0.0.1:8000/api/v1/ontologies", + files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], + data={ + "ontology_key": json.dumps([ontology_key]), + "description": json.dumps(["Test ontology"]), + }, + ) + self.assertEqual(ontology_response.status_code, 200) + # Cognify request url = "http://127.0.0.1:8000/api/v1/cognify" headers = { @@ -107,6 +127,29 @@ class TestCogneeServerStart(unittest.TestCase): if cognify_response.status_code not in [200, 201]: cognify_response.raise_for_status() + datasets_response = requests.get("http://127.0.0.1:8000/api/v1/datasets", headers=headers) + + datasets = datasets_response.json() + dataset_id = None + for dataset in datasets: + if dataset["name"] == dataset_name: + dataset_id = dataset["id"] + break + + graph_response = requests.get( + f"http://127.0.0.1:8000/api/v1/datasets/{dataset_id}/graph", headers=headers + ) + self.assertEqual(graph_response.status_code, 200) + + graph_data = graph_response.json() + ontology_nodes = [ + node for node in graph_data.get("nodes") if node.get("properties").get("ontology_valid") + ] + + self.assertGreater( + len(ontology_nodes), 0, "No ontology nodes found - ontology was not integrated" + ) + # TODO: Add test to verify cognify pipeline is complete before testing search # Search request diff --git a/cognee/tests/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py similarity index 100% rename from cognee/tests/test_ontology_endpoint.py rename to cognee/tests/unit/api/test_ontology_endpoint.py From 1ded09d0f995fa57c9eaa3feafbe64089525a92f Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Sat, 15 Nov 2025 00:06:55 +0500 Subject: [PATCH 14/24] fix: fixed test ontology file content. Added tests to support multiple files and improved validation. --- cognee/tests/test_cognee_server_start.py | 13 +- .../tests/unit/api/test_ontology_endpoint.py | 189 +++++++++++++++++- 2 files changed, 185 insertions(+), 17 deletions(-) diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py index d6aa55a98..b266fc7bf 100644 --- a/cognee/tests/test_cognee_server_start.py +++ b/cognee/tests/test_cognee_server_start.py @@ -98,13 +98,12 @@ class TestCogneeServerStart(unittest.TestCase): if add_response.status_code not in [200, 201]: add_response.raise_for_status() - ontology_content = b""" - - - - - """ + ontology_content = b""" + + + + + """ ontology_response = requests.post( "http://127.0.0.1:8000/api/v1/ontologies", diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py index b5cedfafe..c04959998 100644 --- a/cognee/tests/unit/api/test_ontology_endpoint.py +++ b/cognee/tests/unit/api/test_ontology_endpoint.py @@ -32,6 +32,8 @@ def mock_default_user(): @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_success(mock_get_default_user, client, mock_default_user): """Test successful ontology upload""" + import json + mock_get_default_user.return_value = mock_default_user ontology_content = ( b"" @@ -40,14 +42,14 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use response = client.post( "/api/v1/ontologies", - files={"ontology_file": ("test.owl", ontology_content)}, - data={"ontology_key": unique_key, "description": "Test"}, + files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], + data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])}, ) assert response.status_code == 200 data = response.json() - assert data["ontology_key"] == unique_key - assert "uploaded_at" in data + assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key + assert "uploaded_at" in data["uploaded_ontologies"][0] @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) @@ -66,30 +68,197 @@ def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_defaul @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user): """Test 400 response for missing file or key""" + import json + mock_get_default_user.return_value = mock_default_user # Missing file - response = client.post("/api/v1/ontologies", data={"ontology_key": "test"}) + response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])}) assert response.status_code == 400 # Missing key - response = client.post("/api/v1/ontologies", files={"ontology_file": ("test.owl", b"xml")}) + response = client.post( + "/api/v1/ontologies", files=[("ontology_file", ("test.owl", b"xml", "application/xml"))] + ) assert response.status_code == 400 @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user): """Test behavior when default user is provided (no explicit authentication)""" + import json + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" mock_get_default_user.return_value = mock_default_user response = client.post( "/api/v1/ontologies", - files={"ontology_file": ("test.owl", b"")}, - data={"ontology_key": unique_key}, + files=[("ontology_file", ("test.owl", b"", "application/xml"))], + data={"ontology_key": json.dumps([unique_key])}, ) # The current system provides a default user when no explicit authentication is given # This test verifies the system works with conditional authentication assert response.status_code == 200 data = response.json() - assert data["ontology_key"] == unique_key - assert "uploaded_at" in data + assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key + assert "uploaded_at" in data["uploaded_ontologies"][0] + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user): + """Test uploading multiple ontology files in single request""" + import io + + # Create mock files + file1_content = b"" + file2_content = b"" + + files = [ + ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), + ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), + ] + data = { + "ontology_key": '["vehicles", "manufacturers"]', + "descriptions": '["Base vehicles", "Car manufacturers"]', + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + + assert response.status_code == 200 + result = response.json() + assert "uploaded_ontologies" in result + assert len(result["uploaded_ontologies"]) == 2 + assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles" + assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers" + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user): + """Test that upload endpoint accepts array parameters""" + import io + import json + + file_content = b"" + + files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))] + data = { + "ontology_key": json.dumps(["single_key"]), + "descriptions": json.dumps(["Single ontology"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + + assert response.status_code == 200 + result = response.json() + assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key" + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user): + """Test cognify endpoint accepts multiple ontology keys""" + payload = { + "datasets": ["test_dataset"], + "ontology_key": ["ontology1", "ontology2"], # Array instead of string + "run_in_background": False, + } + + response = client.post("/api/v1/cognify", json=payload) + + # Should not fail due to ontology_key type + assert response.status_code in [200, 400, 409] # May fail for other reasons, not type + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user): + """Test complete workflow: upload multiple ontologies → cognify with multiple keys""" + import io + import json + + # Step 1: Upload multiple ontologies + file1_content = b""" + + + """ + + file2_content = b""" + + + """ + + files = [ + ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), + ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), + ] + data = { + "ontology_key": json.dumps(["vehicles", "manufacturers"]), + "descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]), + } + + upload_response = client.post("/api/v1/ontologies", files=files, data=data) + assert upload_response.status_code == 200 + + # Step 2: Verify ontologies are listed + list_response = client.get("/api/v1/ontologies") + assert list_response.status_code == 200 + ontologies = list_response.json() + assert "vehicles" in ontologies + assert "manufacturers" in ontologies + + # Step 3: Test cognify with multiple ontologies + cognify_payload = { + "datasets": ["test_dataset"], + "ontology_key": ["vehicles", "manufacturers"], + "run_in_background": False, + } + + cognify_response = client.post("/api/v1/cognify", json=cognify_payload) + # Should not fail due to ontology handling (may fail for dataset reasons) + assert cognify_response.status_code != 400 # Not a validation error + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_multifile_error_handling(mock_get_default_user, client, mock_default_user): + """Test error handling for invalid multifile uploads""" + import io + import json + + # Test mismatched array lengths + file_content = b"" + files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))] + data = { + "ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file + "descriptions": json.dumps(["desc1"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + assert response.status_code == 400 + assert "Number of keys must match number of files" in response.json()["error"] + + # Test duplicate keys + files = [ + ("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")), + ("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")), + ] + data = { + "ontology_key": json.dumps(["duplicate", "duplicate"]), + "descriptions": json.dumps(["desc1", "desc2"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + assert response.status_code == 400 + assert "Duplicate ontology keys not allowed" in response.json()["error"] + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user): + """Test cognify with non-existent ontology key""" + payload = { + "datasets": ["test_dataset"], + "ontology_key": ["nonexistent_key"], + "run_in_background": False, + } + + response = client.post("/api/v1/cognify", json=payload) + assert response.status_code == 409 + assert "Ontology key 'nonexistent_key' not found" in response.json()["error"] From 983bfae4fcc9046fd520f0c24733e491679d54cf Mon Sep 17 00:00:00 2001 From: EricXiao Date: Mon, 17 Nov 2025 14:41:55 +0800 Subject: [PATCH 15/24] chore: remove unnecessary csv file type Signed-off-by: EricXiao --- .../files/utils/is_csv_content.py | 181 ------------------ 1 file changed, 181 deletions(-) delete mode 100644 cognee/infrastructure/files/utils/is_csv_content.py diff --git a/cognee/infrastructure/files/utils/is_csv_content.py b/cognee/infrastructure/files/utils/is_csv_content.py deleted file mode 100644 index 07b7ea69b..000000000 --- a/cognee/infrastructure/files/utils/is_csv_content.py +++ /dev/null @@ -1,181 +0,0 @@ -import csv -from collections import Counter - - -def is_csv_content(content): - """ - Heuristically determine whether a bytes-like object is CSV text. - - Strategy (fail-fast and cheap to expensive): - 1) Decode: Try a small ordered list of common encodings with strict errors. - 2) Line sampling: require >= 2 non-empty lines; sample up to 50 lines. - 3) Delimiter detection: - - Prefer csv.Sniffer() with common delimiters. - - Fallback to a lightweight consistency heuristic. - 4) Lightweight parse check: - - Parse a few lines with the delimiter. - - Ensure at least 2 valid rows and relatively stable column counts. - - Returns: - bool: True if the buffer looks like CSV; False otherwise. - """ - try: - encoding_list = [ - "utf-8", - "utf-8-sig", - "utf-32-le", - "utf-32-be", - "utf-16-le", - "utf-16-be", - "gb18030", - "shift_jis", - "cp949", - "cp1252", - "iso-8859-1", - ] - - # Try to decode strictly—if decoding fails for all encodings, it's not text/CSV. - text = None - for enc in encoding_list: - try: - text = content.decode(enc, errors="strict") - break - except UnicodeDecodeError: - continue - if text is None: - return False - - # Reject empty/whitespace-only payloads. - stripped = text.strip() - if not stripped: - return False - - # Split into logical lines and drop empty ones. Require at least two lines. - lines = [ln for ln in text.splitlines() if ln.strip()] - if len(lines) < 2: - return False - - # Take a small sample to keep sniffing cheap and predictable. - sample_lines = lines[:50] - - # Detect delimiter using csv.Sniffer first; if that fails, use our heuristic. - delimiter = _sniff_delimiter(sample_lines) or _heuristic_delimiter(sample_lines) - if not delimiter: - return False - - # Finally, do a lightweight parse sanity check with the chosen delimiter. - return _lightweight_parse_check(sample_lines, delimiter) - except Exception: - return False - - -def _sniff_delimiter(lines): - """ - Try Python's built-in csv.Sniffer on a sample. - - Args: - lines (list[str]): Sample lines (already decoded). - - Returns: - str | None: The detected delimiter if sniffing succeeds; otherwise None. - """ - # Join up to 50 lines to form the sample string Sniffer will inspect. - sample = "\n".join(lines[:50]) - try: - dialect = csv.Sniffer().sniff(sample, delimiters=",\t;|") - return dialect.delimiter - except Exception: - # Sniffer is known to be brittle on small/dirty samples—silently fallback. - return None - - -def _heuristic_delimiter(lines): - """ - Fallback delimiter detection based on count consistency per line. - - Heuristic: - - For each candidate delimiter, count occurrences per line. - - Keep only lines with count > 0 (line must contain the delimiter). - - Require at least half of lines to contain the delimiter (min 2). - - Compute the mode (most common count). If the proportion of lines that - exhibit the modal count is >= 80%, accept that delimiter. - - Args: - lines (list[str]): Sample lines. - - Returns: - str | None: Best delimiter if one meets the consistency threshold; else None. - """ - candidates = [",", "\t", ";", "|"] - best = None - best_score = 0.0 - - for d in candidates: - # Count how many times the delimiter appears in each line. - counts = [ln.count(d) for ln in lines] - # Consider only lines that actually contain the delimiter at least once. - nonzero = [c for c in counts if c > 0] - - # Require that more than half of lines (and at least 2) contain the delimiter. - if len(nonzero) < max(2, int(0.5 * len(lines))): - continue - - # Find the modal count and its frequency. - cnt = Counter(nonzero) - pairs = cnt.most_common(1) - if not pairs: - continue - - mode, mode_freq = pairs[0] - # Consistency ratio: lines with the modal count / total lines in the sample. - consistency = mode_freq / len(lines) - # Accept if consistent enough and better than any previous candidate. - if mode >= 1 and consistency >= 0.80 and consistency > best_score: - best = d - best_score = consistency - - return best - - -def _lightweight_parse_check(lines, delimiter): - """ - Parse a few lines with csv.reader and check structural stability. - - Heuristic: - - Parse up to 5 lines with the given delimiter. - - Count column widths per parsed row. - - Require at least 2 non-empty rows. - - Allow at most 1 row whose width deviates by >2 columns from the first row. - - Args: - lines (list[str]): Sample lines (decoded). - delimiter (str): Delimiter chosen by sniffing/heuristics. - - Returns: - bool: True if parsing looks stable; False otherwise. - """ - try: - # csv.reader accepts any iterable of strings; feeding the first 10 lines is fine. - reader = csv.reader(lines[:10], delimiter=delimiter) - widths = [] - valid_rows = 0 - for row in reader: - if not row: - continue - - widths.append(len(row)) - valid_rows += 1 - - # Need at least two meaningful rows to make a judgment. - if valid_rows < 2: - return False - - if widths: - first = widths[0] - # Count rows whose width deviates significantly (>2) from the first row. - unstable = sum(1 for w in widths if abs(w - first) > 2) - # Permit at most 1 unstable row among the parsed sample. - return unstable <= 1 - return False - except Exception: - return False From fe55071849c06ecd3fe70602d56bb1f2904538a4 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 19 Nov 2025 15:33:50 +0100 Subject: [PATCH 16/24] Feature/cog 3407 fixing integration test in ci (#1810) ## Description This PR should fix the web crawler integration test issue in our CI ## Type of Change - [x] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- .../test_default_url_crawler.py | 2 +- .../web_url_crawler/test_tavily_crawler.py | 2 +- .../web_url_crawler/test_url_adding_e2e.py | 40 ++++++------------- 3 files changed, 15 insertions(+), 29 deletions(-) diff --git a/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py b/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py index 156cc87a4..f48c1cedc 100644 --- a/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py +++ b/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py @@ -5,7 +5,7 @@ from cognee.tasks.web_scraper import DefaultUrlCrawler @pytest.mark.asyncio async def test_fetch(): crawler = DefaultUrlCrawler() - url = "https://en.wikipedia.org/wiki/Large_language_model" + url = "https://httpbin.org/html" results = await crawler.fetch_urls(url) assert len(results) == 1 assert isinstance(results, dict) diff --git a/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py b/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py index 946ce8378..19ffdc4ea 100644 --- a/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py +++ b/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py @@ -11,7 +11,7 @@ skip_in_ci = pytest.mark.skipif( @skip_in_ci @pytest.mark.asyncio async def test_fetch(): - url = "https://en.wikipedia.org/wiki/Large_language_model" + url = "https://httpbin.org/html" results = await fetch_with_tavily(url) assert isinstance(results, dict) assert len(results) == 1 diff --git a/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py b/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py index d91b075aa..cc8ae24d0 100644 --- a/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py +++ b/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py @@ -14,9 +14,7 @@ async def test_url_saves_as_html_file(): await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("https://httpbin.org/html") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -44,9 +42,7 @@ async def test_saved_html_is_valid(): await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("https://httpbin.org/html") file_path = get_data_file_path(original_file_path) content = Path(file_path).read_text() @@ -72,7 +68,7 @@ async def test_add_url(): await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - await cognee.add("https://en.wikipedia.org/wiki/Large_language_model") + await cognee.add("https://httpbin.org/html") skip_in_ci = pytest.mark.skipif( @@ -88,7 +84,7 @@ async def test_add_url_with_tavily(): await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - await cognee.add("https://en.wikipedia.org/wiki/Large_language_model") + await cognee.add("https://httpbin.org/html") @pytest.mark.asyncio @@ -98,7 +94,7 @@ async def test_add_url_without_incremental_loading(): try: await cognee.add( - "https://en.wikipedia.org/wiki/Large_language_model", + "https://httpbin.org/html", incremental_loading=False, ) except Exception as e: @@ -112,7 +108,7 @@ async def test_add_url_with_incremental_loading(): try: await cognee.add( - "https://en.wikipedia.org/wiki/Large_language_model", + "https://httpbin.org/html", incremental_loading=True, ) except Exception as e: @@ -125,7 +121,7 @@ async def test_add_url_can_define_preferred_loader_as_list_of_str(): await cognee.prune.prune_system(metadata=True) await cognee.add( - "https://en.wikipedia.org/wiki/Large_language_model", + "https://httpbin.org/html", preferred_loaders=["beautiful_soup_loader"], ) @@ -144,7 +140,7 @@ async def test_add_url_with_extraction_rules(): try: await cognee.add( - "https://en.wikipedia.org/wiki/Large_language_model", + "https://httpbin.org/html", preferred_loaders={"beautiful_soup_loader": {"extraction_rules": extraction_rules}}, ) except Exception as e: @@ -163,9 +159,7 @@ async def test_loader_is_none_by_default(): } try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("https://httpbin.org/html") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -196,9 +190,7 @@ async def test_beautiful_soup_loader_is_selected_loader_if_preferred_loader_prov } try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("https://httpbin.org/html") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -225,9 +217,7 @@ async def test_beautiful_soup_loader_works_with_and_without_arguments(): await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("https://httpbin.org/html") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -263,9 +253,7 @@ async def test_beautiful_soup_loader_successfully_loads_file_if_required_args_pr await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("https://httpbin.org/html") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -302,9 +290,7 @@ async def test_beautiful_soup_loads_file_successfully(): } try: - original_file_path = await save_data_item_to_storage( - "https://en.wikipedia.org/wiki/Large_language_model" - ) + original_file_path = await save_data_item_to_storage("https://httpbin.org/html") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") original_file = Path(file_path) From 30e3971d44816db50d9e83eee39bf6d69b98a328 Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Thu, 20 Nov 2025 15:36:15 +0500 Subject: [PATCH 17/24] fix: add auth headers to ontology upload request and enhance ontology content --- cognee/tests/test_cognee_server_start.py | 53 +++++++++++++++++++++--- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py index b266fc7bf..ddffe53a4 100644 --- a/cognee/tests/test_cognee_server_start.py +++ b/cognee/tests/test_cognee_server_start.py @@ -98,15 +98,56 @@ class TestCogneeServerStart(unittest.TestCase): if add_response.status_code not in [200, 201]: add_response.raise_for_status() - ontology_content = b""" - - - - - """ + ontology_content = b""" + + + + + + + + + + + + + + + + A failure caused by physical components. + + + + + An error caused by software logic or configuration. + + + + A human being or individual. + + + + + Programmers + + + + Light Bulb + + + + Hardware Problem + + + """ ontology_response = requests.post( "http://127.0.0.1:8000/api/v1/ontologies", + headers=headers, files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], data={ "ontology_key": json.dumps([ontology_key]), From 8cfb6c41eeca3b2ad0e34fb4b6043f4b5f6a8c00 Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Thu, 20 Nov 2025 15:54:09 +0500 Subject: [PATCH 18/24] fix: remove async from ontology endpoint test functions --- cognee/tests/unit/api/test_ontology_endpoint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py index c04959998..d53c5ab44 100644 --- a/cognee/tests/unit/api/test_ontology_endpoint.py +++ b/cognee/tests/unit/api/test_ontology_endpoint.py @@ -104,7 +104,7 @@ def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_defaul @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user): +def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user): """Test uploading multiple ontology files in single request""" import io @@ -132,7 +132,7 @@ async def test_upload_multiple_ontologies(mock_get_default_user, client, mock_de @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user): +def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user): """Test that upload endpoint accepts array parameters""" import io import json @@ -153,7 +153,7 @@ async def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, moc @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user): +def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user): """Test cognify endpoint accepts multiple ontology keys""" payload = { "datasets": ["test_dataset"], @@ -168,7 +168,7 @@ async def test_cognify_with_multiple_ontologies(mock_get_default_user, client, m @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user): +def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user): """Test complete workflow: upload multiple ontologies → cognify with multiple keys""" import io import json @@ -218,7 +218,7 @@ async def test_complete_multifile_workflow(mock_get_default_user, client, mock_d @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_multifile_error_handling(mock_get_default_user, client, mock_default_user): +def test_multifile_error_handling(mock_get_default_user, client, mock_default_user): """Test error handling for invalid multifile uploads""" import io import json @@ -251,7 +251,7 @@ async def test_multifile_error_handling(mock_get_default_user, client, mock_defa @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user): +def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user): """Test cognify with non-existent ontology key""" payload = { "datasets": ["test_dataset"], From 4e880eca8422872b26a568b18b9f1339ce362c18 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 20 Nov 2025 15:47:22 +0100 Subject: [PATCH 19/24] chore: update env template --- .env.template | 1 + 1 file changed, 1 insertion(+) diff --git a/.env.template b/.env.template index ae2cb1338..376233b1f 100644 --- a/.env.template +++ b/.env.template @@ -21,6 +21,7 @@ LLM_PROVIDER="openai" LLM_ENDPOINT="" LLM_API_VERSION="" LLM_MAX_TOKENS="16384" +LLM_INSTRUCTOR_MODE="json_schema_mode" # this mode is used for gpt-5 models EMBEDDING_PROVIDER="openai" EMBEDDING_MODEL="openai/text-embedding-3-large" From 2176ec16b8e440087f96410fed979528e8159ca2 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 20 Nov 2025 17:03:36 +0100 Subject: [PATCH 20/24] chore: changes url for crawler tests (#1816) Updates crawler test url to avoid blocking and unavailable sites in CI. ## Description ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- .../test_default_url_crawler.py | 2 +- .../web_url_crawler/test_tavily_crawler.py | 2 +- .../web_url_crawler/test_url_adding_e2e.py | 26 +++++++++---------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py b/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py index f48c1cedc..af2595b14 100644 --- a/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py +++ b/cognee/tests/integration/web_url_crawler/test_default_url_crawler.py @@ -5,7 +5,7 @@ from cognee.tasks.web_scraper import DefaultUrlCrawler @pytest.mark.asyncio async def test_fetch(): crawler = DefaultUrlCrawler() - url = "https://httpbin.org/html" + url = "http://example.com/" results = await crawler.fetch_urls(url) assert len(results) == 1 assert isinstance(results, dict) diff --git a/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py b/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py index 19ffdc4ea..5db9b58ce 100644 --- a/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py +++ b/cognee/tests/integration/web_url_crawler/test_tavily_crawler.py @@ -11,7 +11,7 @@ skip_in_ci = pytest.mark.skipif( @skip_in_ci @pytest.mark.asyncio async def test_fetch(): - url = "https://httpbin.org/html" + url = "http://example.com/" results = await fetch_with_tavily(url) assert isinstance(results, dict) assert len(results) == 1 diff --git a/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py b/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py index cc8ae24d0..200f40a94 100644 --- a/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py +++ b/cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py @@ -14,7 +14,7 @@ async def test_url_saves_as_html_file(): await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage("https://httpbin.org/html") + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -42,7 +42,7 @@ async def test_saved_html_is_valid(): await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage("https://httpbin.org/html") + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) content = Path(file_path).read_text() @@ -68,7 +68,7 @@ async def test_add_url(): await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - await cognee.add("https://httpbin.org/html") + await cognee.add("http://example.com/") skip_in_ci = pytest.mark.skipif( @@ -84,7 +84,7 @@ async def test_add_url_with_tavily(): await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - await cognee.add("https://httpbin.org/html") + await cognee.add("http://example.com/") @pytest.mark.asyncio @@ -94,7 +94,7 @@ async def test_add_url_without_incremental_loading(): try: await cognee.add( - "https://httpbin.org/html", + "http://example.com/", incremental_loading=False, ) except Exception as e: @@ -108,7 +108,7 @@ async def test_add_url_with_incremental_loading(): try: await cognee.add( - "https://httpbin.org/html", + "http://example.com/", incremental_loading=True, ) except Exception as e: @@ -121,7 +121,7 @@ async def test_add_url_can_define_preferred_loader_as_list_of_str(): await cognee.prune.prune_system(metadata=True) await cognee.add( - "https://httpbin.org/html", + "http://example.com/", preferred_loaders=["beautiful_soup_loader"], ) @@ -140,7 +140,7 @@ async def test_add_url_with_extraction_rules(): try: await cognee.add( - "https://httpbin.org/html", + "http://example.com/", preferred_loaders={"beautiful_soup_loader": {"extraction_rules": extraction_rules}}, ) except Exception as e: @@ -159,7 +159,7 @@ async def test_loader_is_none_by_default(): } try: - original_file_path = await save_data_item_to_storage("https://httpbin.org/html") + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -190,7 +190,7 @@ async def test_beautiful_soup_loader_is_selected_loader_if_preferred_loader_prov } try: - original_file_path = await save_data_item_to_storage("https://httpbin.org/html") + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -217,7 +217,7 @@ async def test_beautiful_soup_loader_works_with_and_without_arguments(): await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage("https://httpbin.org/html") + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -253,7 +253,7 @@ async def test_beautiful_soup_loader_successfully_loads_file_if_required_args_pr await cognee.prune.prune_system(metadata=True) try: - original_file_path = await save_data_item_to_storage("https://httpbin.org/html") + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") file = Path(file_path) @@ -290,7 +290,7 @@ async def test_beautiful_soup_loads_file_successfully(): } try: - original_file_path = await save_data_item_to_storage("https://httpbin.org/html") + original_file_path = await save_data_item_to_storage("http://example.com/") file_path = get_data_file_path(original_file_path) assert file_path.endswith(".html") original_file = Path(file_path) From 204f9c2e4ad6dc706c03deed68adfbc4744ae6df Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 21 Nov 2025 16:20:19 +0100 Subject: [PATCH 21/24] fix: PR comment changes --- .env.template | 5 ++++- cognee/infrastructure/llm/config.py | 4 ++-- .../litellm_instructor/llm/get_llm_client.py | 12 ++++++------ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/.env.template b/.env.template index 376233b1f..61853b983 100644 --- a/.env.template +++ b/.env.template @@ -21,7 +21,10 @@ LLM_PROVIDER="openai" LLM_ENDPOINT="" LLM_API_VERSION="" LLM_MAX_TOKENS="16384" -LLM_INSTRUCTOR_MODE="json_schema_mode" # this mode is used for gpt-5 models +# Instructor's modes determine how structured data is requested from and extracted from LLM responses +# You can change this type (i.e. mode) via this env variable +# Each LLM has its own default value, e.g. gpt-5 models have "json_schema_mode" +LLM_INSTRUCTOR_MODE="" EMBEDDING_PROVIDER="openai" EMBEDDING_MODEL="openai/text-embedding-3-large" diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index c87054ff6..2e300dc0c 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -38,7 +38,7 @@ class LLMConfig(BaseSettings): """ structured_output_framework: str = "instructor" - llm_instructor_mode: Optional[str] = None + llm_instructor_mode: str = "" llm_provider: str = "openai" llm_model: str = "openai/gpt-5-mini" llm_endpoint: str = "" @@ -182,7 +182,7 @@ class LLMConfig(BaseSettings): instance. """ return { - "llm_instructor_mode": self.llm_instructor_mode, + "llm_instructor_mode": self.llm_instructor_mode.lower(), "provider": self.llm_provider, "model": self.llm_model, "endpoint": self.llm_endpoint, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index 537eda1b2..6ab3b91ad 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -81,7 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, transcription_model=llm_config.transcription_model, max_completion_tokens=max_completion_tokens, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), streaming=llm_config.llm_streaming, fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, @@ -102,7 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Ollama", max_completion_tokens=max_completion_tokens, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.ANTHROPIC: @@ -113,7 +113,7 @@ def get_llm_client(raise_api_key_error: bool = True): return AnthropicAdapter( max_completion_tokens=max_completion_tokens, model=llm_config.llm_model, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.CUSTOM: @@ -130,7 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Custom", max_completion_tokens=max_completion_tokens, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, fallback_model=llm_config.fallback_model, @@ -150,7 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True): max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, api_version=llm_config.llm_api_version, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.MISTRAL: @@ -166,7 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.MISTRAL: From af8c55e82bd5dcefb526ffc106f3dcbc1a881a24 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 24 Nov 2025 16:16:47 +0100 Subject: [PATCH 22/24] version: 0.5.0.dev0 --- poetry.lock | 62 +++++++++++++++++++++----------------------------- pyproject.toml | 2 +- uv.lock | 10 +++++++- 3 files changed, 36 insertions(+), 38 deletions(-) diff --git a/poetry.lock b/poetry.lock index 67de51633..0736a7bb7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "accelerate" @@ -1231,12 +1231,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -groups = ["main", "dev"] +groups = ["main"] +markers = "(platform_system == \"Windows\" or extra == \"llama-index\" or extra == \"dev\" or extra == \"chromadb\" or sys_platform == \"win32\") and (platform_system == \"Windows\" or os_name == \"nt\" or extra == \"llama-index\" or extra == \"dev\" or sys_platform == \"win32\")" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -markers = {main = "(platform_system == \"Windows\" or extra == \"llama-index\" or extra == \"dev\" or extra == \"chromadb\" or sys_platform == \"win32\") and (platform_system == \"Windows\" or os_name == \"nt\" or extra == \"llama-index\" or extra == \"dev\" or sys_platform == \"win32\")", dev = "sys_platform == \"win32\""} [[package]] name = "coloredlogs" @@ -2347,7 +2347,7 @@ version = "1.3.0" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" -groups = ["main", "dev"] +groups = ["main"] markers = "python_version == \"3.10\"" files = [ {file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"}, @@ -3732,14 +3732,14 @@ type = ["pytest-mypy"] name = "iniconfig" version = "2.1.0" description = "brain-dead simple config-ini parsing" -optional = false +optional = true python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"deepeval\" or extra == \"dev\"" files = [ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, ] -markers = {main = "extra == \"deepeval\" or extra == \"dev\""} [[package]] name = "instructor" @@ -4196,8 +4196,6 @@ groups = ["main"] markers = "extra == \"dlt\"" files = [ {file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"}, - {file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"}, - {file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"}, ] [package.dependencies] @@ -7634,7 +7632,7 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -8289,14 +8287,14 @@ kaleido = ["kaleido (>=1.0.0)"] name = "pluggy" version = "1.6.0" description = "plugin and hook calling mechanisms for python" -optional = false +optional = true python-versions = ">=3.9" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"deepeval\" or extra == \"dev\" or extra == \"dlt\" or extra == \"docling\"" files = [ {file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"}, {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"}, ] -markers = {main = "extra == \"deepeval\" or extra == \"dev\" or extra == \"dlt\" or extra == \"docling\""} [package.extras] dev = ["pre-commit", "tox"] @@ -8656,7 +8654,6 @@ files = [ {file = "psycopg2-2.9.10-cp311-cp311-win_amd64.whl", hash = "sha256:0435034157049f6846e95103bd8f5a668788dd913a7c30162ca9503fdf542cb4"}, {file = "psycopg2-2.9.10-cp312-cp312-win32.whl", hash = "sha256:65a63d7ab0e067e2cdb3cf266de39663203d38d6a8ed97f5ca0cb315c73fe067"}, {file = "psycopg2-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:4a579d6243da40a7b3182e0430493dbd55950c493d8c68f4eec0b302f6bbf20e"}, - {file = "psycopg2-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:91fd603a2155da8d0cfcdbf8ab24a2d54bca72795b90d2a3ed2b6da8d979dee2"}, {file = "psycopg2-2.9.10-cp39-cp39-win32.whl", hash = "sha256:9d5b3b94b79a844a986d029eee38998232451119ad653aea42bb9220a8c5066b"}, {file = "psycopg2-2.9.10-cp39-cp39-win_amd64.whl", hash = "sha256:88138c8dedcbfa96408023ea2b0c369eda40fe5d75002c0964c78f46f11fa442"}, {file = "psycopg2-2.9.10.tar.gz", hash = "sha256:12ec0b40b0273f95296233e8750441339298e6a572f7039da5b260e3c8b60e11"}, @@ -8718,7 +8715,6 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, - {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -9698,14 +9694,14 @@ files = [ name = "pytest" version = "7.4.4" description = "pytest: simple powerful testing with Python" -optional = false +optional = true python-versions = ">=3.7" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"deepeval\" or extra == \"dev\"" files = [ {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, ] -markers = {main = "extra == \"deepeval\" or extra == \"dev\""} [package.dependencies] colorama = {version = "*", markers = "sys_platform == \"win32\""} @@ -9792,21 +9788,6 @@ files = [ packaging = ">=17.1" pytest = ">=6.2" -[[package]] -name = "pytest-timeout" -version = "2.4.0" -description = "pytest plugin to abort hanging tests" -optional = false -python-versions = ">=3.7" -groups = ["dev"] -files = [ - {file = "pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2"}, - {file = "pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a"}, -] - -[package.dependencies] -pytest = ">=7.0.0" - [[package]] name = "pytest-xdist" version = "3.8.0" @@ -11656,7 +11637,9 @@ groups = ["main"] files = [ {file = "SQLAlchemy-2.0.43-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:21ba7a08a4253c5825d1db389d4299f64a100ef9800e4624c8bf70d8f136e6ed"}, {file = "SQLAlchemy-2.0.43-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11b9503fa6f8721bef9b8567730f664c5a5153d25e247aadc69247c4bc605227"}, + {file = "SQLAlchemy-2.0.43-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07097c0a1886c150ef2adba2ff7437e84d40c0f7dcb44a2c2b9c905ccfc6361c"}, {file = "SQLAlchemy-2.0.43-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cdeff998cb294896a34e5b2f00e383e7c5c4ef3b4bfa375d9104723f15186443"}, + {file = "SQLAlchemy-2.0.43-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:bcf0724a62a5670e5718957e05c56ec2d6850267ea859f8ad2481838f889b42c"}, {file = "SQLAlchemy-2.0.43-cp37-cp37m-win32.whl", hash = "sha256:c697575d0e2b0a5f0433f679bda22f63873821d991e95a90e9e52aae517b2e32"}, {file = "SQLAlchemy-2.0.43-cp37-cp37m-win_amd64.whl", hash = "sha256:d34c0f6dbefd2e816e8f341d0df7d4763d382e3f452423e752ffd1e213da2512"}, {file = "sqlalchemy-2.0.43-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70322986c0c699dca241418fcf18e637a4369e0ec50540a2b907b184c8bca069"}, @@ -11691,12 +11674,20 @@ files = [ {file = "sqlalchemy-2.0.43-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9df7126fd9db49e3a5a3999442cc67e9ee8971f3cb9644250107d7296cb2a164"}, {file = "sqlalchemy-2.0.43-cp313-cp313-win32.whl", hash = "sha256:7f1ac7828857fcedb0361b48b9ac4821469f7694089d15550bbcf9ab22564a1d"}, {file = "sqlalchemy-2.0.43-cp313-cp313-win_amd64.whl", hash = "sha256:971ba928fcde01869361f504fcff3b7143b47d30de188b11c6357c0505824197"}, + {file = "sqlalchemy-2.0.43-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4e6aeb2e0932f32950cf56a8b4813cb15ff792fc0c9b3752eaf067cfe298496a"}, + {file = "sqlalchemy-2.0.43-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:61f964a05356f4bca4112e6334ed7c208174511bd56e6b8fc86dad4d024d4185"}, {file = "sqlalchemy-2.0.43-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46293c39252f93ea0910aababa8752ad628bcce3a10d3f260648dd472256983f"}, + {file = "sqlalchemy-2.0.43-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:136063a68644eca9339d02e6693932116f6a8591ac013b0014479a1de664e40a"}, {file = "sqlalchemy-2.0.43-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6e2bf13d9256398d037fef09fd8bf9b0bf77876e22647d10761d35593b9ac547"}, + {file = "sqlalchemy-2.0.43-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:44337823462291f17f994d64282a71c51d738fc9ef561bf265f1d0fd9116a782"}, {file = "sqlalchemy-2.0.43-cp38-cp38-win32.whl", hash = "sha256:13194276e69bb2af56198fef7909d48fd34820de01d9c92711a5fa45497cc7ed"}, {file = "sqlalchemy-2.0.43-cp38-cp38-win_amd64.whl", hash = "sha256:334f41fa28de9f9be4b78445e68530da3c5fa054c907176460c81494f4ae1f5e"}, + {file = "sqlalchemy-2.0.43-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ceb5c832cc30663aeaf5e39657712f4c4241ad1f638d487ef7216258f6d41fe7"}, + {file = "sqlalchemy-2.0.43-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:11f43c39b4b2ec755573952bbcc58d976779d482f6f832d7f33a8d869ae891bf"}, {file = "sqlalchemy-2.0.43-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:413391b2239db55be14fa4223034d7e13325a1812c8396ecd4f2c08696d5ccad"}, + {file = "sqlalchemy-2.0.43-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c379e37b08c6c527181a397212346be39319fb64323741d23e46abd97a400d34"}, {file = "sqlalchemy-2.0.43-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03d73ab2a37d9e40dec4984d1813d7878e01dbdc742448d44a7341b7a9f408c7"}, + {file = "sqlalchemy-2.0.43-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:8cee08f15d9e238ede42e9bbc1d6e7158d0ca4f176e4eab21f88ac819ae3bd7b"}, {file = "sqlalchemy-2.0.43-cp39-cp39-win32.whl", hash = "sha256:b3edaec7e8b6dc5cd94523c6df4f294014df67097c8217a89929c99975811414"}, {file = "sqlalchemy-2.0.43-cp39-cp39-win_amd64.whl", hash = "sha256:227119ce0a89e762ecd882dc661e0aa677a690c914e358f0dd8932a2e8b2765b"}, {file = "sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc"}, @@ -12065,7 +12056,7 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main"] markers = "python_version == \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, @@ -12537,12 +12528,11 @@ version = "4.15.0" description = "Backported and Experimental Type Hints for Python 3.9+" optional = false python-versions = ">=3.9" -groups = ["main", "dev"] +groups = ["main"] files = [ {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, ] -markers = {dev = "python_version == \"3.10\""} [[package]] name = "typing-inspect" diff --git a/pyproject.toml b/pyproject.toml index 2436911e8..a9b895dfb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "cognee" -version = "0.3.9" +version = "0.5.0.dev0" description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." authors = [ { name = "Vasilije Markovic" }, diff --git a/uv.lock b/uv.lock index 8c35a3366..cc66c3d7e 100644 --- a/uv.lock +++ b/uv.lock @@ -929,7 +929,7 @@ wheels = [ [[package]] name = "cognee" -version = "0.3.9" +version = "0.5.0.dev0" source = { editable = "." } dependencies = [ { name = "aiofiles" }, @@ -2560,6 +2560,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/91/ae2eb6b7979e2f9b035a9f612cf70f1bf54aad4e1d125129bef1eae96f19/greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d", size = 584358, upload-time = "2025-08-07T13:18:23.708Z" }, { url = "https://files.pythonhosted.org/packages/f7/85/433de0c9c0252b22b16d413c9407e6cb3b41df7389afc366ca204dbc1393/greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5", size = 1113550, upload-time = "2025-08-07T13:42:37.467Z" }, { url = "https://files.pythonhosted.org/packages/a1/8d/88f3ebd2bc96bf7747093696f4335a0a8a4c5acfcf1b757717c0d2474ba3/greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f", size = 1137126, upload-time = "2025-08-07T13:18:20.239Z" }, + { url = "https://files.pythonhosted.org/packages/f1/29/74242b7d72385e29bcc5563fba67dad94943d7cd03552bac320d597f29b2/greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7", size = 1544904, upload-time = "2025-11-04T12:42:04.763Z" }, + { url = "https://files.pythonhosted.org/packages/c8/e2/1572b8eeab0f77df5f6729d6ab6b141e4a84ee8eb9bc8c1e7918f94eda6d/greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8", size = 1611228, upload-time = "2025-11-04T12:42:08.423Z" }, { url = "https://files.pythonhosted.org/packages/d6/6f/b60b0291d9623c496638c582297ead61f43c4b72eef5e9c926ef4565ec13/greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c", size = 298654, upload-time = "2025-08-07T13:50:00.469Z" }, { url = "https://files.pythonhosted.org/packages/a4/de/f28ced0a67749cac23fecb02b694f6473f47686dff6afaa211d186e2ef9c/greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2", size = 272305, upload-time = "2025-08-07T13:15:41.288Z" }, { url = "https://files.pythonhosted.org/packages/09/16/2c3792cba130000bf2a31c5272999113f4764fd9d874fb257ff588ac779a/greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246", size = 632472, upload-time = "2025-08-07T13:42:55.044Z" }, @@ -2569,6 +2571,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684, upload-time = "2025-08-07T13:18:25.164Z" }, { url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647, upload-time = "2025-08-07T13:42:38.655Z" }, { url = "https://files.pythonhosted.org/packages/3f/cc/b07000438a29ac5cfb2194bfc128151d52f333cee74dd7dfe3fb733fc16c/greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa", size = 1142073, upload-time = "2025-08-07T13:18:21.737Z" }, + { url = "https://files.pythonhosted.org/packages/67/24/28a5b2fa42d12b3d7e5614145f0bd89714c34c08be6aabe39c14dd52db34/greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c", size = 1548385, upload-time = "2025-11-04T12:42:11.067Z" }, + { url = "https://files.pythonhosted.org/packages/6a/05/03f2f0bdd0b0ff9a4f7b99333d57b53a7709c27723ec8123056b084e69cd/greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5", size = 1613329, upload-time = "2025-11-04T12:42:12.928Z" }, { url = "https://files.pythonhosted.org/packages/d8/0f/30aef242fcab550b0b3520b8e3561156857c94288f0332a79928c31a52cf/greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9", size = 299100, upload-time = "2025-08-07T13:44:12.287Z" }, { url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" }, { url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" }, @@ -2578,6 +2582,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" }, { url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" }, { url = "https://files.pythonhosted.org/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142, upload-time = "2025-08-07T13:18:22.981Z" }, + { url = "https://files.pythonhosted.org/packages/27/45/80935968b53cfd3f33cf99ea5f08227f2646e044568c9b1555b58ffd61c2/greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0", size = 1564846, upload-time = "2025-11-04T12:42:15.191Z" }, + { url = "https://files.pythonhosted.org/packages/69/02/b7c30e5e04752cb4db6202a3858b149c0710e5453b71a3b2aec5d78a1aab/greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d", size = 1633814, upload-time = "2025-11-04T12:42:17.175Z" }, { url = "https://files.pythonhosted.org/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899, upload-time = "2025-08-07T13:38:53.448Z" }, { url = "https://files.pythonhosted.org/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" }, { url = "https://files.pythonhosted.org/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" }, @@ -2587,6 +2593,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" }, { url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" }, { url = "https://files.pythonhosted.org/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" }, + { url = "https://files.pythonhosted.org/packages/1c/53/f9c440463b3057485b8594d7a638bed53ba531165ef0ca0e6c364b5cc807/greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b", size = 1564759, upload-time = "2025-11-04T12:42:19.395Z" }, + { url = "https://files.pythonhosted.org/packages/47/e4/3bb4240abdd0a8d23f4f88adec746a3099f0d86bfedb623f063b2e3b4df0/greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929", size = 1634288, upload-time = "2025-11-04T12:42:21.174Z" }, { url = "https://files.pythonhosted.org/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685, upload-time = "2025-08-07T13:24:38.824Z" }, ] From c2c64a417c4805c5d5aeb59b4e5f9519b729ee85 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 24 Nov 2025 17:44:51 +0100 Subject: [PATCH 23/24] fix: fixes ontology api endpoint tests + poetry lock(#1824) ## Description This PR fixes the failing CI tests related to the new ontology api endpoint. ## Type of Change - [x] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- .../tests/unit/api/test_ontology_endpoint.py | 10 +++- poetry.lock | 52 +++++++++++++------ 2 files changed, 45 insertions(+), 17 deletions(-) diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py index d53c5ab44..af3a4d90e 100644 --- a/cognee/tests/unit/api/test_ontology_endpoint.py +++ b/cognee/tests/unit/api/test_ontology_endpoint.py @@ -25,7 +25,10 @@ def mock_user(): def mock_default_user(): """Mock default user for testing.""" return SimpleNamespace( - id=uuid.uuid4(), email="default@example.com", is_active=True, tenant_id=uuid.uuid4() + id=str(uuid.uuid4()), + email="default@example.com", + is_active=True, + tenant_id=str(uuid.uuid4()), ) @@ -108,6 +111,7 @@ def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_ """Test uploading multiple ontology files in single request""" import io + mock_get_default_user.return_value = mock_default_user # Create mock files file1_content = b"" file2_content = b"" @@ -137,6 +141,7 @@ def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_defa import io import json + mock_get_default_user.return_value = mock_default_user file_content = b"" files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))] @@ -173,6 +178,7 @@ def test_complete_multifile_workflow(mock_get_default_user, client, mock_default import io import json + mock_get_default_user.return_value = mock_default_user # Step 1: Upload multiple ontologies file1_content = b""" =1.0.0)"] name = "pluggy" version = "1.6.0" description = "plugin and hook calling mechanisms for python" -optional = true +optional = false python-versions = ">=3.9" -groups = ["main"] -markers = "extra == \"deepeval\" or extra == \"dev\" or extra == \"dlt\" or extra == \"docling\"" +groups = ["main", "dev"] files = [ {file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"}, {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"}, ] +markers = {main = "extra == \"deepeval\" or extra == \"dev\" or extra == \"dlt\" or extra == \"docling\""} [package.extras] dev = ["pre-commit", "tox"] @@ -8654,6 +8656,7 @@ files = [ {file = "psycopg2-2.9.10-cp311-cp311-win_amd64.whl", hash = "sha256:0435034157049f6846e95103bd8f5a668788dd913a7c30162ca9503fdf542cb4"}, {file = "psycopg2-2.9.10-cp312-cp312-win32.whl", hash = "sha256:65a63d7ab0e067e2cdb3cf266de39663203d38d6a8ed97f5ca0cb315c73fe067"}, {file = "psycopg2-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:4a579d6243da40a7b3182e0430493dbd55950c493d8c68f4eec0b302f6bbf20e"}, + {file = "psycopg2-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:91fd603a2155da8d0cfcdbf8ab24a2d54bca72795b90d2a3ed2b6da8d979dee2"}, {file = "psycopg2-2.9.10-cp39-cp39-win32.whl", hash = "sha256:9d5b3b94b79a844a986d029eee38998232451119ad653aea42bb9220a8c5066b"}, {file = "psycopg2-2.9.10-cp39-cp39-win_amd64.whl", hash = "sha256:88138c8dedcbfa96408023ea2b0c369eda40fe5d75002c0964c78f46f11fa442"}, {file = "psycopg2-2.9.10.tar.gz", hash = "sha256:12ec0b40b0273f95296233e8750441339298e6a572f7039da5b260e3c8b60e11"}, @@ -8715,6 +8718,7 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -9694,14 +9698,14 @@ files = [ name = "pytest" version = "7.4.4" description = "pytest: simple powerful testing with Python" -optional = true +optional = false python-versions = ">=3.7" -groups = ["main"] -markers = "extra == \"deepeval\" or extra == \"dev\"" +groups = ["main", "dev"] files = [ {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, ] +markers = {main = "extra == \"deepeval\" or extra == \"dev\""} [package.dependencies] colorama = {version = "*", markers = "sys_platform == \"win32\""} @@ -9788,6 +9792,21 @@ files = [ packaging = ">=17.1" pytest = ">=6.2" +[[package]] +name = "pytest-timeout" +version = "2.4.0" +description = "pytest plugin to abort hanging tests" +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2"}, + {file = "pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + [[package]] name = "pytest-xdist" version = "3.8.0" @@ -12056,7 +12075,7 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] markers = "python_version == \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, @@ -12528,11 +12547,12 @@ version = "4.15.0" description = "Backported and Experimental Type Hints for Python 3.9+" optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, ] +markers = {dev = "python_version == \"3.10\""} [[package]] name = "typing-inspect" From 508165e883aa9dc3232d8021daa94765439aded1 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:18:53 +0100 Subject: [PATCH 24/24] feature: Introduces wide subgraph search in graph completion and improves QA speed (#1736) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR introduces wide vector and graph structure filtering capabilities. With these changes, the graph completion retriever and all retrievers that inherit from it will now filter relevant vector elements and subgraphs based on the query. This improvement significantly increases search speed for large graphs while maintaining—and in some cases slightly improving—accuracy. Changes in This PR: -Introduced new wide_search_top_k parameter: Controls the initial search space size -Added graph adapter level filtering method: Enables relevant subgraph filtering while maintaining backward compatibility. For community or custom graph adapters that don't implement this method, the system gracefully falls back to the original search behavior. -Updated modal dashboard and evaluation framework: Fixed compatibility issues. Added comprehensive unit tests: Introduced unit tests for brute_force_triplet_search (previously untested) and expanded the CogneeGraph test suite. Integration tests: Existing integration tests verify end-to-end search functionality (no changes required). Acceptance Criteria and Testing To verify the new search behavior, run search queries with different wide_search_top_k parameters while logging is enabled: None: Triggers a full graph search (default behavior) 1: Projects a minimal subgraph (demonstrates maximum filtering) Custom values: Test intermediate levels of filtering Internal Testing and results: Performance and accuracy benchmarks are available upon request. The implementation demonstrates measurable improvements in query latency for large graphs without sacrificing result quality. ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [x] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) None ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Pavel Zorin --- cognee/api/v1/search/search.py | 4 + cognee/eval_framework/Dockerfile | 29 + .../answer_generation_executor.py | 10 + .../run_question_answering_module.py | 2 +- cognee/eval_framework/eval_config.py | 4 +- cognee/eval_framework/modal_run_eval.py | 46 +- .../databases/graph/graph_db_interface.py | 15 + .../databases/graph/kuzu/adapter.py | 97 +++ .../databases/graph/neo4j_driver/adapter.py | 57 ++ .../modules/graph/cognee_graph/CogneeGraph.py | 110 +++- .../graph/cognee_graph/CogneeGraphElements.py | 11 +- ..._completion_context_extension_retriever.py | 4 + .../graph_completion_cot_retriever.py | 4 + .../retrieval/graph_completion_retriever.py | 10 + .../graph_summary_completion_retriever.py | 4 + .../modules/retrieval/temporal_retriever.py | 4 + .../utils/brute_force_triplet_search.py | 45 +- .../search/methods/get_search_type_tools.py | 30 +- .../methods/no_access_control_search.py | 4 + cognee/modules/search/methods/search.py | 21 + .../graph/cognee_graph_elements_test.py | 4 +- .../unit/modules/graph/cognee_graph_test.py | 455 ++++++++++++++ .../test_brute_force_triplet_search.py | 582 ++++++++++++++++++ 23 files changed, 1482 insertions(+), 70 deletions(-) create mode 100644 cognee/eval_framework/Dockerfile create mode 100644 cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index d4e5fbbe6..354331c57 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -31,6 +31,8 @@ async def search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[List[SearchResult], CombinedSearchResult]: """ Search and query the knowledge graph for insights, information, and connections. @@ -200,6 +202,8 @@ async def search( only_context=only_context, use_combined_context=use_combined_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) return filtered_search_results diff --git a/cognee/eval_framework/Dockerfile b/cognee/eval_framework/Dockerfile new file mode 100644 index 000000000..e83be3da4 --- /dev/null +++ b/cognee/eval_framework/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.11-slim + +# Set environment variables +ENV PIP_NO_CACHE_DIR=true +ENV PATH="${PATH}:/root/.poetry/bin" +ENV PYTHONPATH=/app +ENV SKIP_MIGRATIONS=true + +# System dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + libpq-dev \ + git \ + curl \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY pyproject.toml poetry.lock README.md /app/ + +RUN pip install poetry + +RUN poetry config virtualenvs.create false + +RUN poetry install --extras distributed --extras evals --extras deepeval --no-root + +COPY cognee/ /app/cognee +COPY distributed/ /app/distributed diff --git a/cognee/eval_framework/answer_generation/answer_generation_executor.py b/cognee/eval_framework/answer_generation/answer_generation_executor.py index 6f166657e..29b3ede68 100644 --- a/cognee/eval_framework/answer_generation/answer_generation_executor.py +++ b/cognee/eval_framework/answer_generation/answer_generation_executor.py @@ -35,6 +35,16 @@ class AnswerGeneratorExecutor: retrieval_context = await retriever.get_context(query_text) search_results = await retriever.get_completion(query_text, retrieval_context) + ############ + #:TODO This is a quick fix until we don't structure retriever results properly but lets not leave it like this...this is needed now due to the changed combined retriever structure.. + if isinstance(retrieval_context, list): + retrieval_context = await retriever.convert_retrieved_objects_to_context( + triplets=retrieval_context + ) + + if isinstance(search_results, str): + search_results = [search_results] + ############# answer = { "question": query_text, "answer": search_results[0], diff --git a/cognee/eval_framework/answer_generation/run_question_answering_module.py b/cognee/eval_framework/answer_generation/run_question_answering_module.py index d0a2ebe1e..6b55d84b2 100644 --- a/cognee/eval_framework/answer_generation/run_question_answering_module.py +++ b/cognee/eval_framework/answer_generation/run_question_answering_module.py @@ -35,7 +35,7 @@ async def create_and_insert_answers_table(questions_payload): async def run_question_answering( - params: dict, system_prompt="answer_simple_question.txt", top_k: Optional[int] = None + params: dict, system_prompt="answer_simple_question_benchmark.txt", top_k: Optional[int] = None ) -> List[dict]: if params.get("answering_questions"): logger.info("Question answering started...") diff --git a/cognee/eval_framework/eval_config.py b/cognee/eval_framework/eval_config.py index 6edcc0454..9e6f26688 100644 --- a/cognee/eval_framework/eval_config.py +++ b/cognee/eval_framework/eval_config.py @@ -14,7 +14,7 @@ class EvalConfig(BaseSettings): # Question answering params answering_questions: bool = True - qa_engine: str = "cognee_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension' + qa_engine: str = "cognee_graph_completion" # Options: 'cognee_completion' or 'cognee_graph_completion' or 'cognee_graph_completion_cot' or 'cognee_graph_completion_context_extension' # Evaluation params evaluating_answers: bool = True @@ -25,7 +25,7 @@ class EvalConfig(BaseSettings): "EM", "f1", ] # Use only 'correctness' for DirectLLM - deepeval_model: str = "gpt-5-mini" + deepeval_model: str = "gpt-4o-mini" # Metrics params calculate_metrics: bool = True diff --git a/cognee/eval_framework/modal_run_eval.py b/cognee/eval_framework/modal_run_eval.py index aca2686a5..bc2ff77c5 100644 --- a/cognee/eval_framework/modal_run_eval.py +++ b/cognee/eval_framework/modal_run_eval.py @@ -2,7 +2,6 @@ import modal import os import asyncio import datetime -import hashlib import json from cognee.shared.logging_utils import get_logger from cognee.eval_framework.eval_config import EvalConfig @@ -10,6 +9,9 @@ from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_b from cognee.eval_framework.answer_generation.run_question_answering_module import ( run_question_answering, ) +import pathlib +from os import path +from modal import Image from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation from cognee.eval_framework.metrics_dashboard import create_dashboard @@ -38,22 +40,19 @@ def read_and_combine_metrics(eval_params: dict) -> dict: app = modal.App("modal-run-eval") -image = ( - modal.Image.from_dockerfile(path="Dockerfile_modal", force_build=False) - .copy_local_file("pyproject.toml", "pyproject.toml") - .copy_local_file("poetry.lock", "poetry.lock") - .env( - { - "ENV": os.getenv("ENV"), - "LLM_API_KEY": os.getenv("LLM_API_KEY"), - "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"), - } - ) - .pip_install("protobuf", "h2", "deepeval", "gdown", "plotly") +image = Image.from_dockerfile( + path=pathlib.Path(path.join(path.dirname(__file__), "Dockerfile")).resolve(), + force_build=False, +).add_local_python_source("cognee") + + +@app.function( + image=image, + max_containers=10, + timeout=86400, + volumes={"/data": vol}, + secrets=[modal.Secret.from_name("eval_secrets")], ) - - -@app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol}) async def modal_run_eval(eval_params=None): """Runs evaluation pipeline and returns combined metrics results.""" if eval_params is None: @@ -105,18 +104,7 @@ async def main(): configs = [ EvalConfig( task_getter_type="Default", - number_of_samples_in_corpus=10, - benchmark="HotPotQA", - qa_engine="cognee_graph_completion", - building_corpus_from_scratch=True, - answering_questions=True, - evaluating_answers=True, - calculate_metrics=True, - dashboard=True, - ), - EvalConfig( - task_getter_type="Default", - number_of_samples_in_corpus=10, + number_of_samples_in_corpus=25, benchmark="TwoWikiMultiHop", qa_engine="cognee_graph_completion", building_corpus_from_scratch=True, @@ -127,7 +115,7 @@ async def main(): ), EvalConfig( task_getter_type="Default", - number_of_samples_in_corpus=10, + number_of_samples_in_corpus=25, benchmark="Musique", qa_engine="cognee_graph_completion", building_corpus_from_scratch=True, diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 67df1a27c..8f8c96e79 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -398,3 +398,18 @@ class GraphDBInterface(ABC): - node_id (Union[str, UUID]): Unique identifier of the node for which to retrieve connections. """ raise NotImplementedError + + @abstractmethod + async def get_filtered_graph_data( + self, attribute_filters: List[Dict[str, List[Union[str, int]]]] + ) -> Tuple[List[Node], List[EdgeData]]: + """ + Retrieve nodes and edges filtered by the provided attribute criteria. + + Parameters: + ----------- + + - attribute_filters: A list of dictionaries where keys are attribute names and values + are lists of attribute values to filter by. + """ + raise NotImplementedError diff --git a/cognee/infrastructure/databases/graph/kuzu/adapter.py b/cognee/infrastructure/databases/graph/kuzu/adapter.py index 8dd160665..9dbc9c1bc 100644 --- a/cognee/infrastructure/databases/graph/kuzu/adapter.py +++ b/cognee/infrastructure/databases/graph/kuzu/adapter.py @@ -12,6 +12,7 @@ from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor from typing import Dict, Any, List, Union, Optional, Tuple, Type +from cognee.exceptions import CogneeValidationError from cognee.shared.logging_utils import get_logger from cognee.infrastructure.utils.run_sync import run_sync from cognee.infrastructure.files.storage import get_file_storage @@ -1186,6 +1187,11 @@ class KuzuAdapter(GraphDBInterface): A tuple with two elements: a list of tuples of (node_id, properties) and a list of tuples of (source_id, target_id, relationship_name, properties). """ + + import time + + start_time = time.time() + try: nodes_query = """ MATCH (n:Node) @@ -1249,6 +1255,11 @@ class KuzuAdapter(GraphDBInterface): }, ) ) + + retrieval_time = time.time() - start_time + logger.info( + f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds" + ) return formatted_nodes, formatted_edges except Exception as e: logger.error(f"Failed to get graph data: {e}") @@ -1417,6 +1428,92 @@ class KuzuAdapter(GraphDBInterface): formatted_edges.append((source_id, target_id, rel_type, props)) return formatted_nodes, formatted_edges + async def get_id_filtered_graph_data(self, target_ids: list[str]): + """ + Retrieve graph data filtered by specific node IDs, including their direct neighbors + and only edges where one endpoint matches those IDs. + + Returns: + nodes: List[dict] -> Each dict includes "id" and all node properties + edges: List[dict] -> Each dict includes "source", "target", "type", "properties" + """ + import time + + start_time = time.time() + + try: + if not target_ids: + logger.warning("No target IDs provided for ID-filtered graph retrieval.") + return [], [] + + if not all(isinstance(x, str) for x in target_ids): + raise CogneeValidationError("target_ids must be a list of strings") + + query = """ + MATCH (n:Node)-[r]->(m:Node) + WHERE n.id IN $target_ids OR m.id IN $target_ids + RETURN n.id, { + name: n.name, + type: n.type, + properties: n.properties + }, m.id, { + name: m.name, + type: m.type, + properties: m.properties + }, r.relationship_name, r.properties + """ + + result = await self.query(query, {"target_ids": target_ids}) + + if not result: + logger.info("No data returned for the supplied IDs") + return [], [] + + nodes_dict = {} + edges = [] + + for n_id, n_props, m_id, m_props, r_type, r_props_raw in result: + if n_props.get("properties"): + try: + additional_props = json.loads(n_props["properties"]) + n_props.update(additional_props) + del n_props["properties"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse properties JSON for node {n_id}") + + if m_props.get("properties"): + try: + additional_props = json.loads(m_props["properties"]) + m_props.update(additional_props) + del m_props["properties"] + except json.JSONDecodeError: + logger.warning(f"Failed to parse properties JSON for node {m_id}") + + nodes_dict[n_id] = (n_id, n_props) + nodes_dict[m_id] = (m_id, m_props) + + edge_props = {} + if r_props_raw: + try: + edge_props = json.loads(r_props_raw) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse edge properties for {n_id}->{m_id}") + + source_id = edge_props.get("source_node_id", n_id) + target_id = edge_props.get("target_node_id", m_id) + edges.append((source_id, target_id, r_type, edge_props)) + + retrieval_time = time.time() - start_time + logger.info( + f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s" + ) + + return list(nodes_dict.values()), edges + + except Exception as e: + logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}") + raise + async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]: """ Get metrics on graph structure and connectivity. diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 6216e107e..f3bb8e173 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -964,6 +964,63 @@ class Neo4jAdapter(GraphDBInterface): logger.error(f"Error during graph data retrieval: {str(e)}") raise + async def get_id_filtered_graph_data(self, target_ids: list[str]): + """ + Retrieve graph data filtered by specific node IDs, including their direct neighbors + and only edges where one endpoint matches those IDs. + + This version uses a single Cypher query for efficiency. + """ + import time + + start_time = time.time() + + try: + if not target_ids: + logger.warning("No target IDs provided for ID-filtered graph retrieval.") + return [], [] + + query = """ + MATCH ()-[r]-() + WHERE startNode(r).id IN $target_ids + OR endNode(r).id IN $target_ids + WITH DISTINCT r, startNode(r) AS a, endNode(r) AS b + RETURN + properties(a) AS n_properties, + properties(b) AS m_properties, + type(r) AS type, + properties(r) AS properties + """ + + result = await self.query(query, {"target_ids": target_ids}) + + nodes_dict = {} + edges = [] + + for record in result: + n_props = record["n_properties"] + m_props = record["m_properties"] + r_props = record["properties"] + r_type = record["type"] + + nodes_dict[n_props["id"]] = (n_props["id"], n_props) + nodes_dict[m_props["id"]] = (m_props["id"], m_props) + + source_id = r_props.get("source_node_id", n_props["id"]) + target_id = r_props.get("target_node_id", m_props["id"]) + edges.append((source_id, target_id, r_type, r_props)) + + retrieval_time = time.time() - start_time + logger.info( + f"ID-filtered retrieval: {len(nodes_dict)} nodes and {len(edges)} edges in {retrieval_time:.2f}s" + ) + + return list(nodes_dict.values()), edges + + except Exception as e: + logger.error(f"Error during ID-filtered graph data retrieval: {str(e)}") + raise + async def get_nodeset_subgraph( self, node_type: Type[Any], node_name: List[str] ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index cb7562422..2e0b82e8d 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph): def get_edges(self) -> List[Edge]: return self.edges + async def _get_nodeset_subgraph( + self, + adapter, + node_type, + node_name, + ): + """Retrieve subgraph based on node type and name.""" + logger.info("Retrieving graph filtered by node type and node name (NodeSet).") + nodes_data, edges_data = await adapter.get_nodeset_subgraph( + node_type=node_type, node_name=node_name + ) + if not nodes_data or not edges_data: + raise EntityNotFoundError( + message="Nodeset does not exist, or empty nodeset projected from the database." + ) + return nodes_data, edges_data + + async def _get_full_or_id_filtered_graph( + self, + adapter, + relevant_ids_to_filter, + ): + """Retrieve full or ID-filtered graph with fallback.""" + if relevant_ids_to_filter is None: + logger.info("Retrieving full graph.") + nodes_data, edges_data = await adapter.get_graph_data() + if not nodes_data or not edges_data: + raise EntityNotFoundError(message="Empty graph projected from the database.") + return nodes_data, edges_data + + get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data) + if getattr(adapter.__class__, "get_id_filtered_graph_data", None): + logger.info("Retrieving ID-filtered graph from database.") + nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter) + else: + logger.info("Retrieving full graph from database.") + nodes_data, edges_data = await get_graph_data_fn() + if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data): + logger.warning( + "Id filtered graph returned empty, falling back to full graph retrieval." + ) + logger.info("Retrieving full graph") + nodes_data, edges_data = await adapter.get_graph_data() + + if not nodes_data or not edges_data: + raise EntityNotFoundError("Empty graph projected from the database.") + return nodes_data, edges_data + + async def _get_filtered_graph( + self, + adapter, + memory_fragment_filter, + ): + """Retrieve graph filtered by attributes.""" + logger.info("Retrieving graph filtered by memory fragment") + nodes_data, edges_data = await adapter.get_filtered_graph_data( + attribute_filters=memory_fragment_filter + ) + if not nodes_data or not edges_data: + raise EntityNotFoundError(message="Empty filtered graph projected from the database.") + return nodes_data, edges_data + async def project_graph_from_db( self, adapter: Union[GraphDBInterface], @@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph): memory_fragment_filter=[], node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + relevant_ids_to_filter: Optional[List[str]] = None, + triplet_distance_penalty: float = 3.5, ) -> None: if node_dimension < 1 or edge_dimension < 1: raise InvalidDimensionsError() try: + if node_type is not None and node_name not in [None, [], ""]: + nodes_data, edges_data = await self._get_nodeset_subgraph( + adapter, node_type, node_name + ) + elif len(memory_fragment_filter) == 0: + nodes_data, edges_data = await self._get_full_or_id_filtered_graph( + adapter, relevant_ids_to_filter + ) + else: + nodes_data, edges_data = await self._get_filtered_graph( + adapter, memory_fragment_filter + ) + import time start_time = time.time() - - # Determine projection strategy - if node_type is not None and node_name not in [None, [], ""]: - nodes_data, edges_data = await adapter.get_nodeset_subgraph( - node_type=node_type, node_name=node_name - ) - if not nodes_data or not edges_data: - raise EntityNotFoundError( - message="Nodeset does not exist, or empty nodetes projected from the database." - ) - elif len(memory_fragment_filter) == 0: - nodes_data, edges_data = await adapter.get_graph_data() - if not nodes_data or not edges_data: - raise EntityNotFoundError(message="Empty graph projected from the database.") - else: - nodes_data, edges_data = await adapter.get_filtered_graph_data( - attribute_filters=memory_fragment_filter - ) - if not nodes_data or not edges_data: - raise EntityNotFoundError( - message="Empty filtered graph projected from the database." - ) - # Process nodes for node_id, properties in nodes_data: node_attributes = {key: properties.get(key) for key in node_properties_to_project} - self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension)) + self.add_node( + Node( + str(node_id), + node_attributes, + dimension=node_dimension, + node_penalty=triplet_distance_penalty, + ) + ) # Process edges for source_id, target_id, relationship_type, properties in edges_data: @@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph): attributes=edge_attributes, directed=directed, dimension=edge_dimension, + edge_penalty=triplet_distance_penalty, ) self.add_edge(edge) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index 0ca9c4fb9..62ef8d9fd 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -20,13 +20,17 @@ class Node: status: np.ndarray def __init__( - self, node_id: str, attributes: Optional[Dict[str, Any]] = None, dimension: int = 1 + self, + node_id: str, + attributes: Optional[Dict[str, Any]] = None, + dimension: int = 1, + node_penalty: float = 3.5, ): if dimension <= 0: raise InvalidDimensionsError() self.id = node_id self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = float("inf") + self.attributes["vector_distance"] = node_penalty self.skeleton_neighbours = [] self.skeleton_edges = [] self.status = np.ones(dimension, dtype=int) @@ -105,13 +109,14 @@ class Edge: attributes: Optional[Dict[str, Any]] = None, directed: bool = True, dimension: int = 1, + edge_penalty: float = 3.5, ): if dimension <= 0: raise InvalidDimensionsError() self.node1 = node1 self.node2 = node2 self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = float("inf") + self.attributes["vector_distance"] = edge_penalty self.directed = directed self.status = np.ones(dimension, dtype=int) diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index b07d11fd2..fc49a139b 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -39,6 +39,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -48,6 +50,8 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) async def get_completion( diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index eb8f502cb..70fcb6cdb 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -65,6 +65,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -74,6 +76,8 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): node_type=node_type, node_name=node_name, save_interaction=save_interaction, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.validation_system_prompt_path = validation_system_prompt_path self.validation_user_prompt_path = validation_user_prompt_path diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index df77a11ac..89e9e47ce 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -47,6 +47,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): """Initialize retriever with prompt paths and search parameters.""" self.save_interaction = save_interaction @@ -54,8 +56,10 @@ class GraphCompletionRetriever(BaseGraphRetriever): self.system_prompt_path = system_prompt_path self.system_prompt = system_prompt self.top_k = top_k if top_k is not None else 5 + self.wide_search_top_k = wide_search_top_k self.node_type = node_type self.node_name = node_name + self.triplet_distance_penalty = triplet_distance_penalty async def resolve_edges_to_text(self, retrieved_edges: list) -> str: """ @@ -105,6 +109,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): collections=vector_index_collections or None, node_type=self.node_type, node_name=self.node_name, + wide_search_top_k=self.wide_search_top_k, + triplet_distance_penalty=self.triplet_distance_penalty, ) return found_triplets @@ -141,6 +147,10 @@ class GraphCompletionRetriever(BaseGraphRetriever): return triplets + async def convert_retrieved_objects_to_context(self, triplets: List[Edge]): + context = await self.resolve_edges_to_text(triplets) + return context + async def get_completion( self, query: str, diff --git a/cognee/modules/retrieval/graph_summary_completion_retriever.py b/cognee/modules/retrieval/graph_summary_completion_retriever.py index 051f39b22..e31ad126e 100644 --- a/cognee/modules/retrieval/graph_summary_completion_retriever.py +++ b/cognee/modules/retrieval/graph_summary_completion_retriever.py @@ -26,6 +26,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): """Initialize retriever with default prompt paths and search parameters.""" super().__init__( @@ -36,6 +38,8 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever): node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.summarize_prompt_path = summarize_prompt_path diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index f3da02c15..87d2ab009 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -47,6 +47,8 @@ class TemporalRetriever(GraphCompletionRetriever): top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ): super().__init__( user_prompt_path=user_prompt_path, @@ -54,6 +56,8 @@ class TemporalRetriever(GraphCompletionRetriever): top_k=top_k, node_type=node_type, node_name=node_name, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index f8bdbb97d..2f8a545f7 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -58,6 +58,8 @@ async def get_memory_fragment( properties_to_project: Optional[List[str]] = None, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + relevant_ids_to_filter: Optional[List[str]] = None, + triplet_distance_penalty: Optional[float] = 3.5, ) -> CogneeGraph: """Creates and initializes a CogneeGraph memory fragment with optional property projections.""" if properties_to_project is None: @@ -74,6 +76,8 @@ async def get_memory_fragment( edge_properties_to_project=["relationship_name", "edge_text"], node_type=node_type, node_name=node_name, + relevant_ids_to_filter=relevant_ids_to_filter, + triplet_distance_penalty=triplet_distance_penalty, ) except EntityNotFoundError: @@ -95,6 +99,8 @@ async def brute_force_triplet_search( memory_fragment: Optional[CogneeGraph] = None, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> List[Edge]: """ Performs a brute force search to retrieve the top triplets from the graph. @@ -107,6 +113,8 @@ async def brute_force_triplet_search( memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse. node_type: node type to filter node_name: node name to filter + wide_search_top_k (Optional[int]): Number of initial elements to retrieve from collections + triplet_distance_penalty (Optional[float]): Default distance penalty in graph projection Returns: list: The top triplet results. @@ -116,10 +124,10 @@ async def brute_force_triplet_search( if top_k <= 0: raise ValueError("top_k must be a positive integer.") - if memory_fragment is None: - memory_fragment = await get_memory_fragment( - properties_to_project, node_type=node_type, node_name=node_name - ) + # Setting wide search limit based on the parameters + non_global_search = node_name is None + + wide_search_limit = wide_search_top_k if non_global_search else None if collections is None: collections = [ @@ -140,7 +148,7 @@ async def brute_force_triplet_search( async def search_in_collection(collection_name: str): try: return await vector_engine.search( - collection_name=collection_name, query_vector=query_vector, limit=None + collection_name=collection_name, query_vector=query_vector, limit=wide_search_limit ) except CollectionNotFoundError: return [] @@ -156,15 +164,38 @@ async def brute_force_triplet_search( return [] # Final statistics - projection_time = time.time() - start_time + vector_collection_search_time = time.time() - start_time logger.info( - f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {projection_time:.2f}s" + f"Vector collection retrieval completed: Retrieved distances from {sum(1 for res in results if res)} collections in {vector_collection_search_time:.2f}s" ) node_distances = {collection: result for collection, result in zip(collections, results)} edge_distances = node_distances.get("EdgeType_relationship_name", None) + if wide_search_limit is not None: + relevant_ids_to_filter = list( + { + str(getattr(scored_node, "id")) + for collection_name, score_collection in node_distances.items() + if collection_name != "EdgeType_relationship_name" + and isinstance(score_collection, (list, tuple)) + for scored_node in score_collection + if getattr(scored_node, "id", None) + } + ) + else: + relevant_ids_to_filter = None + + if memory_fragment is None: + memory_fragment = await get_memory_fragment( + properties_to_project=properties_to_project, + node_type=node_type, + node_name=node_name, + relevant_ids_to_filter=relevant_ids_to_filter, + triplet_distance_penalty=triplet_distance_penalty, + ) + await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances) await memory_fragment.map_vector_distances_to_graph_edges( vector_engine=vector_engine, query_vector=query_vector, edge_distances=edge_distances diff --git a/cognee/modules/search/methods/get_search_type_tools.py b/cognee/modules/search/methods/get_search_type_tools.py index 72e2db89a..165ec379b 100644 --- a/cognee/modules/search/methods/get_search_type_tools.py +++ b/cognee/modules/search/methods/get_search_type_tools.py @@ -37,6 +37,8 @@ async def get_search_type_tools( node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> list: search_tasks: dict[SearchType, List[Callable]] = { SearchType.SUMMARIES: [ @@ -67,6 +69,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionRetriever( system_prompt_path=system_prompt_path, @@ -75,6 +79,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_COMPLETION_COT: [ @@ -85,6 +91,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionCotRetriever( system_prompt_path=system_prompt_path, @@ -93,6 +101,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [ @@ -103,6 +113,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphCompletionContextExtensionRetriever( system_prompt_path=system_prompt_path, @@ -111,6 +123,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.GRAPH_SUMMARY_COMPLETION: [ @@ -121,6 +135,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_completion, GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path, @@ -129,6 +145,8 @@ async def get_search_type_tools( node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ).get_context, ], SearchType.CODE: [ @@ -145,8 +163,16 @@ async def get_search_type_tools( ], SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback], SearchType.TEMPORAL: [ - TemporalRetriever(top_k=top_k).get_completion, - TemporalRetriever(top_k=top_k).get_context, + TemporalRetriever( + top_k=top_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, + ).get_completion, + TemporalRetriever( + top_k=top_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, + ).get_context, ], SearchType.CHUNKS_LEXICAL: ( lambda _r=JaccardChunksRetriever(top_k=top_k): [ diff --git a/cognee/modules/search/methods/no_access_control_search.py b/cognee/modules/search/methods/no_access_control_search.py index fcb02da46..3a703bbc9 100644 --- a/cognee/modules/search/methods/no_access_control_search.py +++ b/cognee/modules/search/methods/no_access_control_search.py @@ -24,6 +24,8 @@ async def no_access_control_search( last_k: Optional[int] = None, only_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: search_tools = await get_search_type_tools( query_type=query_type, @@ -35,6 +37,8 @@ async def no_access_control_search( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) graph_engine = await get_graph_engine() is_empty = await graph_engine.is_empty() diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index b4278424b..9f180d607 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -47,6 +47,8 @@ async def search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[CombinedSearchResult, List[SearchResult]]: """ @@ -90,6 +92,8 @@ async def search( only_context=only_context, use_combined_context=use_combined_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) else: search_results = [ @@ -105,6 +109,8 @@ async def search( last_k=last_k, only_context=only_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) ] @@ -219,6 +225,8 @@ async def authorized_search( only_context: bool = False, use_combined_context: bool = False, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Union[ Tuple[Any, Union[List[Edge], str], List[Dataset]], List[Tuple[Any, Union[List[Edge], str], List[Dataset]]], @@ -246,6 +254,8 @@ async def authorized_search( last_k=last_k, only_context=True, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) context = {} @@ -267,6 +277,8 @@ async def authorized_search( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) search_tools = specific_search_tools if len(search_tools) == 2: @@ -306,6 +318,7 @@ async def authorized_search( last_k=last_k, only_context=only_context, session_id=session_id, + wide_search_top_k=wide_search_top_k, ) return search_results @@ -325,6 +338,8 @@ async def search_in_datasets_context( only_context: bool = False, context: Optional[Any] = None, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]: """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. @@ -345,6 +360,8 @@ async def search_in_datasets_context( only_context: bool = False, context: Optional[Any] = None, session_id: Optional[str] = None, + wide_search_top_k: Optional[int] = 100, + triplet_distance_penalty: Optional[float] = 3.5, ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) @@ -378,6 +395,8 @@ async def search_in_datasets_context( node_name=node_name, save_interaction=save_interaction, last_k=last_k, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) search_tools = specific_search_tools if len(search_tools) == 2: @@ -413,6 +432,8 @@ async def search_in_datasets_context( only_context=only_context, context=context, session_id=session_id, + wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) ) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py index 37ba113b5..1d2b79cf9 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py @@ -9,7 +9,7 @@ def test_node_initialization(): """Test that a Node is initialized correctly.""" node = Node("node1", {"attr1": "value1"}, dimension=2) assert node.id == "node1" - assert node.attributes == {"attr1": "value1", "vector_distance": np.inf} + assert node.attributes == {"attr1": "value1", "vector_distance": 3.5} assert len(node.status) == 2 assert np.all(node.status == 1) @@ -96,7 +96,7 @@ def test_edge_initialization(): edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2) assert edge.node1 == node1 assert edge.node2 == node2 - assert edge.attributes == {"vector_distance": np.inf, "weight": 10} + assert edge.attributes == {"vector_distance": 3.5, "weight": 10} assert edge.directed is False assert len(edge.status) == 2 assert np.all(edge.status == 1) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 6888648c3..711479387 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -1,4 +1,5 @@ import pytest +from unittest.mock import AsyncMock from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph @@ -11,6 +12,30 @@ def setup_graph(): return CogneeGraph() +@pytest.fixture +def mock_adapter(): + """Fixture to create a mock adapter for database operations.""" + adapter = AsyncMock() + return adapter + + +@pytest.fixture +def mock_vector_engine(): + """Fixture to create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + def test_add_node_success(setup_graph): """Test successful addition of a node.""" graph = setup_graph @@ -73,3 +98,433 @@ def test_get_edges_nonexistent_node(setup_graph): graph = setup_graph with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."): graph.get_edges_from_node("nonexistent") + + +@pytest.mark.asyncio +async def test_project_graph_from_db_full_graph(setup_graph, mock_adapter): + """Test projecting a full graph from database.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1", "description": "First node"}), + ("2", {"name": "Node2", "description": "Second node"}), + ] + edges_data = [ + ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name", "description"], + edge_properties_to_project=["relationship_name"], + ) + + assert len(graph.nodes) == 2 + assert len(graph.edges) == 1 + assert graph.get_node("1") is not None + assert graph.get_node("2") is not None + assert graph.edges[0].node1.id == "1" + assert graph.edges[0].node2.id == "2" + + +@pytest.mark.asyncio +async def test_project_graph_from_db_id_filtered(setup_graph, mock_adapter): + """Test projecting an ID-filtered graph from database.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1"}), + ("2", {"name": "Node2"}), + ] + edges_data = [ + ("1", "2", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_id_filtered_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=["relationship_name"], + relevant_ids_to_filter=["1", "2"], + ) + + assert len(graph.nodes) == 2 + assert len(graph.edges) == 1 + mock_adapter.get_id_filtered_graph_data.assert_called_once() + + +@pytest.mark.asyncio +async def test_project_graph_from_db_nodeset_subgraph(setup_graph, mock_adapter): + """Test projecting a nodeset subgraph filtered by node type and name.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Alice", "type": "Person"}), + ("2", {"name": "Bob", "type": "Person"}), + ] + edges_data = [ + ("1", "2", "KNOWS", {"relationship_name": "knows"}), + ] + + mock_adapter.get_nodeset_subgraph = AsyncMock(return_value=(nodes_data, edges_data)) + + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name", "type"], + edge_properties_to_project=["relationship_name"], + node_type="Person", + node_name=["Alice"], + ) + + assert len(graph.nodes) == 2 + assert graph.get_node("1") is not None + assert len(graph.edges) == 1 + mock_adapter.get_nodeset_subgraph.assert_called_once() + + +@pytest.mark.asyncio +async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter): + """Test projecting empty graph raises EntityNotFoundError.""" + graph = setup_graph + + mock_adapter.get_graph_data = AsyncMock(return_value=([], [])) + + with pytest.raises(EntityNotFoundError, match="Empty graph projected from the database."): + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=[], + ) + + +@pytest.mark.asyncio +async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter): + """Test that edges referencing missing nodes raise error.""" + graph = setup_graph + + nodes_data = [ + ("1", {"name": "Node1"}), + ] + edges_data = [ + ("1", "999", "CONNECTS_TO", {"relationship_name": "connects"}), + ] + + mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + with pytest.raises(EntityNotFoundError, match="Edge references nonexistent nodes"): + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=["relationship_name"], + ) + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_nodes(setup_graph): + """Test mapping vector distances to graph nodes.""" + graph = setup_graph + + node1 = Node("1", {"name": "Node1"}) + node2 = Node("2", {"name": "Node2"}) + graph.add_node(node1) + graph.add_node(node2) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ] + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + + +@pytest.mark.asyncio +async def test_map_vector_distances_partial_node_coverage(setup_graph): + """Test mapping vector distances when only some nodes have results.""" + graph = setup_graph + + node1 = Node("1", {"name": "Node1"}) + node2 = Node("2", {"name": "Node2"}) + node3 = Node("3", {"name": "Node3"}) + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ] + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + assert graph.get_node("3").attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_multiple_categories(setup_graph): + """Test mapping vector distances from multiple collection categories.""" + graph = setup_graph + + # Create nodes + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + node4 = Node("4") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + graph.add_node(node4) + + node_distances = { + "Entity_name": [ + MockScoredResult("1", 0.95), + MockScoredResult("2", 0.87), + ], + "TextSummary_text": [ + MockScoredResult("3", 0.92), + ], + } + + await graph.map_vector_distances_to_graph_nodes(node_distances) + + assert graph.get_node("1").attributes.get("vector_distance") == 0.95 + assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + assert graph.get_node("3").attributes.get("vector_distance") == 0.92 + assert graph.get_node("4").attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph, mock_vector_engine): + """Test mapping vector distances to edges when edge_distances provided.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.92 + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_search(setup_graph, mock_vector_engine): + """Test mapping edge distances when searching for them.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + mock_vector_engine.search.return_value = [ + MockScoredResult("e1", 0.88, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=None, + ) + + mock_vector_engine.search.assert_called_once() + assert graph.edges[0].attributes.get("vector_distance") == 0.88 + + +@pytest.mark.asyncio +async def test_map_vector_distances_partial_edge_coverage(setup_graph, mock_vector_engine): + """Test mapping edge distances when only some edges have results.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge1 = Edge(node1, node2, attributes={"edge_text": "CONNECTS_TO"}) + edge2 = Edge(node2, node3, attributes={"edge_text": "DEPENDS_ON"}) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.92 + assert graph.edges[1].attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_edges_fallback_to_relationship_type( + setup_graph, mock_vector_engine +): + """Test that edge mapping falls back to relationship_type when edge_text is missing.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"relationship_type": "KNOWS"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 0.85 + + +@pytest.mark.asyncio +async def test_map_vector_distances_no_edge_matches(setup_graph, mock_vector_engine): + """Test edge mapping when no edges match the distance results.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge( + node1, + node2, + attributes={"edge_text": "CONNECTS_TO", "relationship_type": "connects"}, + ) + graph.add_edge(edge) + + edge_distances = [ + MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}), + ] + + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[0.1, 0.2, 0.3], + edge_distances=edge_distances, + ) + + assert graph.edges[0].attributes.get("vector_distance") == 3.5 + + +@pytest.mark.asyncio +async def test_map_vector_distances_invalid_query_vector(setup_graph, mock_vector_engine): + """Test that invalid query vector raises error.""" + graph = setup_graph + + with pytest.raises(ValueError, match="Failed to generate query embedding"): + await graph.map_vector_distances_to_graph_edges( + vector_engine=mock_vector_engine, + query_vector=[], + edge_distances=None, + ) + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances(setup_graph): + """Test calculating top triplet importances by score.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + node4 = Node("4") + + node1.add_attribute("vector_distance", 0.9) + node2.add_attribute("vector_distance", 0.8) + node3.add_attribute("vector_distance", 0.7) + node4.add_attribute("vector_distance", 0.6) + + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + graph.add_node(node4) + + edge1 = Edge(node1, node2) + edge2 = Edge(node2, node3) + edge3 = Edge(node3, node4) + + edge1.add_attribute("vector_distance", 0.85) + edge2.add_attribute("vector_distance", 0.75) + edge3.add_attribute("vector_distance", 0.65) + + graph.add_edge(edge1) + graph.add_edge(edge2) + graph.add_edge(edge3) + + top_triplets = await graph.calculate_top_triplet_importances(k=2) + + assert len(top_triplets) == 2 + + assert top_triplets[0] == edge3 + assert top_triplets[1] == edge2 + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_default_distances(setup_graph): + """Test calculating importances when nodes/edges have no vector distances.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + edge = Edge(node1, node2) + graph.add_edge(edge) + + top_triplets = await graph.calculate_top_triplet_importances(k=1) + + assert len(top_triplets) == 1 + assert top_triplets[0] == edge diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py new file mode 100644 index 000000000..5eb6fb105 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -0,0 +1,582 @@ +import pytest +from unittest.mock import AsyncMock, patch + +from cognee.modules.retrieval.utils.brute_force_triplet_search import ( + brute_force_triplet_search, + get_memory_fragment, +) +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_empty_query(): + """Test that empty query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query="") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_none_query(): + """Test that None query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query=None) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_negative_top_k(): + """Test that negative top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=-1) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_zero_top_k(): + """Test that zero top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=0) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_global_search(): + """Test that wide_search_limit is applied for global search (node_name=None).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=None, # Global search + wide_search_top_k=75, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 75 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): + """Test that wide_search_limit is None for filtered search (node_name provided).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=["Node1"], + wide_search_top_k=50, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_default(): + """Test that wide_search_top_k defaults to 100.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", node_name=None) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 100 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_default_collections(): + """Test that default collections are used when none provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test") + + expected_collections = [ + "Entity_name", + "TextSummary_text", + "EntityType_name", + "DocumentChunk_text", + ] + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert call_collections == expected_collections + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_custom_collections(): + """Test that custom collections are used when provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + custom_collections = ["CustomCol1", "CustomCol2"] + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", collections=custom_collections) + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert call_collections == custom_collections + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_all_collections_empty(): + """Test that empty list is returned when all collections return no results.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + results = await brute_force_triplet_search(query="test") + assert results == [] + + +# Tests for query embedding + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_embeds_query(): + """Test that query is embedded before searching.""" + query_text = "test query" + expected_vector = [0.1, 0.2, 0.3] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query=query_text) + + mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text]) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["query_vector"] == expected_vector + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_extracts_node_ids_global_search(): + """Test that node IDs are extracted from search results for global search.""" + scored_results = [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + MockScoredResult("node3", 0.92), + ] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=scored_results) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_reuses_provided_fragment(): + """Test that provided memory fragment is reused instead of creating new one.""" + provided_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment" + ) as mock_get_fragment, + ): + await brute_force_triplet_search( + query="test", + memory_fragment=provided_fragment, + node_name=["node"], + ) + + mock_get_fragment.assert_not_called() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): + """Test that memory fragment is created when not provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + await brute_force_triplet_search(query="test", node_name=["node"]) + + mock_get_fragment.assert_called_once() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation(): + """Test that custom top_k is passed to importance calculation.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + custom_top_k = 15 + await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) + + mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): + """Test that get_memory_fragment returns empty graph when entity not found.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.project_graph_from_db = AsyncMock( + side_effect=EntityNotFoundError("Entity not found") + ) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ): + fragment = await get_memory_fragment() + + assert isinstance(fragment, CogneeGraph) + assert len(fragment.nodes) == 0 + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_error(): + """Test that get_memory_fragment returns empty graph on generic error.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error")) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ): + fragment = await get_memory_fragment() + + assert isinstance(fragment, CogneeGraph) + assert len(fragment.nodes) == 0 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_deduplicates_node_ids(): + """Test that duplicate node IDs across collections are deduplicated.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ] + elif collection_name == "TextSummary_text": + return [ + MockScoredResult("node1", 0.90), + MockScoredResult("node3", 0.92), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + assert len(call_kwargs["relevant_ids_to_filter"]) == 3 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_excludes_edge_collection(): + """Test that EdgeType_relationship_name collection is excluded from ID extraction.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "EdgeType_relationship_name": + return [MockScoredResult("edge1", 0.88)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search( + query="test", + node_name=None, + collections=["Entity_name", "EdgeType_relationship_name"], + ) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] == ["node1"] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_skips_nodes_without_ids(): + """Test that nodes without ID attribute are skipped.""" + + class ScoredResultNoId: + """Mock result without id attribute.""" + + def __init__(self, score): + self.score = score + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + ScoredResultNoId(0.90), + MockScoredResult("node2", 0.87), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_handles_tuple_results(): + """Test that both list and tuple results are handled correctly.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return ( + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ) + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_mixed_empty_collections(): + """Test ID extraction with mixed empty and non-empty collections.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "TextSummary_text": + return [] + elif collection_name == "EntityType_name": + return [MockScoredResult("node2", 0.92)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"}