From 742866b4c9f1d4aa53ab60fb54b79474fbfea0d2 Mon Sep 17 00:00:00 2001 From: EricXiao Date: Wed, 22 Oct 2025 16:56:46 +0800 Subject: [PATCH 01/16] feat: csv ingestion loader & chunk Signed-off-by: EricXiao --- cognee/cli/commands/cognify_command.py | 9 +- cognee/cli/config.py | 2 +- .../files/utils/guess_file_type.py | 43 +++++ .../files/utils/is_csv_content.py | 181 ++++++++++++++++++ cognee/infrastructure/loaders/LoaderEngine.py | 1 + .../infrastructure/loaders/core/__init__.py | 3 +- .../infrastructure/loaders/core/csv_loader.py | 93 +++++++++ .../loaders/core/text_loader.py | 3 +- .../loaders/supported_loaders.py | 3 +- cognee/modules/chunking/CsvChunker.py | 35 ++++ .../processing/document_types/CsvDocument.py | 33 ++++ .../processing/document_types/__init__.py | 1 + cognee/tasks/chunks/__init__.py | 1 + cognee/tasks/chunks/chunk_by_row.py | 94 +++++++++ cognee/tasks/documents/classify_documents.py | 2 + .../integration/documents/CsvDocument_test.py | 70 +++++++ .../tests/test_data/example_with_header.csv | 3 + .../processing/chunks/chunk_by_row_test.py | 52 +++++ 18 files changed, 623 insertions(+), 6 deletions(-) create mode 100644 cognee/infrastructure/files/utils/is_csv_content.py create mode 100644 cognee/infrastructure/loaders/core/csv_loader.py create mode 100644 cognee/modules/chunking/CsvChunker.py create mode 100644 cognee/modules/data/processing/document_types/CsvDocument.py create mode 100644 cognee/tasks/chunks/chunk_by_row.py create mode 100644 cognee/tests/integration/documents/CsvDocument_test.py create mode 100644 cognee/tests/test_data/example_with_header.csv create mode 100644 cognee/tests/unit/processing/chunks/chunk_by_row_test.py diff --git a/cognee/cli/commands/cognify_command.py b/cognee/cli/commands/cognify_command.py index 16eaf0454..b89c1f70e 100644 --- a/cognee/cli/commands/cognify_command.py +++ b/cognee/cli/commands/cognify_command.py @@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin Processing Pipeline: 1. **Document Classification**: Identifies document types and structures -2. **Permission Validation**: Ensures user has processing rights +2. **Permission Validation**: Ensures user has processing rights 3. **Text Chunking**: Breaks content into semantically meaningful segments 4. **Entity Extraction**: Identifies key concepts, people, places, organizations 5. **Relationship Detection**: Discovers connections between entities @@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge chunker_class = LangchainChunker except ImportError: fmt.warning("LangchainChunker not available, using TextChunker") + elif args.chunker == "CsvChunker": + try: + from cognee.modules.chunking.CsvChunker import CsvChunker + + chunker_class = CsvChunker + except ImportError: + fmt.warning("CsvChunker not available, using TextChunker") result = await cognee.cognify( datasets=datasets, diff --git a/cognee/cli/config.py b/cognee/cli/config.py index d016608c1..082adbaec 100644 --- a/cognee/cli/config.py +++ b/cognee/cli/config.py @@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [ ] # Chunker choices -CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"] +CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"] # Output format choices OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"] diff --git a/cognee/infrastructure/files/utils/guess_file_type.py b/cognee/infrastructure/files/utils/guess_file_type.py index dcdd68cad..10f59a400 100644 --- a/cognee/infrastructure/files/utils/guess_file_type.py +++ b/cognee/infrastructure/files/utils/guess_file_type.py @@ -1,6 +1,8 @@ from typing import BinaryIO import filetype + from .is_text_content import is_text_content +from .is_csv_content import is_csv_content class FileTypeException(Exception): @@ -134,3 +136,44 @@ def guess_file_type(file: BinaryIO) -> filetype.Type: raise FileTypeException(f"Unknown file detected: {file.name}.") return file_type + + +class CsvFileType(filetype.Type): + """ + Match CSV file types based on MIME type and extension. + + Public methods: + - match + + Instance variables: + - MIME: The MIME type of the CSV. + - EXTENSION: The file extension of the CSV. + """ + + MIME = "text/csv" + EXTENSION = "csv" + + def __init__(self): + super().__init__(mime=self.MIME, extension=self.EXTENSION) + + def match(self, buf): + """ + Determine if the given buffer contains csv content. + + Parameters: + ----------- + + - buf: The buffer to check for csv content. + + Returns: + -------- + + Returns True if the buffer is identified as csv content, otherwise False. + """ + + return is_csv_content(buf) + + +csv_file_type = CsvFileType() + +filetype.add_type(csv_file_type) diff --git a/cognee/infrastructure/files/utils/is_csv_content.py b/cognee/infrastructure/files/utils/is_csv_content.py new file mode 100644 index 000000000..07b7ea69b --- /dev/null +++ b/cognee/infrastructure/files/utils/is_csv_content.py @@ -0,0 +1,181 @@ +import csv +from collections import Counter + + +def is_csv_content(content): + """ + Heuristically determine whether a bytes-like object is CSV text. + + Strategy (fail-fast and cheap to expensive): + 1) Decode: Try a small ordered list of common encodings with strict errors. + 2) Line sampling: require >= 2 non-empty lines; sample up to 50 lines. + 3) Delimiter detection: + - Prefer csv.Sniffer() with common delimiters. + - Fallback to a lightweight consistency heuristic. + 4) Lightweight parse check: + - Parse a few lines with the delimiter. + - Ensure at least 2 valid rows and relatively stable column counts. + + Returns: + bool: True if the buffer looks like CSV; False otherwise. + """ + try: + encoding_list = [ + "utf-8", + "utf-8-sig", + "utf-32-le", + "utf-32-be", + "utf-16-le", + "utf-16-be", + "gb18030", + "shift_jis", + "cp949", + "cp1252", + "iso-8859-1", + ] + + # Try to decode strictly—if decoding fails for all encodings, it's not text/CSV. + text = None + for enc in encoding_list: + try: + text = content.decode(enc, errors="strict") + break + except UnicodeDecodeError: + continue + if text is None: + return False + + # Reject empty/whitespace-only payloads. + stripped = text.strip() + if not stripped: + return False + + # Split into logical lines and drop empty ones. Require at least two lines. + lines = [ln for ln in text.splitlines() if ln.strip()] + if len(lines) < 2: + return False + + # Take a small sample to keep sniffing cheap and predictable. + sample_lines = lines[:50] + + # Detect delimiter using csv.Sniffer first; if that fails, use our heuristic. + delimiter = _sniff_delimiter(sample_lines) or _heuristic_delimiter(sample_lines) + if not delimiter: + return False + + # Finally, do a lightweight parse sanity check with the chosen delimiter. + return _lightweight_parse_check(sample_lines, delimiter) + except Exception: + return False + + +def _sniff_delimiter(lines): + """ + Try Python's built-in csv.Sniffer on a sample. + + Args: + lines (list[str]): Sample lines (already decoded). + + Returns: + str | None: The detected delimiter if sniffing succeeds; otherwise None. + """ + # Join up to 50 lines to form the sample string Sniffer will inspect. + sample = "\n".join(lines[:50]) + try: + dialect = csv.Sniffer().sniff(sample, delimiters=",\t;|") + return dialect.delimiter + except Exception: + # Sniffer is known to be brittle on small/dirty samples—silently fallback. + return None + + +def _heuristic_delimiter(lines): + """ + Fallback delimiter detection based on count consistency per line. + + Heuristic: + - For each candidate delimiter, count occurrences per line. + - Keep only lines with count > 0 (line must contain the delimiter). + - Require at least half of lines to contain the delimiter (min 2). + - Compute the mode (most common count). If the proportion of lines that + exhibit the modal count is >= 80%, accept that delimiter. + + Args: + lines (list[str]): Sample lines. + + Returns: + str | None: Best delimiter if one meets the consistency threshold; else None. + """ + candidates = [",", "\t", ";", "|"] + best = None + best_score = 0.0 + + for d in candidates: + # Count how many times the delimiter appears in each line. + counts = [ln.count(d) for ln in lines] + # Consider only lines that actually contain the delimiter at least once. + nonzero = [c for c in counts if c > 0] + + # Require that more than half of lines (and at least 2) contain the delimiter. + if len(nonzero) < max(2, int(0.5 * len(lines))): + continue + + # Find the modal count and its frequency. + cnt = Counter(nonzero) + pairs = cnt.most_common(1) + if not pairs: + continue + + mode, mode_freq = pairs[0] + # Consistency ratio: lines with the modal count / total lines in the sample. + consistency = mode_freq / len(lines) + # Accept if consistent enough and better than any previous candidate. + if mode >= 1 and consistency >= 0.80 and consistency > best_score: + best = d + best_score = consistency + + return best + + +def _lightweight_parse_check(lines, delimiter): + """ + Parse a few lines with csv.reader and check structural stability. + + Heuristic: + - Parse up to 5 lines with the given delimiter. + - Count column widths per parsed row. + - Require at least 2 non-empty rows. + - Allow at most 1 row whose width deviates by >2 columns from the first row. + + Args: + lines (list[str]): Sample lines (decoded). + delimiter (str): Delimiter chosen by sniffing/heuristics. + + Returns: + bool: True if parsing looks stable; False otherwise. + """ + try: + # csv.reader accepts any iterable of strings; feeding the first 10 lines is fine. + reader = csv.reader(lines[:10], delimiter=delimiter) + widths = [] + valid_rows = 0 + for row in reader: + if not row: + continue + + widths.append(len(row)) + valid_rows += 1 + + # Need at least two meaningful rows to make a judgment. + if valid_rows < 2: + return False + + if widths: + first = widths[0] + # Count rows whose width deviates significantly (>2) from the first row. + unstable = sum(1 for w in widths if abs(w - first) > 2) + # Permit at most 1 unstable row among the parsed sample. + return unstable <= 1 + return False + except Exception: + return False diff --git a/cognee/infrastructure/loaders/LoaderEngine.py b/cognee/infrastructure/loaders/LoaderEngine.py index 6b62f7641..37e63c9fc 100644 --- a/cognee/infrastructure/loaders/LoaderEngine.py +++ b/cognee/infrastructure/loaders/LoaderEngine.py @@ -30,6 +30,7 @@ class LoaderEngine: "pypdf_loader", "image_loader", "audio_loader", + "csv_loader", "unstructured_loader", "advanced_pdf_loader", ] diff --git a/cognee/infrastructure/loaders/core/__init__.py b/cognee/infrastructure/loaders/core/__init__.py index 8a2df80f9..09819fbd2 100644 --- a/cognee/infrastructure/loaders/core/__init__.py +++ b/cognee/infrastructure/loaders/core/__init__.py @@ -3,5 +3,6 @@ from .text_loader import TextLoader from .audio_loader import AudioLoader from .image_loader import ImageLoader +from .csv_loader import CsvLoader -__all__ = ["TextLoader", "AudioLoader", "ImageLoader"] +__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"] diff --git a/cognee/infrastructure/loaders/core/csv_loader.py b/cognee/infrastructure/loaders/core/csv_loader.py new file mode 100644 index 000000000..a314a7a24 --- /dev/null +++ b/cognee/infrastructure/loaders/core/csv_loader.py @@ -0,0 +1,93 @@ +import os +from typing import List +import csv +from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface +from cognee.infrastructure.files.storage import get_file_storage, get_storage_config +from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata + + +class CsvLoader(LoaderInterface): + """ + Core CSV file loader that handles basic CSV file formats. + """ + + @property + def supported_extensions(self) -> List[str]: + """Supported text file extensions.""" + return [ + "csv", + ] + + @property + def supported_mime_types(self) -> List[str]: + """Supported MIME types for text content.""" + return [ + "text/csv", + ] + + @property + def loader_name(self) -> str: + """Unique identifier for this loader.""" + return "csv_loader" + + def can_handle(self, extension: str, mime_type: str) -> bool: + """ + Check if this loader can handle the given file. + + Args: + extension: File extension + mime_type: Optional MIME type + + Returns: + True if file can be handled, False otherwise + """ + if extension in self.supported_extensions and mime_type in self.supported_mime_types: + return True + + return False + + async def load(self, file_path: str, encoding: str = "utf-8", **kwargs): + """ + Load and process the csv file. + + Args: + file_path: Path to the file to load + encoding: Text encoding to use (default: utf-8) + **kwargs: Additional configuration (unused) + + Returns: + LoaderResult containing the file content and metadata + + Raises: + FileNotFoundError: If file doesn't exist + UnicodeDecodeError: If file cannot be decoded with specified encoding + OSError: If file cannot be read + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as f: + file_metadata = await get_file_metadata(f) + # Name ingested file of current loader based on original file content hash + storage_file_name = "text_" + file_metadata["content_hash"] + ".txt" + + row_texts = [] + row_index = 1 + + with open(file_path, "r", encoding=encoding, newline="") as file: + reader = csv.DictReader(file) + for row in reader: + pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()] + row_text = ", ".join(pairs) + row_texts.append(f"Row {row_index}:\n{row_text}\n") + row_index += 1 + + content = "\n".join(row_texts) + + storage_config = get_storage_config() + data_root_directory = storage_config["data_root_directory"] + storage = get_file_storage(data_root_directory) + + full_file_path = await storage.store(storage_file_name, content) + + return full_file_path diff --git a/cognee/infrastructure/loaders/core/text_loader.py b/cognee/infrastructure/loaders/core/text_loader.py index a6f94be9b..e478edb22 100644 --- a/cognee/infrastructure/loaders/core/text_loader.py +++ b/cognee/infrastructure/loaders/core/text_loader.py @@ -16,7 +16,7 @@ class TextLoader(LoaderInterface): @property def supported_extensions(self) -> List[str]: """Supported text file extensions.""" - return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"] + return ["txt", "md", "json", "xml", "yaml", "yml", "log"] @property def supported_mime_types(self) -> List[str]: @@ -24,7 +24,6 @@ class TextLoader(LoaderInterface): return [ "text/plain", "text/markdown", - "text/csv", "application/json", "text/xml", "application/xml", diff --git a/cognee/infrastructure/loaders/supported_loaders.py b/cognee/infrastructure/loaders/supported_loaders.py index d103babe3..b506df5f3 100644 --- a/cognee/infrastructure/loaders/supported_loaders.py +++ b/cognee/infrastructure/loaders/supported_loaders.py @@ -1,5 +1,5 @@ from cognee.infrastructure.loaders.external import PyPdfLoader -from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader +from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader # Registry for loader implementations supported_loaders = { @@ -7,6 +7,7 @@ supported_loaders = { TextLoader.loader_name: TextLoader, ImageLoader.loader_name: ImageLoader, AudioLoader.loader_name: AudioLoader, + CsvLoader.loader_name: CsvLoader, } # Try adding optional loaders diff --git a/cognee/modules/chunking/CsvChunker.py b/cognee/modules/chunking/CsvChunker.py new file mode 100644 index 000000000..4ba4a969e --- /dev/null +++ b/cognee/modules/chunking/CsvChunker.py @@ -0,0 +1,35 @@ +from cognee.shared.logging_utils import get_logger + + +from cognee.tasks.chunks import chunk_by_row +from cognee.modules.chunking.Chunker import Chunker +from .models.DocumentChunk import DocumentChunk + +logger = get_logger() + + +class CsvChunker(Chunker): + async def read(self): + async for content_text in self.get_text(): + if content_text is None: + continue + + for chunk_data in chunk_by_row(content_text, self.max_chunk_size): + if chunk_data["chunk_size"] <= self.max_chunk_size: + yield DocumentChunk( + id=chunk_data["chunk_id"], + text=chunk_data["text"], + chunk_size=chunk_data["chunk_size"], + is_part_of=self.document, + chunk_index=self.chunk_index, + cut_type=chunk_data["cut_type"], + contains=[], + metadata={ + "index_fields": ["text"], + }, + ) + self.chunk_index += 1 + else: + raise ValueError( + f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}" + ) diff --git a/cognee/modules/data/processing/document_types/CsvDocument.py b/cognee/modules/data/processing/document_types/CsvDocument.py new file mode 100644 index 000000000..3381275bd --- /dev/null +++ b/cognee/modules/data/processing/document_types/CsvDocument.py @@ -0,0 +1,33 @@ +import io +import csv +from typing import Type + +from cognee.modules.chunking.Chunker import Chunker +from cognee.infrastructure.files.utils.open_data_file import open_data_file +from .Document import Document + + +class CsvDocument(Document): + type: str = "csv" + mime_type: str = "text/csv" + + async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int): + async def get_text(): + async with open_data_file( + self.raw_data_location, mode="r", encoding="utf-8", newline="" + ) as file: + content = file.read() + file_like_obj = io.StringIO(content) + reader = csv.DictReader(file_like_obj) + + for row in reader: + pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()] + row_text = ", ".join(pairs) + if not row_text.strip(): + break + yield row_text + + chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text) + + async for chunk in chunker.read(): + yield chunk diff --git a/cognee/modules/data/processing/document_types/__init__.py b/cognee/modules/data/processing/document_types/__init__.py index 2e862f4ba..133dd53f8 100644 --- a/cognee/modules/data/processing/document_types/__init__.py +++ b/cognee/modules/data/processing/document_types/__init__.py @@ -4,3 +4,4 @@ from .TextDocument import TextDocument from .ImageDocument import ImageDocument from .AudioDocument import AudioDocument from .UnstructuredDocument import UnstructuredDocument +from .CsvDocument import CsvDocument diff --git a/cognee/tasks/chunks/__init__.py b/cognee/tasks/chunks/__init__.py index 22ce96be8..37d4de73e 100644 --- a/cognee/tasks/chunks/__init__.py +++ b/cognee/tasks/chunks/__init__.py @@ -1,4 +1,5 @@ from .chunk_by_word import chunk_by_word from .chunk_by_sentence import chunk_by_sentence from .chunk_by_paragraph import chunk_by_paragraph +from .chunk_by_row import chunk_by_row from .remove_disconnected_chunks import remove_disconnected_chunks diff --git a/cognee/tasks/chunks/chunk_by_row.py b/cognee/tasks/chunks/chunk_by_row.py new file mode 100644 index 000000000..8daf13689 --- /dev/null +++ b/cognee/tasks/chunks/chunk_by_row.py @@ -0,0 +1,94 @@ +from typing import Any, Dict, Iterator +from uuid import NAMESPACE_OID, uuid5 + +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine + + +def _get_pair_size(pair_text: str) -> int: + """ + Calculate the size of a given text in terms of tokens. + + If an embedding engine's tokenizer is available, count the tokens for the provided word. + If the tokenizer is not available, assume the word counts as one token. + + Parameters: + ----------- + + - pair_text (str): The key:value pair text for which the token size is to be calculated. + + Returns: + -------- + + - int: The number of tokens representing the text, typically an integer, depending + on the tokenizer's output. + """ + embedding_engine = get_embedding_engine() + if embedding_engine.tokenizer: + return embedding_engine.tokenizer.count_tokens(pair_text) + else: + return 3 + + +def chunk_by_row( + data: str, + max_chunk_size, +) -> Iterator[Dict[str, Any]]: + """ + Chunk the input text by row while enabling exact text reconstruction. + + This function divides the given text data into smaller chunks on a line-by-line basis, + ensuring that the size of each chunk is less than or equal to the specified maximum + chunk size. It guarantees that when the generated chunks are concatenated, they + reproduce the original text accurately. The tokenization process is handled by + adapters compatible with the vector engine's embedding model. + + Parameters: + ----------- + + - data (str): The input text to be chunked. + - max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or + words. + """ + current_chunk_list = [] + chunk_index = 0 + current_chunk_size = 0 + + lines = data.split("\n\n") + for line in lines: + pairs_text = line.split(", ") + + for pair_text in pairs_text: + pair_size = _get_pair_size(pair_text) + if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size): + # Yield current cut chunk + current_chunk = ", ".join(current_chunk_list) + chunk_dict = { + "text": current_chunk, + "chunk_size": current_chunk_size, + "chunk_id": uuid5(NAMESPACE_OID, current_chunk), + "chunk_index": chunk_index, + "cut_type": "row_cut", + } + + yield chunk_dict + + # Start new chunk with current pair text + current_chunk_list = [] + current_chunk_size = 0 + chunk_index += 1 + + current_chunk_list.append(pair_text) + current_chunk_size += pair_size + + # Yield row chunk + current_chunk = ", ".join(current_chunk_list) + if current_chunk: + chunk_dict = { + "text": current_chunk, + "chunk_size": current_chunk_size, + "chunk_id": uuid5(NAMESPACE_OID, current_chunk), + "chunk_index": chunk_index, + "cut_type": "row_end", + } + + yield chunk_dict diff --git a/cognee/tasks/documents/classify_documents.py b/cognee/tasks/documents/classify_documents.py index 9fa512906..e4f13ebd1 100644 --- a/cognee/tasks/documents/classify_documents.py +++ b/cognee/tasks/documents/classify_documents.py @@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import ( ImageDocument, TextDocument, UnstructuredDocument, + CsvDocument, ) from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.engine.utils.generate_node_id import generate_node_id @@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError EXTENSION_TO_DOCUMENT_CLASS = { "pdf": PdfDocument, # Text documents "txt": TextDocument, + "csv": CsvDocument, "docx": UnstructuredDocument, "doc": UnstructuredDocument, "odt": UnstructuredDocument, diff --git a/cognee/tests/integration/documents/CsvDocument_test.py b/cognee/tests/integration/documents/CsvDocument_test.py new file mode 100644 index 000000000..421bb81bd --- /dev/null +++ b/cognee/tests/integration/documents/CsvDocument_test.py @@ -0,0 +1,70 @@ +import os +import sys +import uuid +import pytest +import pathlib +from unittest.mock import patch + +from cognee.modules.chunking.CsvChunker import CsvChunker +from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument +from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine +from cognee.tests.integration.documents.async_gen_zip import async_gen_zip + +chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row") + + +GROUND_TRUTH = { + "chunk_size_10": [ + {"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0}, + {"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1}, + {"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2}, + {"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3}, + ], + "chunk_size_128": [ + {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0}, + {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1}, + ], +} + + +@pytest.mark.parametrize( + "input_file,chunk_size", + [("example_with_header.csv", 10), ("example_with_header.csv", 128)], +) +@patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine) +@pytest.mark.asyncio +async def test_CsvDocument(mock_engine, input_file, chunk_size): + # Define file paths of test data + csv_file_path = os.path.join( + pathlib.Path(__file__).parent.parent.parent, + "test_data", + input_file, + ) + + # Define test documents + csv_document = CsvDocument( + id=uuid.uuid4(), + name="example_with_header.csv", + raw_data_location=csv_file_path, + external_metadata="", + mime_type="text/csv", + ) + + # TEST CSV + ground_truth_key = f"chunk_size_{chunk_size}" + async for ground_truth, row_data in async_gen_zip( + GROUND_TRUTH[ground_truth_key], + csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size), + ): + assert ground_truth["token_count"] == row_data.chunk_size, ( + f'{ground_truth["token_count"] = } != {row_data.chunk_size = }' + ) + assert ground_truth["len_text"] == len(row_data.text), ( + f'{ground_truth["len_text"] = } != {len(row_data.text) = }' + ) + assert ground_truth["cut_type"] == row_data.cut_type, ( + f'{ground_truth["cut_type"] = } != {row_data.cut_type = }' + ) + assert ground_truth["chunk_index"] == row_data.chunk_index, ( + f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }' + ) diff --git a/cognee/tests/test_data/example_with_header.csv b/cognee/tests/test_data/example_with_header.csv new file mode 100644 index 000000000..dc900e5ef --- /dev/null +++ b/cognee/tests/test_data/example_with_header.csv @@ -0,0 +1,3 @@ +id,name,age,city,country +1,Eric,30,Beijing,China +2,Joe,35,Berlin,Germany diff --git a/cognee/tests/unit/processing/chunks/chunk_by_row_test.py b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py new file mode 100644 index 000000000..7d6a73a06 --- /dev/null +++ b/cognee/tests/unit/processing/chunks/chunk_by_row_test.py @@ -0,0 +1,52 @@ +from itertools import product + +import numpy as np +import pytest + +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine +from cognee.tasks.chunks import chunk_by_row + +INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA" +max_chunk_size_vals = [8, 32] + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_chunk_by_row_isomorphism(input_text, max_chunk_size): + chunks = chunk_by_row(input_text, max_chunk_size) + reconstructed_text = ", ".join([chunk["text"] for chunk in chunks]) + assert reconstructed_text == input_text, ( + f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" + ) + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_row_chunk_length(input_text, max_chunk_size): + chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)) + embedding_engine = get_embedding_engine() + + chunk_lengths = np.array( + [embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks] + ) + + larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size] + assert np.all(chunk_lengths <= max_chunk_size), ( + f"{max_chunk_size = }: {larger_chunks} are too large" + ) + + +@pytest.mark.parametrize( + "input_text,max_chunk_size", + list(product([INPUT_TEXTS], max_chunk_size_vals)), +) +def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size): + chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size) + chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) + assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( + f"{chunk_indices = } are not monotonically increasing" + ) From 8566516ceca89a0e85db6aa5ba967f5d8070b2c7 Mon Sep 17 00:00:00 2001 From: EricXiao Date: Wed, 22 Oct 2025 16:59:07 +0800 Subject: [PATCH 02/16] chore: Remove local test code Signed-off-by: EricXiao --- .../loaders/external/advanced_pdf_loader.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py index 6d1412b77..4b3ba296a 100644 --- a/cognee/infrastructure/loaders/external/advanced_pdf_loader.py +++ b/cognee/infrastructure/loaders/external/advanced_pdf_loader.py @@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface): if value is None: return "" return str(value).replace("\xa0", " ").strip() - - -if __name__ == "__main__": - loader = AdvancedPdfLoader() - asyncio.run( - loader.load( - "/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf" - ) - ) From a4a9e762465ecaf0dcdb9b0132db7951d11b437c Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Sun, 2 Nov 2025 17:05:03 +0500 Subject: [PATCH 03/16] feat: add ontology endpoint in REST API - Add POST /api/v1/ontologies endpoint for file upload - Add GET /api/v1/ontologies endpoint for listing ontologies - Implement OntologyService for file management and metadata - Integrate ontology_key parameter in cognify endpoint - Update RDFLibOntologyResolver to support file-like objects - Add essential test suite for ontology endpoints --- cognee/api/client.py | 3 + .../v1/cognify/routers/get_cognify_router.py | 31 +++++- cognee/api/v1/ontologies/__init__.py | 4 + cognee/api/v1/ontologies/ontologies.py | 101 ++++++++++++++++++ cognee/api/v1/ontologies/routers/__init__.py | 0 .../ontologies/routers/get_ontology_router.py | 89 +++++++++++++++ .../rdf_xml/RDFLibOntologyResolver.py | 69 +++++++----- cognee/tests/test_ontology_endpoint.py | 89 +++++++++++++++ 8 files changed, 356 insertions(+), 30 deletions(-) create mode 100644 cognee/api/v1/ontologies/__init__.py create mode 100644 cognee/api/v1/ontologies/ontologies.py create mode 100644 cognee/api/v1/ontologies/routers/__init__.py create mode 100644 cognee/api/v1/ontologies/routers/get_ontology_router.py create mode 100644 cognee/tests/test_ontology_endpoint.py diff --git a/cognee/api/client.py b/cognee/api/client.py index 6766c12de..89e9eb2f5 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -23,6 +23,7 @@ from cognee.api.v1.settings.routers import get_settings_router from cognee.api.v1.datasets.routers import get_datasets_router from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router from cognee.api.v1.search.routers import get_search_router +from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router from cognee.api.v1.memify.routers import get_memify_router from cognee.api.v1.add.routers import get_add_router from cognee.api.v1.delete.routers import get_delete_router @@ -258,6 +259,8 @@ app.include_router( app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"]) +app.include_router(get_ontology_router(), prefix="/api/v1/ontologies", tags=["ontologies"]) + app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"]) app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"]) diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 231bbcd11..246cc6c56 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -41,6 +41,10 @@ class CognifyPayloadDTO(InDTO): custom_prompt: Optional[str] = Field( default="", description="Custom prompt for entity extraction and graph generation" ) + ontology_key: Optional[str] = Field( + default=None, + description="Reference to previously uploaded ontology" + ) def get_cognify_router() -> APIRouter: @@ -68,6 +72,7 @@ def get_cognify_router() -> APIRouter: - **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted). - **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking). - **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction. + - **ontology_key** (Optional[str]): Reference to a previously uploaded ontology file to use for knowledge graph construction. ## Response - **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status @@ -82,7 +87,8 @@ def get_cognify_router() -> APIRouter: { "datasets": ["research_papers", "documentation"], "run_in_background": false, - "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections." + "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.", + "ontology_key": "medical_ontology_v1" } ``` @@ -108,13 +114,36 @@ def get_cognify_router() -> APIRouter: ) from cognee.api.v1.cognify import cognify as cognee_cognify + from cognee.api.v1.ontologies.ontologies import OntologyService try: datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets + config_to_use = None + + if payload.ontology_key: + ontology_service = OntologyService() + try: + ontology_content = ontology_service.get_ontology_content(payload.ontology_key, user) + + from cognee.modules.ontology.ontology_config import Config + from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver + from io import StringIO + + ontology_stream = StringIO(ontology_content) + config_to_use: Config = { + "ontology_config": { + "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_stream) + } + } + except ValueError as e: + return JSONResponse( + status_code=400, content={"error": f"Ontology error: {str(e)}"} + ) cognify_run = await cognee_cognify( datasets, user, + config=config_to_use, run_in_background=payload.run_in_background, custom_prompt=payload.custom_prompt, ) diff --git a/cognee/api/v1/ontologies/__init__.py b/cognee/api/v1/ontologies/__init__.py new file mode 100644 index 000000000..c25064edc --- /dev/null +++ b/cognee/api/v1/ontologies/__init__.py @@ -0,0 +1,4 @@ +from .ontologies import OntologyService +from .routers.get_ontology_router import get_ontology_router + +__all__ = ["OntologyService", "get_ontology_router"] \ No newline at end of file diff --git a/cognee/api/v1/ontologies/ontologies.py b/cognee/api/v1/ontologies/ontologies.py new file mode 100644 index 000000000..fb7f3cd9a --- /dev/null +++ b/cognee/api/v1/ontologies/ontologies.py @@ -0,0 +1,101 @@ +import os +import json +import tempfile +from pathlib import Path +from datetime import datetime, timezone +from typing import Optional +from dataclasses import dataclass + +@dataclass +class OntologyMetadata: + ontology_key: str + filename: str + size_bytes: int + uploaded_at: str + description: Optional[str] = None + +class OntologyService: + def __init__(self): + pass + + @property + def base_dir(self) -> Path: + return Path(tempfile.gettempdir()) / "ontologies" + + def _get_user_dir(self, user_id: str) -> Path: + user_dir = self.base_dir / str(user_id) + user_dir.mkdir(parents=True, exist_ok=True) + return user_dir + + def _get_metadata_path(self, user_dir: Path) -> Path: + return user_dir / "metadata.json" + + def _load_metadata(self, user_dir: Path) -> dict: + metadata_path = self._get_metadata_path(user_dir) + if metadata_path.exists(): + with open(metadata_path, 'r') as f: + return json.load(f) + return {} + + def _save_metadata(self, user_dir: Path, metadata: dict): + metadata_path = self._get_metadata_path(user_dir) + with open(metadata_path, 'w') as f: + json.dump(metadata, f, indent=2) + + async def upload_ontology(self, ontology_key: str, file, user, description: Optional[str] = None) -> OntologyMetadata: + # Validate file format + if not file.filename.lower().endswith('.owl'): + raise ValueError("File must be in .owl format") + + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + # Check for duplicate key + if ontology_key in metadata: + raise ValueError(f"Ontology key '{ontology_key}' already exists") + + # Read file content + content = await file.read() + if len(content) > 10 * 1024 * 1024: # 10MB limit + raise ValueError("File size exceeds 10MB limit") + + # Save file + file_path = user_dir / f"{ontology_key}.owl" + with open(file_path, 'wb') as f: + f.write(content) + + # Update metadata + ontology_metadata = { + "filename": file.filename, + "size_bytes": len(content), + "uploaded_at": datetime.now(timezone.utc).isoformat(), + "description": description + } + metadata[ontology_key] = ontology_metadata + self._save_metadata(user_dir, metadata) + + return OntologyMetadata( + ontology_key=ontology_key, + filename=file.filename, + size_bytes=len(content), + uploaded_at=ontology_metadata["uploaded_at"], + description=description + ) + + def get_ontology_content(self, ontology_key: str, user) -> str: + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + if ontology_key not in metadata: + raise ValueError(f"Ontology key '{ontology_key}' not found") + + file_path = user_dir / f"{ontology_key}.owl" + if not file_path.exists(): + raise ValueError(f"Ontology file for key '{ontology_key}' not found") + + with open(file_path, 'r', encoding='utf-8') as f: + return f.read() + + def list_ontologies(self, user) -> dict: + user_dir = self._get_user_dir(str(user.id)) + return self._load_metadata(user_dir) \ No newline at end of file diff --git a/cognee/api/v1/ontologies/routers/__init__.py b/cognee/api/v1/ontologies/routers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py new file mode 100644 index 000000000..c171fa7bb --- /dev/null +++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py @@ -0,0 +1,89 @@ +from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException +from fastapi.responses import JSONResponse +from typing import Optional + +from cognee.modules.users.models import User +from cognee.modules.users.methods import get_authenticated_user +from cognee.shared.utils import send_telemetry +from cognee import __version__ as cognee_version +from ..ontologies import OntologyService + +def get_ontology_router() -> APIRouter: + router = APIRouter() + ontology_service = OntologyService() + + @router.post("", response_model=dict) + async def upload_ontology( + ontology_key: str = Form(...), + ontology_file: UploadFile = File(...), + description: Optional[str] = Form(None), + user: User = Depends(get_authenticated_user) + ): + """ + Upload an ontology file with a named key for later use in cognify operations. + + ## Request Parameters + - **ontology_key** (str): User-defined identifier for the ontology + - **ontology_file** (UploadFile): OWL format ontology file + - **description** (Optional[str]): Optional description of the ontology + + ## Response + Returns metadata about the uploaded ontology including key, filename, size, and upload timestamp. + + ## Error Codes + - **400 Bad Request**: Invalid file format, duplicate key, file size exceeded + - **500 Internal Server Error**: File system or processing errors + """ + send_telemetry( + "Ontology Upload API Endpoint Invoked", + user.id, + additional_properties={ + "endpoint": "POST /api/v1/ontologies", + "cognee_version": cognee_version, + }, + ) + + try: + result = await ontology_service.upload_ontology( + ontology_key, ontology_file, user, description + ) + return { + "ontology_key": result.ontology_key, + "filename": result.filename, + "size_bytes": result.size_bytes, + "uploaded_at": result.uploaded_at + } + except ValueError as e: + return JSONResponse(status_code=400, content={"error": str(e)}) + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) + + @router.get("", response_model=dict) + async def list_ontologies( + user: User = Depends(get_authenticated_user) + ): + """ + List all uploaded ontologies for the authenticated user. + + ## Response + Returns a dictionary mapping ontology keys to their metadata including filename, size, and upload timestamp. + + ## Error Codes + - **500 Internal Server Error**: File system or processing errors + """ + send_telemetry( + "Ontology List API Endpoint Invoked", + user.id, + additional_properties={ + "endpoint": "GET /api/v1/ontologies", + "cognee_version": cognee_version, + }, + ) + + try: + metadata = ontology_service.list_ontologies(user) + return metadata + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) + + return router \ No newline at end of file diff --git a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py index 45e32936a..4acc8861b 100644 --- a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +++ b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py @@ -2,7 +2,7 @@ import os import difflib from cognee.shared.logging_utils import get_logger from collections import deque -from typing import List, Tuple, Dict, Optional, Any, Union +from typing import List, Tuple, Dict, Optional, Any, Union, IO from rdflib import Graph, URIRef, RDF, RDFS, OWL from cognee.modules.ontology.exceptions import ( @@ -26,44 +26,55 @@ class RDFLibOntologyResolver(BaseOntologyResolver): def __init__( self, - ontology_file: Optional[Union[str, List[str]]] = None, + ontology_file: Optional[Union[str, List[str], IO]] = None, matching_strategy: Optional[MatchingStrategy] = None, ) -> None: super().__init__(matching_strategy) self.ontology_file = ontology_file try: - files_to_load = [] + self.graph = None if ontology_file is not None: - if isinstance(ontology_file, str): - files_to_load = [ontology_file] - elif isinstance(ontology_file, list): - files_to_load = ontology_file + if hasattr(ontology_file, "read"): + self.graph = Graph() + content = ontology_file.read() + self.graph.parse(data=content, format="xml") + logger.info("Ontology loaded successfully from file object") else: - raise ValueError( - f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}" - ) - - if files_to_load: - self.graph = Graph() - loaded_files = [] - for file_path in files_to_load: - if os.path.exists(file_path): - self.graph.parse(file_path) - loaded_files.append(file_path) - logger.info("Ontology loaded successfully from file: %s", file_path) + files_to_load = [] + if isinstance(ontology_file, str): + files_to_load = [ontology_file] + elif isinstance(ontology_file, list): + files_to_load = ontology_file else: - logger.warning( - "Ontology file '%s' not found. Skipping this file.", - file_path, + raise ValueError( + f"ontology_file must be a string, list of strings, file-like object, or None. Got: {type(ontology_file)}" ) - if not loaded_files: - logger.info( - "No valid ontology files found. No owl ontology will be attached to the graph." - ) - self.graph = None - else: - logger.info("Total ontology files loaded: %d", len(loaded_files)) + if files_to_load: + self.graph = Graph() + loaded_files = [] + for file_path in files_to_load: + if os.path.exists(file_path): + self.graph.parse(file_path) + loaded_files.append(file_path) + logger.info("Ontology loaded successfully from file: %s", file_path) + else: + logger.warning( + "Ontology file '%s' not found. Skipping this file.", + file_path, + ) + + if not loaded_files: + logger.info( + "No valid ontology files found. No owl ontology will be attached to the graph." + ) + self.graph = None + else: + logger.info("Total ontology files loaded: %d", len(loaded_files)) + else: + logger.info( + "No ontology file provided. No owl ontology will be attached to the graph." + ) else: logger.info( "No ontology file provided. No owl ontology will be attached to the graph." diff --git a/cognee/tests/test_ontology_endpoint.py b/cognee/tests/test_ontology_endpoint.py new file mode 100644 index 000000000..4849f8649 --- /dev/null +++ b/cognee/tests/test_ontology_endpoint.py @@ -0,0 +1,89 @@ +import pytest +import uuid +from fastapi.testclient import TestClient +from unittest.mock import patch, Mock, AsyncMock +from types import SimpleNamespace +import importlib +from cognee.api.client import app + +gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") + +@pytest.fixture +def client(): + return TestClient(app) + +@pytest.fixture +def mock_user(): + user = Mock() + user.id = "test-user-123" + return user + +@pytest.fixture +def mock_default_user(): + """Mock default user for testing.""" + return SimpleNamespace( + id=uuid.uuid4(), + email="default@example.com", + is_active=True, + tenant_id=uuid.uuid4() + ) + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_success(mock_get_default_user, client, mock_default_user): + """Test successful ontology upload""" + mock_get_default_user.return_value = mock_default_user + ontology_content = b"" + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + + response = client.post( + "/api/v1/ontologies", + files={"ontology_file": ("test.owl", ontology_content)}, + data={"ontology_key": unique_key, "description": "Test"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["ontology_key"] == unique_key + assert "uploaded_at" in data + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user): + """Test 400 response for non-.owl files""" + mock_get_default_user.return_value = mock_default_user + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + response = client.post( + "/api/v1/ontologies", + files={"ontology_file": ("test.txt", b"not xml")}, + data={"ontology_key": unique_key} + ) + assert response.status_code == 400 + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user): + """Test 400 response for missing file or key""" + mock_get_default_user.return_value = mock_default_user + # Missing file + response = client.post("/api/v1/ontologies", data={"ontology_key": "test"}) + assert response.status_code == 400 + + # Missing key + response = client.post("/api/v1/ontologies", files={"ontology_file": ("test.owl", b"xml")}) + assert response.status_code == 400 + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user): + """Test behavior when default user is provided (no explicit authentication)""" + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + mock_get_default_user.return_value = mock_default_user + response = client.post( + "/api/v1/ontologies", + files={"ontology_file": ("test.owl", b"")}, + data={"ontology_key": unique_key} + ) + + # The current system provides a default user when no explicit authentication is given + # This test verifies the system works with conditional authentication + assert response.status_code == 200 + data = response.json() + assert data["ontology_key"] == unique_key + assert "uploaded_at" in data \ No newline at end of file From 79bd2b2576b913528feb92ca6832242133bf9822 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 6 Nov 2025 09:48:01 +0100 Subject: [PATCH 04/16] chore: fixes ruff formatting --- .../v1/cognify/routers/get_cognify_router.py | 15 ++++++++---- cognee/api/v1/ontologies/__init__.py | 2 +- cognee/api/v1/ontologies/ontologies.py | 22 ++++++++++------- .../ontologies/routers/get_ontology_router.py | 11 ++++----- cognee/tests/test_ontology_endpoint.py | 24 ++++++++++++------- 5 files changed, 44 insertions(+), 30 deletions(-) diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 246cc6c56..252ffe7bf 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -42,8 +42,7 @@ class CognifyPayloadDTO(InDTO): default="", description="Custom prompt for entity extraction and graph generation" ) ontology_key: Optional[str] = Field( - default=None, - description="Reference to previously uploaded ontology" + default=None, description="Reference to previously uploaded ontology" ) @@ -123,16 +122,22 @@ def get_cognify_router() -> APIRouter: if payload.ontology_key: ontology_service = OntologyService() try: - ontology_content = ontology_service.get_ontology_content(payload.ontology_key, user) + ontology_content = ontology_service.get_ontology_content( + payload.ontology_key, user + ) from cognee.modules.ontology.ontology_config import Config - from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver + from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import ( + RDFLibOntologyResolver, + ) from io import StringIO ontology_stream = StringIO(ontology_content) config_to_use: Config = { "ontology_config": { - "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_stream) + "ontology_resolver": RDFLibOntologyResolver( + ontology_file=ontology_stream + ) } } except ValueError as e: diff --git a/cognee/api/v1/ontologies/__init__.py b/cognee/api/v1/ontologies/__init__.py index c25064edc..b90d46c3d 100644 --- a/cognee/api/v1/ontologies/__init__.py +++ b/cognee/api/v1/ontologies/__init__.py @@ -1,4 +1,4 @@ from .ontologies import OntologyService from .routers.get_ontology_router import get_ontology_router -__all__ = ["OntologyService", "get_ontology_router"] \ No newline at end of file +__all__ = ["OntologyService", "get_ontology_router"] diff --git a/cognee/api/v1/ontologies/ontologies.py b/cognee/api/v1/ontologies/ontologies.py index fb7f3cd9a..6bfb7658e 100644 --- a/cognee/api/v1/ontologies/ontologies.py +++ b/cognee/api/v1/ontologies/ontologies.py @@ -6,6 +6,7 @@ from datetime import datetime, timezone from typing import Optional from dataclasses import dataclass + @dataclass class OntologyMetadata: ontology_key: str @@ -14,6 +15,7 @@ class OntologyMetadata: uploaded_at: str description: Optional[str] = None + class OntologyService: def __init__(self): pass @@ -33,18 +35,20 @@ class OntologyService: def _load_metadata(self, user_dir: Path) -> dict: metadata_path = self._get_metadata_path(user_dir) if metadata_path.exists(): - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: return json.load(f) return {} def _save_metadata(self, user_dir: Path, metadata: dict): metadata_path = self._get_metadata_path(user_dir) - with open(metadata_path, 'w') as f: + with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2) - async def upload_ontology(self, ontology_key: str, file, user, description: Optional[str] = None) -> OntologyMetadata: + async def upload_ontology( + self, ontology_key: str, file, user, description: Optional[str] = None + ) -> OntologyMetadata: # Validate file format - if not file.filename.lower().endswith('.owl'): + if not file.filename.lower().endswith(".owl"): raise ValueError("File must be in .owl format") user_dir = self._get_user_dir(str(user.id)) @@ -61,7 +65,7 @@ class OntologyService: # Save file file_path = user_dir / f"{ontology_key}.owl" - with open(file_path, 'wb') as f: + with open(file_path, "wb") as f: f.write(content) # Update metadata @@ -69,7 +73,7 @@ class OntologyService: "filename": file.filename, "size_bytes": len(content), "uploaded_at": datetime.now(timezone.utc).isoformat(), - "description": description + "description": description, } metadata[ontology_key] = ontology_metadata self._save_metadata(user_dir, metadata) @@ -79,7 +83,7 @@ class OntologyService: filename=file.filename, size_bytes=len(content), uploaded_at=ontology_metadata["uploaded_at"], - description=description + description=description, ) def get_ontology_content(self, ontology_key: str, user) -> str: @@ -93,9 +97,9 @@ class OntologyService: if not file_path.exists(): raise ValueError(f"Ontology file for key '{ontology_key}' not found") - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: return f.read() def list_ontologies(self, user) -> dict: user_dir = self._get_user_dir(str(user.id)) - return self._load_metadata(user_dir) \ No newline at end of file + return self._load_metadata(user_dir) diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py index c171fa7bb..f5c51ba21 100644 --- a/cognee/api/v1/ontologies/routers/get_ontology_router.py +++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py @@ -8,6 +8,7 @@ from cognee.shared.utils import send_telemetry from cognee import __version__ as cognee_version from ..ontologies import OntologyService + def get_ontology_router() -> APIRouter: router = APIRouter() ontology_service = OntologyService() @@ -17,7 +18,7 @@ def get_ontology_router() -> APIRouter: ontology_key: str = Form(...), ontology_file: UploadFile = File(...), description: Optional[str] = Form(None), - user: User = Depends(get_authenticated_user) + user: User = Depends(get_authenticated_user), ): """ Upload an ontology file with a named key for later use in cognify operations. @@ -51,7 +52,7 @@ def get_ontology_router() -> APIRouter: "ontology_key": result.ontology_key, "filename": result.filename, "size_bytes": result.size_bytes, - "uploaded_at": result.uploaded_at + "uploaded_at": result.uploaded_at, } except ValueError as e: return JSONResponse(status_code=400, content={"error": str(e)}) @@ -59,9 +60,7 @@ def get_ontology_router() -> APIRouter: return JSONResponse(status_code=500, content={"error": str(e)}) @router.get("", response_model=dict) - async def list_ontologies( - user: User = Depends(get_authenticated_user) - ): + async def list_ontologies(user: User = Depends(get_authenticated_user)): """ List all uploaded ontologies for the authenticated user. @@ -86,4 +85,4 @@ def get_ontology_router() -> APIRouter: except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) - return router \ No newline at end of file + return router diff --git a/cognee/tests/test_ontology_endpoint.py b/cognee/tests/test_ontology_endpoint.py index 4849f8649..b5cedfafe 100644 --- a/cognee/tests/test_ontology_endpoint.py +++ b/cognee/tests/test_ontology_endpoint.py @@ -8,37 +8,40 @@ from cognee.api.client import app gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") + @pytest.fixture def client(): return TestClient(app) + @pytest.fixture def mock_user(): user = Mock() user.id = "test-user-123" return user + @pytest.fixture def mock_default_user(): """Mock default user for testing.""" return SimpleNamespace( - id=uuid.uuid4(), - email="default@example.com", - is_active=True, - tenant_id=uuid.uuid4() + id=uuid.uuid4(), email="default@example.com", is_active=True, tenant_id=uuid.uuid4() ) + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_success(mock_get_default_user, client, mock_default_user): """Test successful ontology upload""" mock_get_default_user.return_value = mock_default_user - ontology_content = b"" + ontology_content = ( + b"" + ) unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" response = client.post( "/api/v1/ontologies", files={"ontology_file": ("test.owl", ontology_content)}, - data={"ontology_key": unique_key, "description": "Test"} + data={"ontology_key": unique_key, "description": "Test"}, ) assert response.status_code == 200 @@ -46,6 +49,7 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use assert data["ontology_key"] == unique_key assert "uploaded_at" in data + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user): """Test 400 response for non-.owl files""" @@ -54,10 +58,11 @@ def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_defaul response = client.post( "/api/v1/ontologies", files={"ontology_file": ("test.txt", b"not xml")}, - data={"ontology_key": unique_key} + data={"ontology_key": unique_key}, ) assert response.status_code == 400 + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user): """Test 400 response for missing file or key""" @@ -70,6 +75,7 @@ def test_upload_ontology_missing_data(mock_get_default_user, client, mock_defaul response = client.post("/api/v1/ontologies", files={"ontology_file": ("test.owl", b"xml")}) assert response.status_code == 400 + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user): """Test behavior when default user is provided (no explicit authentication)""" @@ -78,7 +84,7 @@ def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_defaul response = client.post( "/api/v1/ontologies", files={"ontology_file": ("test.owl", b"")}, - data={"ontology_key": unique_key} + data={"ontology_key": unique_key}, ) # The current system provides a default user when no explicit authentication is given @@ -86,4 +92,4 @@ def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_defaul assert response.status_code == 200 data = response.json() assert data["ontology_key"] == unique_key - assert "uploaded_at" in data \ No newline at end of file + assert "uploaded_at" in data From a058250c95390df84f1c44419cff8adf3a4e8269 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 6 Nov 2025 13:03:11 +0100 Subject: [PATCH 05/16] fix: add cognee to the local run environment --- cognee/modules/notebooks/operations/run_in_local_sandbox.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cognee/modules/notebooks/operations/run_in_local_sandbox.py b/cognee/modules/notebooks/operations/run_in_local_sandbox.py index 071deafb7..46499186e 100644 --- a/cognee/modules/notebooks/operations/run_in_local_sandbox.py +++ b/cognee/modules/notebooks/operations/run_in_local_sandbox.py @@ -2,6 +2,8 @@ import io import sys import traceback +import cognee + def wrap_in_async_handler(user_code: str) -> str: return ( @@ -34,6 +36,7 @@ def run_in_local_sandbox(code, environment=None, loop=None): environment["print"] = customPrintFunction environment["running_loop"] = loop + environment["cognee"] = cognee try: exec(code, environment) From 2337d36f7b3968cfeff06b00613f7464c8d0ca93 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 13 Nov 2025 18:25:07 +0100 Subject: [PATCH 06/16] feat: add variable to control instructor mode --- cognee/infrastructure/llm/config.py | 2 ++ .../litellm_instructor/llm/anthropic/adapter.py | 8 +++++++- .../litellm_instructor/llm/gemini/adapter.py | 12 +++++++++++- .../llm/generic_llm_api/adapter.py | 12 +++++++++++- .../litellm_instructor/llm/mistral/adapter.py | 8 +++++++- .../litellm_instructor/llm/ollama/adapter.py | 12 +++++++++++- .../litellm_instructor/llm/openai/adapter.py | 11 +++++++++-- 7 files changed, 58 insertions(+), 7 deletions(-) diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 8fd196eaf..c87054ff6 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -38,6 +38,7 @@ class LLMConfig(BaseSettings): """ structured_output_framework: str = "instructor" + llm_instructor_mode: Optional[str] = None llm_provider: str = "openai" llm_model: str = "openai/gpt-5-mini" llm_endpoint: str = "" @@ -181,6 +182,7 @@ class LLMConfig(BaseSettings): instance. """ return { + "llm_instructor_mode": self.llm_instructor_mode, "provider": self.llm_provider, "model": self.llm_model, "endpoint": self.llm_endpoint, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index bf19d6e86..6fb78718e 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -28,13 +28,19 @@ class AnthropicAdapter(LLMInterface): name = "Anthropic" model: str + default_instructor_mode = "anthropic_tools" def __init__(self, max_completion_tokens: int, model: str = None): import anthropic + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) + self.aclient = instructor.patch( create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create, - mode=instructor.Mode.ANTHROPIC_TOOLS, + mode=instructor.Mode(instructor_mode), ) self.model = model diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 1187e0cad..68dddc7b7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -41,6 +41,7 @@ class GeminiAdapter(LLMInterface): name: str model: str api_key: str + default_instructor_mode = "json_mode" def __init__( self, @@ -63,7 +64,16 @@ class GeminiAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) + from cognee.infrastructure.llm.config import get_llm_config + + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) + + self.aclient = instructor.from_litellm( + litellm.acompletion, mode=instructor.Mode(instructor_mode) + ) @retry( stop=stop_after_delay(128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 8bbbaa2cc..ea32dced1 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -41,6 +41,7 @@ class GenericAPIAdapter(LLMInterface): name: str model: str api_key: str + default_instructor_mode = "json_mode" def __init__( self, @@ -63,7 +64,16 @@ class GenericAPIAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) + from cognee.infrastructure.llm.config import get_llm_config + + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) + + self.aclient = instructor.from_litellm( + litellm.acompletion, mode=instructor.Mode(instructor_mode) + ) @retry( stop=stop_after_delay(128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index 78a3cbff5..bed88ce3c 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -37,6 +37,7 @@ class MistralAdapter(LLMInterface): model: str api_key: str max_completion_tokens: int + default_instructor_mode = "mistral_tools" def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None): from mistralai import Mistral @@ -44,9 +45,14 @@ class MistralAdapter(LLMInterface): self.model = model self.max_completion_tokens = max_completion_tokens + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) + self.aclient = instructor.from_litellm( litellm.acompletion, - mode=instructor.Mode.MISTRAL_TOOLS, + mode=instructor.Mode(instructor_mode), api_key=get_llm_config().llm_api_key, ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index 9c3d185aa..aa24a7911 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -42,6 +42,8 @@ class OllamaAPIAdapter(LLMInterface): - aclient """ + default_instructor_mode = "json_mode" + def __init__( self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int ): @@ -51,8 +53,16 @@ class OllamaAPIAdapter(LLMInterface): self.endpoint = endpoint self.max_completion_tokens = max_completion_tokens + from cognee.infrastructure.llm.config import get_llm_config + + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) + self.aclient = instructor.from_openai( - OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON + OpenAI(base_url=self.endpoint, api_key=self.api_key), + mode=instructor.Mode(instructor_mode), ) @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 305b426b8..69367602d 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -56,6 +56,7 @@ class OpenAIAdapter(LLMInterface): model: str api_key: str api_version: str + default_instructor_mode = "json_schema_mode" MAX_RETRIES = 5 @@ -74,14 +75,20 @@ class OpenAIAdapter(LLMInterface): fallback_api_key: str = None, fallback_endpoint: str = None, ): + from cognee.infrastructure.llm.config import get_llm_config + + config_instructor_mode = get_llm_config().llm_instructor_mode + instructor_mode = ( + config_instructor_mode if config_instructor_mode else self.default_instructor_mode + ) # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs. # Make sure all new gpt models will work with this mode as well. if "gpt-5" in model: self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA + litellm.acompletion, mode=instructor.Mode(instructor_mode) ) self.client = instructor.from_litellm( - litellm.completion, mode=instructor.Mode.JSON_SCHEMA + litellm.completion, mode=instructor.Mode(instructor_mode) ) else: self.aclient = instructor.from_litellm(litellm.acompletion) From 661c194f97df5053f70a52d1638c77c23e0d50e3 Mon Sep 17 00:00:00 2001 From: EricXiao Date: Fri, 14 Nov 2025 15:21:47 +0800 Subject: [PATCH 07/16] fix: Resolve issue with csv suffix classification Signed-off-by: EricXiao --- cognee/infrastructure/files/utils/guess_file_type.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cognee/infrastructure/files/utils/guess_file_type.py b/cognee/infrastructure/files/utils/guess_file_type.py index 78b20c93d..4bc96fe80 100644 --- a/cognee/infrastructure/files/utils/guess_file_type.py +++ b/cognee/infrastructure/files/utils/guess_file_type.py @@ -55,6 +55,10 @@ def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type file_type = Type("text/plain", "txt") return file_type + if ext in [".csv"]: + file_type = Type("text/csv", "csv") + return file_type + file_type = filetype.guess(file) # If file type could not be determined consider it a plain text file as they don't have magic number encoding From 205f5a9e0c6d0fb72cc94fa20ce2ba814ebef0d5 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 14 Nov 2025 11:05:39 +0100 Subject: [PATCH 08/16] fix: Fix based on PR comments --- .../litellm_instructor/llm/anthropic/adapter.py | 9 +++------ .../litellm_instructor/llm/gemini/adapter.py | 10 +++------- .../llm/generic_llm_api/adapter.py | 10 +++------- .../litellm_instructor/llm/get_llm_client.py | 9 ++++++++- .../litellm_instructor/llm/mistral/adapter.py | 16 ++++++++++------ .../litellm_instructor/llm/ollama/adapter.py | 17 +++++++++-------- .../litellm_instructor/llm/openai/adapter.py | 12 ++++-------- 7 files changed, 40 insertions(+), 43 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index 6fb78718e..dbf0dfbea 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -30,17 +30,14 @@ class AnthropicAdapter(LLMInterface): model: str default_instructor_mode = "anthropic_tools" - def __init__(self, max_completion_tokens: int, model: str = None): + def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None): import anthropic - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.patch( create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create, - mode=instructor.Mode(instructor_mode), + mode=instructor.Mode(self.instructor_mode), ) self.model = model diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 68dddc7b7..226f291d7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -50,6 +50,7 @@ class GeminiAdapter(LLMInterface): model: str, api_version: str, max_completion_tokens: int, + instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, @@ -64,15 +65,10 @@ class GeminiAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - from cognee.infrastructure.llm.config import get_llm_config - - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode(instructor_mode) + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index ea32dced1..9d7f25fc5 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -50,6 +50,7 @@ class GenericAPIAdapter(LLMInterface): model: str, name: str, max_completion_tokens: int, + instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, @@ -64,15 +65,10 @@ class GenericAPIAdapter(LLMInterface): self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint - from cognee.infrastructure.llm.config import get_llm_config - - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode(instructor_mode) + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index c7dcecc56..537eda1b2 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -81,6 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, transcription_model=llm_config.transcription_model, max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode, streaming=llm_config.llm_streaming, fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, @@ -101,6 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Ollama", max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode, ) elif provider == LLMProvider.ANTHROPIC: @@ -109,7 +111,9 @@ def get_llm_client(raise_api_key_error: bool = True): ) return AnthropicAdapter( - max_completion_tokens=max_completion_tokens, model=llm_config.llm_model + max_completion_tokens=max_completion_tokens, + model=llm_config.llm_model, + instructor_mode=llm_config.llm_instructor_mode, ) elif provider == LLMProvider.CUSTOM: @@ -126,6 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Custom", max_completion_tokens=max_completion_tokens, + instructor_mode=llm_config.llm_instructor_mode, fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, fallback_model=llm_config.fallback_model, @@ -145,6 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True): max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, api_version=llm_config.llm_api_version, + instructor_mode=llm_config.llm_instructor_mode, ) elif provider == LLMProvider.MISTRAL: @@ -160,6 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, + instructor_mode=llm_config.llm_instructor_mode, ) elif provider == LLMProvider.MISTRAL: diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index bed88ce3c..355cdae0b 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -39,20 +39,24 @@ class MistralAdapter(LLMInterface): max_completion_tokens: int default_instructor_mode = "mistral_tools" - def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None): + def __init__( + self, + api_key: str, + model: str, + max_completion_tokens: int, + endpoint: str = None, + instructor_mode: str = None, + ): from mistralai import Mistral self.model = model self.max_completion_tokens = max_completion_tokens - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_litellm( litellm.acompletion, - mode=instructor.Mode(instructor_mode), + mode=instructor.Mode(self.instructor_mode), api_key=get_llm_config().llm_api_key, ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index aa24a7911..aabd19867 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -45,7 +45,13 @@ class OllamaAPIAdapter(LLMInterface): default_instructor_mode = "json_mode" def __init__( - self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int + self, + endpoint: str, + api_key: str, + model: str, + name: str, + max_completion_tokens: int, + instructor_mode: str = None, ): self.name = name self.model = model @@ -53,16 +59,11 @@ class OllamaAPIAdapter(LLMInterface): self.endpoint = endpoint self.max_completion_tokens = max_completion_tokens - from cognee.infrastructure.llm.config import get_llm_config - - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_openai( OpenAI(base_url=self.endpoint, api_key=self.api_key), - mode=instructor.Mode(instructor_mode), + mode=instructor.Mode(self.instructor_mode), ) @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 69367602d..778c8eec7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -70,25 +70,21 @@ class OpenAIAdapter(LLMInterface): model: str, transcription_model: str, max_completion_tokens: int, + instructor_mode: str = None, streaming: bool = False, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): - from cognee.infrastructure.llm.config import get_llm_config - - config_instructor_mode = get_llm_config().llm_instructor_mode - instructor_mode = ( - config_instructor_mode if config_instructor_mode else self.default_instructor_mode - ) + self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs. # Make sure all new gpt models will work with this mode as well. if "gpt-5" in model: self.aclient = instructor.from_litellm( - litellm.acompletion, mode=instructor.Mode(instructor_mode) + litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) self.client = instructor.from_litellm( - litellm.completion, mode=instructor.Mode(instructor_mode) + litellm.completion, mode=instructor.Mode(self.instructor_mode) ) else: self.aclient = instructor.from_litellm(litellm.acompletion) From 844b8d635a7646750dc63dd4be13de07f2996940 Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Fri, 14 Nov 2025 22:13:00 +0500 Subject: [PATCH 09/16] feat: enhance ontology handling to support multiple uploads and retrievals --- .../v1/cognify/routers/get_cognify_router.py | 41 +++---- cognee/api/v1/ontologies/ontologies.py | 108 +++++++++++++++--- .../ontologies/routers/get_ontology_router.py | 51 ++++++--- .../rdf_xml/RDFLibOntologyResolver.py | 93 +++++++++------ 4 files changed, 202 insertions(+), 91 deletions(-) diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 252ffe7bf..4f1497e3c 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -41,8 +41,8 @@ class CognifyPayloadDTO(InDTO): custom_prompt: Optional[str] = Field( default="", description="Custom prompt for entity extraction and graph generation" ) - ontology_key: Optional[str] = Field( - default=None, description="Reference to previously uploaded ontology" + ontology_key: Optional[List[str]] = Field( + default=None, description="Reference to one or more previously uploaded ontologies" ) @@ -71,7 +71,7 @@ def get_cognify_router() -> APIRouter: - **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted). - **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking). - **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction. - - **ontology_key** (Optional[str]): Reference to a previously uploaded ontology file to use for knowledge graph construction. + - **ontology_key** (Optional[List[str]]): Reference to one or more previously uploaded ontology files to use for knowledge graph construction. ## Response - **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status @@ -87,7 +87,7 @@ def get_cognify_router() -> APIRouter: "datasets": ["research_papers", "documentation"], "run_in_background": false, "custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.", - "ontology_key": "medical_ontology_v1" + "ontology_key": ["medical_ontology_v1"] } ``` @@ -121,29 +121,22 @@ def get_cognify_router() -> APIRouter: if payload.ontology_key: ontology_service = OntologyService() - try: - ontology_content = ontology_service.get_ontology_content( - payload.ontology_key, user - ) + ontology_contents = ontology_service.get_ontology_contents( + payload.ontology_key, user + ) - from cognee.modules.ontology.ontology_config import Config - from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import ( - RDFLibOntologyResolver, - ) - from io import StringIO + from cognee.modules.ontology.ontology_config import Config + from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import ( + RDFLibOntologyResolver, + ) + from io import StringIO - ontology_stream = StringIO(ontology_content) - config_to_use: Config = { - "ontology_config": { - "ontology_resolver": RDFLibOntologyResolver( - ontology_file=ontology_stream - ) - } + ontology_streams = [StringIO(content) for content in ontology_contents] + config_to_use: Config = { + "ontology_config": { + "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_streams) } - except ValueError as e: - return JSONResponse( - status_code=400, content={"error": f"Ontology error: {str(e)}"} - ) + } cognify_run = await cognee_cognify( datasets, diff --git a/cognee/api/v1/ontologies/ontologies.py b/cognee/api/v1/ontologies/ontologies.py index 6bfb7658e..130b4a862 100644 --- a/cognee/api/v1/ontologies/ontologies.py +++ b/cognee/api/v1/ontologies/ontologies.py @@ -3,7 +3,7 @@ import json import tempfile from pathlib import Path from datetime import datetime, timezone -from typing import Optional +from typing import Optional, List from dataclasses import dataclass @@ -47,28 +47,23 @@ class OntologyService: async def upload_ontology( self, ontology_key: str, file, user, description: Optional[str] = None ) -> OntologyMetadata: - # Validate file format if not file.filename.lower().endswith(".owl"): raise ValueError("File must be in .owl format") user_dir = self._get_user_dir(str(user.id)) metadata = self._load_metadata(user_dir) - # Check for duplicate key if ontology_key in metadata: raise ValueError(f"Ontology key '{ontology_key}' already exists") - # Read file content content = await file.read() - if len(content) > 10 * 1024 * 1024: # 10MB limit + if len(content) > 10 * 1024 * 1024: raise ValueError("File size exceeds 10MB limit") - # Save file file_path = user_dir / f"{ontology_key}.owl" with open(file_path, "wb") as f: f.write(content) - # Update metadata ontology_metadata = { "filename": file.filename, "size_bytes": len(content), @@ -86,19 +81,102 @@ class OntologyService: description=description, ) - def get_ontology_content(self, ontology_key: str, user) -> str: + async def upload_ontologies( + self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None + ) -> List[OntologyMetadata]: + """ + Upload ontology files with their respective keys. + + Args: + ontology_key: List of unique keys for each ontology + files: List of UploadFile objects (same length as keys) + user: Authenticated user + descriptions: Optional list of descriptions for each file + + Returns: + List of OntologyMetadata objects for uploaded files + + Raises: + ValueError: If keys duplicate, file format invalid, or array lengths don't match + """ + if len(ontology_key) != len(files): + raise ValueError("Number of keys must match number of files") + + if len(set(ontology_key)) != len(ontology_key): + raise ValueError("Duplicate ontology keys not allowed") + + if descriptions and len(descriptions) != len(files): + raise ValueError("Number of descriptions must match number of files") + + results = [] user_dir = self._get_user_dir(str(user.id)) metadata = self._load_metadata(user_dir) - if ontology_key not in metadata: - raise ValueError(f"Ontology key '{ontology_key}' not found") + for i, (key, file) in enumerate(zip(ontology_key, files)): + if key in metadata: + raise ValueError(f"Ontology key '{key}' already exists") - file_path = user_dir / f"{ontology_key}.owl" - if not file_path.exists(): - raise ValueError(f"Ontology file for key '{ontology_key}' not found") + if not file.filename.lower().endswith(".owl"): + raise ValueError(f"File '{file.filename}' must be in .owl format") - with open(file_path, "r", encoding="utf-8") as f: - return f.read() + content = await file.read() + if len(content) > 10 * 1024 * 1024: + raise ValueError(f"File '{file.filename}' exceeds 10MB limit") + + file_path = user_dir / f"{key}.owl" + with open(file_path, "wb") as f: + f.write(content) + + ontology_metadata = { + "filename": file.filename, + "size_bytes": len(content), + "uploaded_at": datetime.now(timezone.utc).isoformat(), + "description": descriptions[i] if descriptions else None, + } + metadata[key] = ontology_metadata + + results.append( + OntologyMetadata( + ontology_key=key, + filename=file.filename, + size_bytes=len(content), + uploaded_at=ontology_metadata["uploaded_at"], + description=descriptions[i] if descriptions else None, + ) + ) + + self._save_metadata(user_dir, metadata) + return results + + def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]: + """ + Retrieve ontology content for one or more keys. + + Args: + ontology_key: List of ontology keys to retrieve (can contain single item) + user: Authenticated user + + Returns: + List of ontology content strings + + Raises: + ValueError: If any ontology key not found + """ + user_dir = self._get_user_dir(str(user.id)) + metadata = self._load_metadata(user_dir) + + contents = [] + for key in ontology_key: + if key not in metadata: + raise ValueError(f"Ontology key '{key}' not found") + + file_path = user_dir / f"{key}.owl" + if not file_path.exists(): + raise ValueError(f"Ontology file for key '{key}' not found") + + with open(file_path, "r", encoding="utf-8") as f: + contents.append(f.read()) + return contents def list_ontologies(self, user) -> dict: user_dir = self._get_user_dir(str(user.id)) diff --git a/cognee/api/v1/ontologies/routers/get_ontology_router.py b/cognee/api/v1/ontologies/routers/get_ontology_router.py index f5c51ba21..ee31c683f 100644 --- a/cognee/api/v1/ontologies/routers/get_ontology_router.py +++ b/cognee/api/v1/ontologies/routers/get_ontology_router.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException from fastapi.responses import JSONResponse -from typing import Optional +from typing import Optional, List from cognee.modules.users.models import User from cognee.modules.users.methods import get_authenticated_user @@ -16,23 +16,27 @@ def get_ontology_router() -> APIRouter: @router.post("", response_model=dict) async def upload_ontology( ontology_key: str = Form(...), - ontology_file: UploadFile = File(...), - description: Optional[str] = Form(None), + ontology_file: List[UploadFile] = File(...), + descriptions: Optional[str] = Form(None), user: User = Depends(get_authenticated_user), ): """ - Upload an ontology file with a named key for later use in cognify operations. + Upload ontology files with their respective keys for later use in cognify operations. + + Supports both single and multiple file uploads: + - Single file: ontology_key=["key"], ontology_file=[file] + - Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2] ## Request Parameters - - **ontology_key** (str): User-defined identifier for the ontology - - **ontology_file** (UploadFile): OWL format ontology file - - **description** (Optional[str]): Optional description of the ontology + - **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies + - **ontology_file** (List[UploadFile]): OWL format ontology files + - **descriptions** (Optional[str]): JSON array string of optional descriptions ## Response - Returns metadata about the uploaded ontology including key, filename, size, and upload timestamp. + Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps. ## Error Codes - - **400 Bad Request**: Invalid file format, duplicate key, file size exceeded + - **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded - **500 Internal Server Error**: File system or processing errors """ send_telemetry( @@ -45,16 +49,31 @@ def get_ontology_router() -> APIRouter: ) try: - result = await ontology_service.upload_ontology( - ontology_key, ontology_file, user, description + import json + + ontology_keys = json.loads(ontology_key) + description_list = json.loads(descriptions) if descriptions else None + + if not isinstance(ontology_keys, list): + raise ValueError("ontology_key must be a JSON array") + + results = await ontology_service.upload_ontologies( + ontology_keys, ontology_file, user, description_list ) + return { - "ontology_key": result.ontology_key, - "filename": result.filename, - "size_bytes": result.size_bytes, - "uploaded_at": result.uploaded_at, + "uploaded_ontologies": [ + { + "ontology_key": result.ontology_key, + "filename": result.filename, + "size_bytes": result.size_bytes, + "uploaded_at": result.uploaded_at, + "description": result.description, + } + for result in results + ] } - except ValueError as e: + except (json.JSONDecodeError, ValueError) as e: return JSONResponse(status_code=400, content={"error": str(e)}) except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) diff --git a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py index 4acc8861b..34d7a946a 100644 --- a/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +++ b/cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py @@ -26,7 +26,7 @@ class RDFLibOntologyResolver(BaseOntologyResolver): def __init__( self, - ontology_file: Optional[Union[str, List[str], IO]] = None, + ontology_file: Optional[Union[str, List[str], IO, List[IO]]] = None, matching_strategy: Optional[MatchingStrategy] = None, ) -> None: super().__init__(matching_strategy) @@ -34,47 +34,68 @@ class RDFLibOntologyResolver(BaseOntologyResolver): try: self.graph = None if ontology_file is not None: + files_to_load = [] + file_objects = [] + if hasattr(ontology_file, "read"): - self.graph = Graph() - content = ontology_file.read() - self.graph.parse(data=content, format="xml") - logger.info("Ontology loaded successfully from file object") - else: - files_to_load = [] - if isinstance(ontology_file, str): - files_to_load = [ontology_file] - elif isinstance(ontology_file, list): + file_objects = [ontology_file] + elif isinstance(ontology_file, str): + files_to_load = [ontology_file] + elif isinstance(ontology_file, list): + if all(hasattr(item, "read") for item in ontology_file): + file_objects = ontology_file + else: files_to_load = ontology_file - else: - raise ValueError( - f"ontology_file must be a string, list of strings, file-like object, or None. Got: {type(ontology_file)}" - ) + else: + raise ValueError( + f"ontology_file must be a string, list of strings, file-like object, list of file-like objects, or None. Got: {type(ontology_file)}" + ) - if files_to_load: - self.graph = Graph() - loaded_files = [] - for file_path in files_to_load: - if os.path.exists(file_path): - self.graph.parse(file_path) - loaded_files.append(file_path) - logger.info("Ontology loaded successfully from file: %s", file_path) - else: - logger.warning( - "Ontology file '%s' not found. Skipping this file.", - file_path, - ) + if file_objects: + self.graph = Graph() + loaded_objects = [] + for file_obj in file_objects: + try: + content = file_obj.read() + self.graph.parse(data=content, format="xml") + loaded_objects.append(file_obj) + logger.info("Ontology loaded successfully from file object") + except Exception as e: + logger.warning("Failed to parse ontology file object: %s", str(e)) - if not loaded_files: - logger.info( - "No valid ontology files found. No owl ontology will be attached to the graph." - ) - self.graph = None - else: - logger.info("Total ontology files loaded: %d", len(loaded_files)) - else: + if not loaded_objects: logger.info( - "No ontology file provided. No owl ontology will be attached to the graph." + "No valid ontology file objects found. No owl ontology will be attached to the graph." ) + self.graph = None + else: + logger.info("Total ontology file objects loaded: %d", len(loaded_objects)) + + elif files_to_load: + self.graph = Graph() + loaded_files = [] + for file_path in files_to_load: + if os.path.exists(file_path): + self.graph.parse(file_path) + loaded_files.append(file_path) + logger.info("Ontology loaded successfully from file: %s", file_path) + else: + logger.warning( + "Ontology file '%s' not found. Skipping this file.", + file_path, + ) + + if not loaded_files: + logger.info( + "No valid ontology files found. No owl ontology will be attached to the graph." + ) + self.graph = None + else: + logger.info("Total ontology files loaded: %d", len(loaded_files)) + else: + logger.info( + "No ontology file provided. No owl ontology will be attached to the graph." + ) else: logger.info( "No ontology file provided. No owl ontology will be attached to the graph." From 01f1c099cc972e2222f1174c515e1baf87fbb9d6 Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Fri, 14 Nov 2025 22:20:54 +0500 Subject: [PATCH 10/16] test: enhance server start test with ontology upload verification - Extend test_cognee_server_start to upload ontology and verify integration - Move test_ontology_endpoint from tests/ to tests/unit/api/ --- cognee/tests/test_cognee_server_start.py | 45 ++++++++++++++++++- .../{ => unit/api}/test_ontology_endpoint.py | 0 2 files changed, 44 insertions(+), 1 deletion(-) rename cognee/tests/{ => unit/api}/test_ontology_endpoint.py (100%) diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py index ab68a8ef1..d6aa55a98 100644 --- a/cognee/tests/test_cognee_server_start.py +++ b/cognee/tests/test_cognee_server_start.py @@ -7,6 +7,7 @@ import requests from pathlib import Path import sys import uuid +import json class TestCogneeServerStart(unittest.TestCase): @@ -90,12 +91,31 @@ class TestCogneeServerStart(unittest.TestCase): ) } - payload = {"datasets": [dataset_name]} + ontology_key = f"test_ontology_{uuid.uuid4().hex[:8]}" + payload = {"datasets": [dataset_name], "ontology_key": [ontology_key]} add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50) if add_response.status_code not in [200, 201]: add_response.raise_for_status() + ontology_content = b""" + + + + + """ + + ontology_response = requests.post( + "http://127.0.0.1:8000/api/v1/ontologies", + files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], + data={ + "ontology_key": json.dumps([ontology_key]), + "description": json.dumps(["Test ontology"]), + }, + ) + self.assertEqual(ontology_response.status_code, 200) + # Cognify request url = "http://127.0.0.1:8000/api/v1/cognify" headers = { @@ -107,6 +127,29 @@ class TestCogneeServerStart(unittest.TestCase): if cognify_response.status_code not in [200, 201]: cognify_response.raise_for_status() + datasets_response = requests.get("http://127.0.0.1:8000/api/v1/datasets", headers=headers) + + datasets = datasets_response.json() + dataset_id = None + for dataset in datasets: + if dataset["name"] == dataset_name: + dataset_id = dataset["id"] + break + + graph_response = requests.get( + f"http://127.0.0.1:8000/api/v1/datasets/{dataset_id}/graph", headers=headers + ) + self.assertEqual(graph_response.status_code, 200) + + graph_data = graph_response.json() + ontology_nodes = [ + node for node in graph_data.get("nodes") if node.get("properties").get("ontology_valid") + ] + + self.assertGreater( + len(ontology_nodes), 0, "No ontology nodes found - ontology was not integrated" + ) + # TODO: Add test to verify cognify pipeline is complete before testing search # Search request diff --git a/cognee/tests/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py similarity index 100% rename from cognee/tests/test_ontology_endpoint.py rename to cognee/tests/unit/api/test_ontology_endpoint.py From 1ded09d0f995fa57c9eaa3feafbe64089525a92f Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Sat, 15 Nov 2025 00:06:55 +0500 Subject: [PATCH 11/16] fix: fixed test ontology file content. Added tests to support multiple files and improved validation. --- cognee/tests/test_cognee_server_start.py | 13 +- .../tests/unit/api/test_ontology_endpoint.py | 189 +++++++++++++++++- 2 files changed, 185 insertions(+), 17 deletions(-) diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py index d6aa55a98..b266fc7bf 100644 --- a/cognee/tests/test_cognee_server_start.py +++ b/cognee/tests/test_cognee_server_start.py @@ -98,13 +98,12 @@ class TestCogneeServerStart(unittest.TestCase): if add_response.status_code not in [200, 201]: add_response.raise_for_status() - ontology_content = b""" - - - - - """ + ontology_content = b""" + + + + + """ ontology_response = requests.post( "http://127.0.0.1:8000/api/v1/ontologies", diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py index b5cedfafe..c04959998 100644 --- a/cognee/tests/unit/api/test_ontology_endpoint.py +++ b/cognee/tests/unit/api/test_ontology_endpoint.py @@ -32,6 +32,8 @@ def mock_default_user(): @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_success(mock_get_default_user, client, mock_default_user): """Test successful ontology upload""" + import json + mock_get_default_user.return_value = mock_default_user ontology_content = ( b"" @@ -40,14 +42,14 @@ def test_upload_ontology_success(mock_get_default_user, client, mock_default_use response = client.post( "/api/v1/ontologies", - files={"ontology_file": ("test.owl", ontology_content)}, - data={"ontology_key": unique_key, "description": "Test"}, + files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], + data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])}, ) assert response.status_code == 200 data = response.json() - assert data["ontology_key"] == unique_key - assert "uploaded_at" in data + assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key + assert "uploaded_at" in data["uploaded_ontologies"][0] @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) @@ -66,30 +68,197 @@ def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_defaul @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user): """Test 400 response for missing file or key""" + import json + mock_get_default_user.return_value = mock_default_user # Missing file - response = client.post("/api/v1/ontologies", data={"ontology_key": "test"}) + response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])}) assert response.status_code == 400 # Missing key - response = client.post("/api/v1/ontologies", files={"ontology_file": ("test.owl", b"xml")}) + response = client.post( + "/api/v1/ontologies", files=[("ontology_file", ("test.owl", b"xml", "application/xml"))] + ) assert response.status_code == 400 @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user): """Test behavior when default user is provided (no explicit authentication)""" + import json + unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}" mock_get_default_user.return_value = mock_default_user response = client.post( "/api/v1/ontologies", - files={"ontology_file": ("test.owl", b"")}, - data={"ontology_key": unique_key}, + files=[("ontology_file", ("test.owl", b"", "application/xml"))], + data={"ontology_key": json.dumps([unique_key])}, ) # The current system provides a default user when no explicit authentication is given # This test verifies the system works with conditional authentication assert response.status_code == 200 data = response.json() - assert data["ontology_key"] == unique_key - assert "uploaded_at" in data + assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key + assert "uploaded_at" in data["uploaded_ontologies"][0] + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user): + """Test uploading multiple ontology files in single request""" + import io + + # Create mock files + file1_content = b"" + file2_content = b"" + + files = [ + ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), + ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), + ] + data = { + "ontology_key": '["vehicles", "manufacturers"]', + "descriptions": '["Base vehicles", "Car manufacturers"]', + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + + assert response.status_code == 200 + result = response.json() + assert "uploaded_ontologies" in result + assert len(result["uploaded_ontologies"]) == 2 + assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles" + assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers" + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user): + """Test that upload endpoint accepts array parameters""" + import io + import json + + file_content = b"" + + files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))] + data = { + "ontology_key": json.dumps(["single_key"]), + "descriptions": json.dumps(["Single ontology"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + + assert response.status_code == 200 + result = response.json() + assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key" + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user): + """Test cognify endpoint accepts multiple ontology keys""" + payload = { + "datasets": ["test_dataset"], + "ontology_key": ["ontology1", "ontology2"], # Array instead of string + "run_in_background": False, + } + + response = client.post("/api/v1/cognify", json=payload) + + # Should not fail due to ontology_key type + assert response.status_code in [200, 400, 409] # May fail for other reasons, not type + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user): + """Test complete workflow: upload multiple ontologies → cognify with multiple keys""" + import io + import json + + # Step 1: Upload multiple ontologies + file1_content = b""" + + + """ + + file2_content = b""" + + + """ + + files = [ + ("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")), + ("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")), + ] + data = { + "ontology_key": json.dumps(["vehicles", "manufacturers"]), + "descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]), + } + + upload_response = client.post("/api/v1/ontologies", files=files, data=data) + assert upload_response.status_code == 200 + + # Step 2: Verify ontologies are listed + list_response = client.get("/api/v1/ontologies") + assert list_response.status_code == 200 + ontologies = list_response.json() + assert "vehicles" in ontologies + assert "manufacturers" in ontologies + + # Step 3: Test cognify with multiple ontologies + cognify_payload = { + "datasets": ["test_dataset"], + "ontology_key": ["vehicles", "manufacturers"], + "run_in_background": False, + } + + cognify_response = client.post("/api/v1/cognify", json=cognify_payload) + # Should not fail due to ontology handling (may fail for dataset reasons) + assert cognify_response.status_code != 400 # Not a validation error + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_multifile_error_handling(mock_get_default_user, client, mock_default_user): + """Test error handling for invalid multifile uploads""" + import io + import json + + # Test mismatched array lengths + file_content = b"" + files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))] + data = { + "ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file + "descriptions": json.dumps(["desc1"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + assert response.status_code == 400 + assert "Number of keys must match number of files" in response.json()["error"] + + # Test duplicate keys + files = [ + ("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")), + ("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")), + ] + data = { + "ontology_key": json.dumps(["duplicate", "duplicate"]), + "descriptions": json.dumps(["desc1", "desc2"]), + } + + response = client.post("/api/v1/ontologies", files=files, data=data) + assert response.status_code == 400 + assert "Duplicate ontology keys not allowed" in response.json()["error"] + + +@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) +async def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user): + """Test cognify with non-existent ontology key""" + payload = { + "datasets": ["test_dataset"], + "ontology_key": ["nonexistent_key"], + "run_in_background": False, + } + + response = client.post("/api/v1/cognify", json=payload) + assert response.status_code == 409 + assert "Ontology key 'nonexistent_key' not found" in response.json()["error"] From 983bfae4fcc9046fd520f0c24733e491679d54cf Mon Sep 17 00:00:00 2001 From: EricXiao Date: Mon, 17 Nov 2025 14:41:55 +0800 Subject: [PATCH 12/16] chore: remove unnecessary csv file type Signed-off-by: EricXiao --- .../files/utils/is_csv_content.py | 181 ------------------ 1 file changed, 181 deletions(-) delete mode 100644 cognee/infrastructure/files/utils/is_csv_content.py diff --git a/cognee/infrastructure/files/utils/is_csv_content.py b/cognee/infrastructure/files/utils/is_csv_content.py deleted file mode 100644 index 07b7ea69b..000000000 --- a/cognee/infrastructure/files/utils/is_csv_content.py +++ /dev/null @@ -1,181 +0,0 @@ -import csv -from collections import Counter - - -def is_csv_content(content): - """ - Heuristically determine whether a bytes-like object is CSV text. - - Strategy (fail-fast and cheap to expensive): - 1) Decode: Try a small ordered list of common encodings with strict errors. - 2) Line sampling: require >= 2 non-empty lines; sample up to 50 lines. - 3) Delimiter detection: - - Prefer csv.Sniffer() with common delimiters. - - Fallback to a lightweight consistency heuristic. - 4) Lightweight parse check: - - Parse a few lines with the delimiter. - - Ensure at least 2 valid rows and relatively stable column counts. - - Returns: - bool: True if the buffer looks like CSV; False otherwise. - """ - try: - encoding_list = [ - "utf-8", - "utf-8-sig", - "utf-32-le", - "utf-32-be", - "utf-16-le", - "utf-16-be", - "gb18030", - "shift_jis", - "cp949", - "cp1252", - "iso-8859-1", - ] - - # Try to decode strictly—if decoding fails for all encodings, it's not text/CSV. - text = None - for enc in encoding_list: - try: - text = content.decode(enc, errors="strict") - break - except UnicodeDecodeError: - continue - if text is None: - return False - - # Reject empty/whitespace-only payloads. - stripped = text.strip() - if not stripped: - return False - - # Split into logical lines and drop empty ones. Require at least two lines. - lines = [ln for ln in text.splitlines() if ln.strip()] - if len(lines) < 2: - return False - - # Take a small sample to keep sniffing cheap and predictable. - sample_lines = lines[:50] - - # Detect delimiter using csv.Sniffer first; if that fails, use our heuristic. - delimiter = _sniff_delimiter(sample_lines) or _heuristic_delimiter(sample_lines) - if not delimiter: - return False - - # Finally, do a lightweight parse sanity check with the chosen delimiter. - return _lightweight_parse_check(sample_lines, delimiter) - except Exception: - return False - - -def _sniff_delimiter(lines): - """ - Try Python's built-in csv.Sniffer on a sample. - - Args: - lines (list[str]): Sample lines (already decoded). - - Returns: - str | None: The detected delimiter if sniffing succeeds; otherwise None. - """ - # Join up to 50 lines to form the sample string Sniffer will inspect. - sample = "\n".join(lines[:50]) - try: - dialect = csv.Sniffer().sniff(sample, delimiters=",\t;|") - return dialect.delimiter - except Exception: - # Sniffer is known to be brittle on small/dirty samples—silently fallback. - return None - - -def _heuristic_delimiter(lines): - """ - Fallback delimiter detection based on count consistency per line. - - Heuristic: - - For each candidate delimiter, count occurrences per line. - - Keep only lines with count > 0 (line must contain the delimiter). - - Require at least half of lines to contain the delimiter (min 2). - - Compute the mode (most common count). If the proportion of lines that - exhibit the modal count is >= 80%, accept that delimiter. - - Args: - lines (list[str]): Sample lines. - - Returns: - str | None: Best delimiter if one meets the consistency threshold; else None. - """ - candidates = [",", "\t", ";", "|"] - best = None - best_score = 0.0 - - for d in candidates: - # Count how many times the delimiter appears in each line. - counts = [ln.count(d) for ln in lines] - # Consider only lines that actually contain the delimiter at least once. - nonzero = [c for c in counts if c > 0] - - # Require that more than half of lines (and at least 2) contain the delimiter. - if len(nonzero) < max(2, int(0.5 * len(lines))): - continue - - # Find the modal count and its frequency. - cnt = Counter(nonzero) - pairs = cnt.most_common(1) - if not pairs: - continue - - mode, mode_freq = pairs[0] - # Consistency ratio: lines with the modal count / total lines in the sample. - consistency = mode_freq / len(lines) - # Accept if consistent enough and better than any previous candidate. - if mode >= 1 and consistency >= 0.80 and consistency > best_score: - best = d - best_score = consistency - - return best - - -def _lightweight_parse_check(lines, delimiter): - """ - Parse a few lines with csv.reader and check structural stability. - - Heuristic: - - Parse up to 5 lines with the given delimiter. - - Count column widths per parsed row. - - Require at least 2 non-empty rows. - - Allow at most 1 row whose width deviates by >2 columns from the first row. - - Args: - lines (list[str]): Sample lines (decoded). - delimiter (str): Delimiter chosen by sniffing/heuristics. - - Returns: - bool: True if parsing looks stable; False otherwise. - """ - try: - # csv.reader accepts any iterable of strings; feeding the first 10 lines is fine. - reader = csv.reader(lines[:10], delimiter=delimiter) - widths = [] - valid_rows = 0 - for row in reader: - if not row: - continue - - widths.append(len(row)) - valid_rows += 1 - - # Need at least two meaningful rows to make a judgment. - if valid_rows < 2: - return False - - if widths: - first = widths[0] - # Count rows whose width deviates significantly (>2) from the first row. - unstable = sum(1 for w in widths if abs(w - first) > 2) - # Permit at most 1 unstable row among the parsed sample. - return unstable <= 1 - return False - except Exception: - return False From 30e3971d44816db50d9e83eee39bf6d69b98a328 Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Thu, 20 Nov 2025 15:36:15 +0500 Subject: [PATCH 13/16] fix: add auth headers to ontology upload request and enhance ontology content --- cognee/tests/test_cognee_server_start.py | 53 +++++++++++++++++++++--- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/cognee/tests/test_cognee_server_start.py b/cognee/tests/test_cognee_server_start.py index b266fc7bf..ddffe53a4 100644 --- a/cognee/tests/test_cognee_server_start.py +++ b/cognee/tests/test_cognee_server_start.py @@ -98,15 +98,56 @@ class TestCogneeServerStart(unittest.TestCase): if add_response.status_code not in [200, 201]: add_response.raise_for_status() - ontology_content = b""" - - - - - """ + ontology_content = b""" + + + + + + + + + + + + + + + + A failure caused by physical components. + + + + + An error caused by software logic or configuration. + + + + A human being or individual. + + + + + Programmers + + + + Light Bulb + + + + Hardware Problem + + + """ ontology_response = requests.post( "http://127.0.0.1:8000/api/v1/ontologies", + headers=headers, files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))], data={ "ontology_key": json.dumps([ontology_key]), From 8cfb6c41eeca3b2ad0e34fb4b6043f4b5f6a8c00 Mon Sep 17 00:00:00 2001 From: Fahad Shoaib Date: Thu, 20 Nov 2025 15:54:09 +0500 Subject: [PATCH 14/16] fix: remove async from ontology endpoint test functions --- cognee/tests/unit/api/test_ontology_endpoint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cognee/tests/unit/api/test_ontology_endpoint.py b/cognee/tests/unit/api/test_ontology_endpoint.py index c04959998..d53c5ab44 100644 --- a/cognee/tests/unit/api/test_ontology_endpoint.py +++ b/cognee/tests/unit/api/test_ontology_endpoint.py @@ -104,7 +104,7 @@ def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_defaul @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user): +def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user): """Test uploading multiple ontology files in single request""" import io @@ -132,7 +132,7 @@ async def test_upload_multiple_ontologies(mock_get_default_user, client, mock_de @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user): +def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user): """Test that upload endpoint accepts array parameters""" import io import json @@ -153,7 +153,7 @@ async def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, moc @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user): +def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user): """Test cognify endpoint accepts multiple ontology keys""" payload = { "datasets": ["test_dataset"], @@ -168,7 +168,7 @@ async def test_cognify_with_multiple_ontologies(mock_get_default_user, client, m @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user): +def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user): """Test complete workflow: upload multiple ontologies → cognify with multiple keys""" import io import json @@ -218,7 +218,7 @@ async def test_complete_multifile_workflow(mock_get_default_user, client, mock_d @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_multifile_error_handling(mock_get_default_user, client, mock_default_user): +def test_multifile_error_handling(mock_get_default_user, client, mock_default_user): """Test error handling for invalid multifile uploads""" import io import json @@ -251,7 +251,7 @@ async def test_multifile_error_handling(mock_get_default_user, client, mock_defa @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) -async def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user): +def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user): """Test cognify with non-existent ontology key""" payload = { "datasets": ["test_dataset"], From 4e880eca8422872b26a568b18b9f1339ce362c18 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 20 Nov 2025 15:47:22 +0100 Subject: [PATCH 15/16] chore: update env template --- .env.template | 1 + 1 file changed, 1 insertion(+) diff --git a/.env.template b/.env.template index ae2cb1338..376233b1f 100644 --- a/.env.template +++ b/.env.template @@ -21,6 +21,7 @@ LLM_PROVIDER="openai" LLM_ENDPOINT="" LLM_API_VERSION="" LLM_MAX_TOKENS="16384" +LLM_INSTRUCTOR_MODE="json_schema_mode" # this mode is used for gpt-5 models EMBEDDING_PROVIDER="openai" EMBEDDING_MODEL="openai/text-embedding-3-large" From 204f9c2e4ad6dc706c03deed68adfbc4744ae6df Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Fri, 21 Nov 2025 16:20:19 +0100 Subject: [PATCH 16/16] fix: PR comment changes --- .env.template | 5 ++++- cognee/infrastructure/llm/config.py | 4 ++-- .../litellm_instructor/llm/get_llm_client.py | 12 ++++++------ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/.env.template b/.env.template index 376233b1f..61853b983 100644 --- a/.env.template +++ b/.env.template @@ -21,7 +21,10 @@ LLM_PROVIDER="openai" LLM_ENDPOINT="" LLM_API_VERSION="" LLM_MAX_TOKENS="16384" -LLM_INSTRUCTOR_MODE="json_schema_mode" # this mode is used for gpt-5 models +# Instructor's modes determine how structured data is requested from and extracted from LLM responses +# You can change this type (i.e. mode) via this env variable +# Each LLM has its own default value, e.g. gpt-5 models have "json_schema_mode" +LLM_INSTRUCTOR_MODE="" EMBEDDING_PROVIDER="openai" EMBEDDING_MODEL="openai/text-embedding-3-large" diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index c87054ff6..2e300dc0c 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -38,7 +38,7 @@ class LLMConfig(BaseSettings): """ structured_output_framework: str = "instructor" - llm_instructor_mode: Optional[str] = None + llm_instructor_mode: str = "" llm_provider: str = "openai" llm_model: str = "openai/gpt-5-mini" llm_endpoint: str = "" @@ -182,7 +182,7 @@ class LLMConfig(BaseSettings): instance. """ return { - "llm_instructor_mode": self.llm_instructor_mode, + "llm_instructor_mode": self.llm_instructor_mode.lower(), "provider": self.llm_provider, "model": self.llm_model, "endpoint": self.llm_endpoint, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index 537eda1b2..6ab3b91ad 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -81,7 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, transcription_model=llm_config.transcription_model, max_completion_tokens=max_completion_tokens, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), streaming=llm_config.llm_streaming, fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, @@ -102,7 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Ollama", max_completion_tokens=max_completion_tokens, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.ANTHROPIC: @@ -113,7 +113,7 @@ def get_llm_client(raise_api_key_error: bool = True): return AnthropicAdapter( max_completion_tokens=max_completion_tokens, model=llm_config.llm_model, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.CUSTOM: @@ -130,7 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_model, "Custom", max_completion_tokens=max_completion_tokens, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, fallback_model=llm_config.fallback_model, @@ -150,7 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True): max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, api_version=llm_config.llm_api_version, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.MISTRAL: @@ -166,7 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True): model=llm_config.llm_model, max_completion_tokens=max_completion_tokens, endpoint=llm_config.llm_endpoint, - instructor_mode=llm_config.llm_instructor_mode, + instructor_mode=llm_config.llm_instructor_mode.lower(), ) elif provider == LLMProvider.MISTRAL: