diff --git a/.env.template b/.env.template index 6ce9fedf9..ec6d01596 100644 --- a/.env.template +++ b/.env.template @@ -7,6 +7,7 @@ LLM_MODEL="openai/gpt-4o-mini" LLM_PROVIDER="openai" LLM_ENDPOINT="" LLM_API_VERSION="" +LLM_MAX_TOKENS="128000" GRAPHISTRY_USERNAME= GRAPHISTRY_PASSWORD= diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index 738f77c52..73504f057 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -4,6 +4,8 @@ from typing import Union from pydantic import BaseModel +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.modules.cognify.config import get_cognify_config from cognee.modules.data.methods import get_datasets, get_datasets_by_name from cognee.modules.data.methods.get_dataset_data import get_dataset_data @@ -146,12 +148,23 @@ async def get_default_tasks( if user is None: user = await get_default_user() + # Calculate max chunk size based on the following formula + embedding_engine = get_vector_engine().embedding_engine + llm_client = get_llm_client() + + # We need to make sure chunk size won't take more than half of LLM max context token size + # but it also can't be bigger than the embedding engine max token size + llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division + max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) + try: cognee_config = get_cognify_config() default_tasks = [ Task(classify_documents), Task(check_permissions_on_documents, user=user, permissions=["write"]), - Task(extract_chunks_from_documents), # Extract text chunks based on the document type. + Task( + extract_chunks_from_documents, max_chunk_tokens=max_chunk_tokens + ), # Extract text chunks based on the document type. Task( extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10} ), # Generate knowledge graphs from the document chunks. diff --git a/cognee/infrastructure/llm/anthropic/adapter.py b/cognee/infrastructure/llm/anthropic/adapter.py index 1fba732a0..cfeb15d25 100644 --- a/cognee/infrastructure/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/anthropic/adapter.py @@ -14,11 +14,12 @@ class AnthropicAdapter(LLMInterface): name = "Anthropic" model: str - def __init__(self, model: str = None): + def __init__(self, max_tokens: int, model: str = None): self.aclient = instructor.patch( create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_TOOLS ) self.model = model + self.max_tokens = max_tokens async def acreate_structured_output( self, text_input: str, system_prompt: str, response_model: Type[BaseModel] diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 67fc82683..00dff82b9 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -11,6 +11,7 @@ class LLMConfig(BaseSettings): llm_api_version: Optional[str] = None llm_temperature: float = 0.0 llm_streaming: bool = False + llm_max_tokens: int = 128000 transcription_model: str = "whisper-1" model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -24,6 +25,7 @@ class LLMConfig(BaseSettings): "api_version": self.llm_api_version, "temperature": self.llm_temperature, "streaming": self.llm_streaming, + "max_tokens": self.llm_max_tokens, "transcription_model": self.transcription_model, } diff --git a/cognee/infrastructure/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py index a910c0780..98002b2cb 100644 --- a/cognee/infrastructure/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/generic_llm_api/adapter.py @@ -2,6 +2,7 @@ import asyncio from typing import List, Type + from pydantic import BaseModel import instructor from cognee.infrastructure.llm.llm_interface import LLMInterface @@ -16,11 +17,12 @@ class GenericAPIAdapter(LLMInterface): model: str api_key: str - def __init__(self, endpoint, api_key: str, model: str, name: str): + def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int): self.name = name self.model = model self.api_key = api_key self.endpoint = endpoint + self.max_tokens = max_tokens llm_config = get_llm_config() diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index 0f64014e3..f601f48b2 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -32,6 +32,7 @@ def get_llm_client(): api_version=llm_config.llm_api_version, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, + max_tokens=llm_config.llm_max_tokens, streaming=llm_config.llm_streaming, ) @@ -42,13 +43,17 @@ def get_llm_client(): from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama" + llm_config.llm_endpoint, + llm_config.llm_api_key, + llm_config.llm_model, + "Ollama", + max_tokens=llm_config.llm_max_tokens, ) elif provider == LLMProvider.ANTHROPIC: from .anthropic.adapter import AnthropicAdapter - return AnthropicAdapter(llm_config.llm_model) + return AnthropicAdapter(max_tokens=llm_config.llm_max_tokens, model=llm_config.llm_model) elif provider == LLMProvider.CUSTOM: if llm_config.llm_api_key is None: @@ -57,7 +62,11 @@ def get_llm_client(): from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom" + llm_config.llm_endpoint, + llm_config.llm_api_key, + llm_config.llm_model, + "Custom", + max_tokens=llm_config.llm_max_tokens, ) else: diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index d45662380..d6939e323 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -32,6 +32,7 @@ class OpenAIAdapter(LLMInterface): api_version: str, model: str, transcription_model: str, + max_tokens: int, streaming: bool = False, ): self.aclient = instructor.from_litellm(litellm.acompletion) @@ -41,6 +42,7 @@ class OpenAIAdapter(LLMInterface): self.api_key = api_key self.endpoint = endpoint self.api_version = api_version + self.max_tokens = max_tokens self.streaming = streaming @observe(as_type="generation") diff --git a/cognee/modules/chunking/TextChunker.py b/cognee/modules/chunking/TextChunker.py index 4fa246272..f9a664d8b 100644 --- a/cognee/modules/chunking/TextChunker.py +++ b/cognee/modules/chunking/TextChunker.py @@ -14,22 +14,15 @@ class TextChunker: chunk_size = 0 token_count = 0 - def __init__(self, document, get_text: callable, chunk_size: int = 1024): + def __init__(self, document, get_text: callable, max_chunk_tokens: int, chunk_size: int = 1024): self.document = document self.max_chunk_size = chunk_size self.get_text = get_text + self.max_chunk_tokens = max_chunk_tokens def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data): word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size - - # Get embedding engine related to vector database - from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine - - embedding_engine = get_vector_engine().embedding_engine - - token_count_fits = ( - token_count_before + chunk_data["token_count"] <= embedding_engine.max_tokens - ) + token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_chunk_tokens return word_count_fits and token_count_fits def read(self): @@ -37,6 +30,7 @@ class TextChunker: for content_text in self.get_text(): for chunk_data in chunk_by_paragraph( content_text, + self.max_chunk_tokens, self.max_chunk_size, batch_paragraphs=True, ): diff --git a/cognee/modules/data/processing/document_types/AudioDocument.py b/cognee/modules/data/processing/document_types/AudioDocument.py index 77c700482..75152fd3d 100644 --- a/cognee/modules/data/processing/document_types/AudioDocument.py +++ b/cognee/modules/data/processing/document_types/AudioDocument.py @@ -13,12 +13,14 @@ class AudioDocument(Document): result = get_llm_client().create_transcript(self.raw_data_location) return result.text - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): # Transcribe the audio file text = self.create_transcript() chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/ImageDocument.py b/cognee/modules/data/processing/document_types/ImageDocument.py index 3b5c27d75..5f4cb287c 100644 --- a/cognee/modules/data/processing/document_types/ImageDocument.py +++ b/cognee/modules/data/processing/document_types/ImageDocument.py @@ -13,11 +13,13 @@ class ImageDocument(Document): result = get_llm_client().transcribe_image(self.raw_data_location) return result.choices[0].message.content - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): # Transcribe the image file text = self.transcribe_image() chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/PdfDocument.py b/cognee/modules/data/processing/document_types/PdfDocument.py index 0142610bf..8273e0177 100644 --- a/cognee/modules/data/processing/document_types/PdfDocument.py +++ b/cognee/modules/data/processing/document_types/PdfDocument.py @@ -9,7 +9,7 @@ from .Document import Document class PdfDocument(Document): type: str = "pdf" - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): file = PdfReader(self.raw_data_location) def get_text(): @@ -18,7 +18,9 @@ class PdfDocument(Document): yield page_text chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/TextDocument.py b/cognee/modules/data/processing/document_types/TextDocument.py index 692edcb89..6bef959de 100644 --- a/cognee/modules/data/processing/document_types/TextDocument.py +++ b/cognee/modules/data/processing/document_types/TextDocument.py @@ -7,7 +7,7 @@ from .Document import Document class TextDocument(Document): type: str = "text" - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): def get_text(): with open(self.raw_data_location, mode="r", encoding="utf-8") as file: while True: @@ -20,6 +20,8 @@ class TextDocument(Document): chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/UnstructuredDocument.py b/cognee/modules/data/processing/document_types/UnstructuredDocument.py index b3849616f..254958d14 100644 --- a/cognee/modules/data/processing/document_types/UnstructuredDocument.py +++ b/cognee/modules/data/processing/document_types/UnstructuredDocument.py @@ -10,7 +10,7 @@ from .Document import Document class UnstructuredDocument(Document): type: str = "unstructured" - def read(self, chunk_size: int, chunker: str) -> str: + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int) -> str: def get_text(): try: from unstructured.partition.auto import partition @@ -29,6 +29,8 @@ class UnstructuredDocument(Document): yield text - chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text) + chunker = TextChunker( + self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens + ) yield from chunker.read() diff --git a/cognee/tasks/chunks/chunk_by_paragraph.py b/cognee/tasks/chunks/chunk_by_paragraph.py index 7d7221b87..34205d9f6 100644 --- a/cognee/tasks/chunks/chunk_by_paragraph.py +++ b/cognee/tasks/chunks/chunk_by_paragraph.py @@ -4,13 +4,13 @@ from uuid import NAMESPACE_OID, uuid5 import tiktoken from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine from .chunk_by_sentence import chunk_by_sentence def chunk_by_paragraph( data: str, + max_chunk_tokens, paragraph_length: int = 1024, batch_paragraphs: bool = True, ) -> Iterator[Dict[str, Any]]: @@ -31,19 +31,21 @@ def chunk_by_paragraph( last_cut_type = None current_token_count = 0 - # Get vector and embedding engine vector_engine = get_vector_engine() - embedding_engine = vector_engine.embedding_engine + embedding_model = vector_engine.embedding_engine.model + embedding_model = embedding_model.split("/")[-1] for paragraph_id, sentence, word_count, end_type in chunk_by_sentence( data, maximum_length=paragraph_length ): # Check if this sentence would exceed length limit - token_count = embedding_engine.tokenizer.count_tokens(sentence) + + tokenizer = tiktoken.encoding_for_model(embedding_model) + token_count = len(tokenizer.encode(sentence)) if current_word_count > 0 and ( current_word_count + word_count > paragraph_length - or current_token_count + token_count > embedding_engine.max_tokens + or current_token_count + token_count > max_chunk_tokens ): # Yield current chunk chunk_dict = { diff --git a/cognee/tasks/documents/extract_chunks_from_documents.py b/cognee/tasks/documents/extract_chunks_from_documents.py index 3dc35ce57..4a089c7bc 100644 --- a/cognee/tasks/documents/extract_chunks_from_documents.py +++ b/cognee/tasks/documents/extract_chunks_from_documents.py @@ -5,6 +5,7 @@ from cognee.modules.data.processing.document_types.Document import Document async def extract_chunks_from_documents( documents: list[Document], + max_chunk_tokens: int, chunk_size: int = 1024, chunker="text_chunker", ) -> AsyncGenerator: @@ -16,5 +17,7 @@ async def extract_chunks_from_documents( - The `chunker` parameter determines the chunking logic and should align with the document type. """ for document in documents: - for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker): + for document_chunk in document.read( + chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens + ): yield document_chunk