diff --git a/.env.template b/.env.template index 75a57de4d..df8408518 100644 --- a/.env.template +++ b/.env.template @@ -1,12 +1,28 @@ ENV="local" TOKENIZERS_PARALLELISM="false" -LLM_API_KEY= + +# LLM settings +LLM_API_KEY="" +LLM_MODEL="openai/gpt-4o-mini" +LLM_PROVIDER="openai" +LLM_ENDPOINT="" +LLM_API_VERSION="" +LLM_MAX_TOKENS="16384" GRAPHISTRY_USERNAME= GRAPHISTRY_PASSWORD= SENTRY_REPORTING_URL= +# Embedding settings +EMBEDDING_PROVIDER="openai" +EMBEDDING_API_KEY="" +EMBEDDING_MODEL="openai/text-embedding-3-large" +EMBEDDING_ENDPOINT="" +EMBEDDING_API_VERSION="" +EMBEDDING_DIMENSIONS=3072 +EMBEDDING_MAX_TOKENS=8191 + # "neo4j" or "networkx" GRAPH_DATABASE_PROVIDER="networkx" # Not needed if using networkx diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index c08e906ee..c73e90c19 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -20,6 +20,7 @@ from cognee.tasks.repo_processor import ( from cognee.tasks.repo_processor.get_source_code_chunks import get_source_code_chunks from cognee.tasks.storage import add_data_points from cognee.tasks.summarization import summarize_code, summarize_text +from cognee.infrastructure.llm import get_max_chunk_tokens monitoring = get_base_config().monitoring_tool if monitoring == MonitoringTool.LANGFUSE: @@ -57,7 +58,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True): Task(ingest_data, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(classify_documents), - Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens), + Task(extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()), Task( extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50} ), diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index 738f77c52..12a84030d 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -4,6 +4,7 @@ from typing import Union from pydantic import BaseModel +from cognee.infrastructure.llm import get_max_chunk_tokens from cognee.modules.cognify.config import get_cognify_config from cognee.modules.data.methods import get_datasets, get_datasets_by_name from cognee.modules.data.methods.get_dataset_data import get_dataset_data @@ -151,7 +152,9 @@ async def get_default_tasks( default_tasks = [ Task(classify_documents), Task(check_permissions_on_documents, user=user, permissions=["write"]), - Task(extract_chunks_from_documents), # Extract text chunks based on the document type. + Task( + extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens() + ), # Extract text chunks based on the document type. Task( extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10} ), # Generate knowledge graphs from the document chunks. diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index f0a40ca36..f81bc8515 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -6,6 +6,9 @@ import litellm import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException +from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer +from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer +from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer litellm.set_verbose = False logger = logging.getLogger("LiteLLMEmbeddingEngine") @@ -15,23 +18,29 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): api_key: str endpoint: str api_version: str + provider: str model: str dimensions: int mock: bool def __init__( self, + provider: str = "openai", model: Optional[str] = "text-embedding-3-large", dimensions: Optional[int] = 3072, api_key: str = None, endpoint: str = None, api_version: str = None, + max_tokens: int = 512, ): self.api_key = api_key self.endpoint = endpoint self.api_version = api_version + self.provider = provider self.model = model self.dimensions = dimensions + self.max_tokens = max_tokens + self.tokenizer = self.get_tokenizer() enable_mocking = os.getenv("MOCK_EMBEDDING", "false") if isinstance(enable_mocking, bool): @@ -104,3 +113,18 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): def get_vector_size(self) -> int: return self.dimensions + + def get_tokenizer(self): + logger.debug(f"Loading tokenizer for model {self.model}...") + # If model also contains provider information, extract only model information + model = self.model.split("/")[-1] + + if "openai" in self.provider.lower(): + tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens) + elif "gemini" in self.provider.lower(): + tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens) + else: + tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens) + + logger.debug(f"Tokenizer loaded for model: {self.model}") + return tokenizer diff --git a/cognee/infrastructure/databases/vector/embeddings/config.py b/cognee/infrastructure/databases/vector/embeddings/config.py index 042c063f8..315caf7ef 100644 --- a/cognee/infrastructure/databases/vector/embeddings/config.py +++ b/cognee/infrastructure/databases/vector/embeddings/config.py @@ -4,12 +4,13 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class EmbeddingConfig(BaseSettings): - embedding_model: Optional[str] = "text-embedding-3-large" + embedding_provider: Optional[str] = "openai" + embedding_model: Optional[str] = "openai/text-embedding-3-large" embedding_dimensions: Optional[int] = 3072 embedding_endpoint: Optional[str] = None embedding_api_key: Optional[str] = None embedding_api_version: Optional[str] = None - + embedding_max_tokens: Optional[int] = 8191 model_config = SettingsConfigDict(env_file=".env", extra="allow") diff --git a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py index 6bfb4dd15..d3011f059 100644 --- a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +++ b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py @@ -10,9 +10,11 @@ def get_embedding_engine() -> EmbeddingEngine: return LiteLLMEmbeddingEngine( # If OpenAI API is used for embeddings, litellm needs only the api_key. + provider=config.embedding_provider, api_key=config.embedding_api_key or llm_config.llm_api_key, endpoint=config.embedding_endpoint, api_version=config.embedding_api_version, model=config.embedding_model, dimensions=config.embedding_dimensions, + max_tokens=config.embedding_max_tokens, ) diff --git a/cognee/infrastructure/llm/__init__.py b/cognee/infrastructure/llm/__init__.py index 7fb3be736..36d7e56ad 100644 --- a/cognee/infrastructure/llm/__init__.py +++ b/cognee/infrastructure/llm/__init__.py @@ -1 +1,2 @@ from .config import get_llm_config +from .utils import get_max_chunk_tokens diff --git a/cognee/infrastructure/llm/anthropic/adapter.py b/cognee/infrastructure/llm/anthropic/adapter.py index 1fba732a0..cfeb15d25 100644 --- a/cognee/infrastructure/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/anthropic/adapter.py @@ -14,11 +14,12 @@ class AnthropicAdapter(LLMInterface): name = "Anthropic" model: str - def __init__(self, model: str = None): + def __init__(self, max_tokens: int, model: str = None): self.aclient = instructor.patch( create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_TOOLS ) self.model = model + self.max_tokens = max_tokens async def acreate_structured_output( self, text_input: str, system_prompt: str, response_model: Type[BaseModel] diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 67fc82683..48c94423e 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -11,6 +11,7 @@ class LLMConfig(BaseSettings): llm_api_version: Optional[str] = None llm_temperature: float = 0.0 llm_streaming: bool = False + llm_max_tokens: int = 16384 transcription_model: str = "whisper-1" model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -24,6 +25,7 @@ class LLMConfig(BaseSettings): "api_version": self.llm_api_version, "temperature": self.llm_temperature, "streaming": self.llm_streaming, + "max_tokens": self.llm_max_tokens, "transcription_model": self.transcription_model, } diff --git a/cognee/infrastructure/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py index a910c0780..98002b2cb 100644 --- a/cognee/infrastructure/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/generic_llm_api/adapter.py @@ -2,6 +2,7 @@ import asyncio from typing import List, Type + from pydantic import BaseModel import instructor from cognee.infrastructure.llm.llm_interface import LLMInterface @@ -16,11 +17,12 @@ class GenericAPIAdapter(LLMInterface): model: str api_key: str - def __init__(self, endpoint, api_key: str, model: str, name: str): + def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int): self.name = name self.model = model self.api_key = api_key self.endpoint = endpoint + self.max_tokens = max_tokens llm_config = get_llm_config() diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index 0f64014e3..ede8bd330 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -20,6 +20,15 @@ def get_llm_client(): provider = LLMProvider(llm_config.llm_provider) + # Check if max_token value is defined in liteLLM for given model + # if not use value from cognee configuration + from cognee.infrastructure.llm.utils import ( + get_model_max_tokens, + ) # imported here to avoid circular imports + + model_max_tokens = get_model_max_tokens(llm_config.llm_model) + max_tokens = model_max_tokens if model_max_tokens else llm_config.llm_max_tokens + if provider == LLMProvider.OPENAI: if llm_config.llm_api_key is None: raise InvalidValueError(message="LLM API key is not set.") @@ -32,6 +41,7 @@ def get_llm_client(): api_version=llm_config.llm_api_version, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, + max_tokens=max_tokens, streaming=llm_config.llm_streaming, ) @@ -42,13 +52,17 @@ def get_llm_client(): from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama" + llm_config.llm_endpoint, + llm_config.llm_api_key, + llm_config.llm_model, + "Ollama", + max_tokens=max_tokens, ) elif provider == LLMProvider.ANTHROPIC: from .anthropic.adapter import AnthropicAdapter - return AnthropicAdapter(llm_config.llm_model) + return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model) elif provider == LLMProvider.CUSTOM: if llm_config.llm_api_key is None: @@ -57,7 +71,11 @@ def get_llm_client(): from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom" + llm_config.llm_endpoint, + llm_config.llm_api_key, + llm_config.llm_model, + "Custom", + max_tokens=max_tokens, ) else: diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index d45662380..d6939e323 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -32,6 +32,7 @@ class OpenAIAdapter(LLMInterface): api_version: str, model: str, transcription_model: str, + max_tokens: int, streaming: bool = False, ): self.aclient = instructor.from_litellm(litellm.acompletion) @@ -41,6 +42,7 @@ class OpenAIAdapter(LLMInterface): self.api_key = api_key self.endpoint = endpoint self.api_version = api_version + self.max_tokens = max_tokens self.streaming = streaming @observe(as_type="generation") diff --git a/cognee/infrastructure/llm/tokenizer/Gemini/__init__.py b/cognee/infrastructure/llm/tokenizer/Gemini/__init__.py new file mode 100644 index 000000000..4ed4ad4d5 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/Gemini/__init__.py @@ -0,0 +1 @@ +from .adapter import GeminiTokenizer diff --git a/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py new file mode 100644 index 000000000..e4cc4f145 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py @@ -0,0 +1,44 @@ +from typing import List, Any + +from ..tokenizer_interface import TokenizerInterface + + +class GeminiTokenizer(TokenizerInterface): + def __init__( + self, + model: str, + max_tokens: int = 3072, + ): + self.model = model + self.max_tokens = max_tokens + + # Get LLM API key from config + from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config + from cognee.infrastructure.llm.config import get_llm_config + + config = get_embedding_config() + llm_config = get_llm_config() + + import google.generativeai as genai + + genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key) + + def extract_tokens(self, text: str) -> List[Any]: + raise NotImplementedError + + def count_tokens(self, text: str) -> int: + """ + Returns the number of tokens in the given text. + Args: + text: str + + Returns: + number of tokens in the given text + + """ + import google.generativeai as genai + + return len(genai.embed_content(model=f"models/{self.model}", content=text)) + + def trim_text_to_max_tokens(self, text: str) -> str: + raise NotImplementedError diff --git a/cognee/infrastructure/llm/tokenizer/HuggingFace/__init__.py b/cognee/infrastructure/llm/tokenizer/HuggingFace/__init__.py new file mode 100644 index 000000000..7cdfb9aa0 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/HuggingFace/__init__.py @@ -0,0 +1 @@ +from .adapter import HuggingFaceTokenizer diff --git a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py new file mode 100644 index 000000000..878458414 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py @@ -0,0 +1,36 @@ +from typing import List, Any + +from transformers import AutoTokenizer + +from ..tokenizer_interface import TokenizerInterface + + +class HuggingFaceTokenizer(TokenizerInterface): + def __init__( + self, + model: str, + max_tokens: int = 512, + ): + self.model = model + self.max_tokens = max_tokens + + self.tokenizer = AutoTokenizer.from_pretrained(model) + + def extract_tokens(self, text: str) -> List[Any]: + tokens = self.tokenizer.tokenize(text) + return tokens + + def count_tokens(self, text: str) -> int: + """ + Returns the number of tokens in the given text. + Args: + text: str + + Returns: + number of tokens in the given text + + """ + return len(self.tokenizer.tokenize(text)) + + def trim_text_to_max_tokens(self, text: str) -> str: + raise NotImplementedError diff --git a/cognee/infrastructure/llm/tokenizer/TikToken/__init__.py b/cognee/infrastructure/llm/tokenizer/TikToken/__init__.py new file mode 100644 index 000000000..4c7a39401 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/TikToken/__init__.py @@ -0,0 +1 @@ +from .adapter import TikTokenTokenizer diff --git a/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py new file mode 100644 index 000000000..3d649ef38 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py @@ -0,0 +1,69 @@ +from typing import List, Any +import tiktoken + +from ..tokenizer_interface import TokenizerInterface + + +class TikTokenTokenizer(TokenizerInterface): + """ + Tokenizer adapter for OpenAI. + Inteded to be used as part of LLM Embedding and LLM Adapters classes + """ + + def __init__( + self, + model: str, + max_tokens: int = 8191, + ): + self.model = model + self.max_tokens = max_tokens + # Initialize TikToken for GPT based on model + self.tokenizer = tiktoken.encoding_for_model(self.model) + + def extract_tokens(self, text: str) -> List[Any]: + tokens = [] + # Using TikToken's method to tokenize text + token_ids = self.tokenizer.encode(text) + # Go through tokens and decode them to text value + for token_id in token_ids: + token = self.tokenizer.decode([token_id]) + tokens.append(token) + return tokens + + def count_tokens(self, text: str) -> int: + """ + Returns the number of tokens in the given text. + Args: + text: str + + Returns: + number of tokens in the given text + + """ + num_tokens = len(self.tokenizer.encode(text)) + return num_tokens + + def trim_text_to_max_tokens(self, text: str) -> str: + """ + Trims the text so that the number of tokens does not exceed max_tokens. + + Args: + text (str): Original text string to be trimmed. + + Returns: + str: Trimmed version of text or original text if under the limit. + """ + # First check the number of tokens + num_tokens = self.count_tokens(text) + + # If the number of tokens is within the limit, return the text as is + if num_tokens <= self.max_tokens: + return text + + # If the number exceeds the limit, trim the text + # This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut + encoded_text = self.tokenizer.encode(text) + trimmed_encoded_text = encoded_text[: self.max_tokens] + # Decoding the trimmed text + trimmed_text = self.tokenizer.decode(trimmed_encoded_text) + return trimmed_text diff --git a/cognee/infrastructure/llm/tokenizer/__init__.py b/cognee/infrastructure/llm/tokenizer/__init__.py new file mode 100644 index 000000000..0cc895c13 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/__init__.py @@ -0,0 +1 @@ +from .tokenizer_interface import TokenizerInterface diff --git a/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py b/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py new file mode 100644 index 000000000..c533f0cf9 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py @@ -0,0 +1,18 @@ +from typing import List, Protocol, Any +from abc import abstractmethod + + +class TokenizerInterface(Protocol): + """Tokenizer interface""" + + @abstractmethod + def extract_tokens(self, text: str) -> List[Any]: + raise NotImplementedError + + @abstractmethod + def count_tokens(self, text: str) -> int: + raise NotImplementedError + + @abstractmethod + def trim_text_to_max_tokens(self, text: str) -> str: + raise NotImplementedError diff --git a/cognee/infrastructure/llm/utils.py b/cognee/infrastructure/llm/utils.py new file mode 100644 index 000000000..e0aa8945a --- /dev/null +++ b/cognee/infrastructure/llm/utils.py @@ -0,0 +1,38 @@ +import logging +import litellm + +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.llm.get_llm_client import get_llm_client + +logger = logging.getLogger(__name__) + + +def get_max_chunk_tokens(): + # Calculate max chunk size based on the following formula + embedding_engine = get_vector_engine().embedding_engine + llm_client = get_llm_client() + + # We need to make sure chunk size won't take more than half of LLM max context token size + # but it also can't be bigger than the embedding engine max token size + llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division + max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) + + return max_chunk_tokens + + +def get_model_max_tokens(model_name: str): + """ + Args: + model_name: name of LLM or embedding model + + Returns: Number of max tokens of model, or None if model is unknown + """ + max_tokens = None + + if model_name in litellm.model_cost: + max_tokens = litellm.model_cost[model_name]["max_tokens"] + logger.debug(f"Max input tokens for {model_name}: {max_tokens}") + else: + logger.info("Model not found in LiteLLM's model_cost.") + + return max_tokens diff --git a/cognee/modules/chunking/TextChunker.py b/cognee/modules/chunking/TextChunker.py index cd71dd128..f9a664d8b 100644 --- a/cognee/modules/chunking/TextChunker.py +++ b/cognee/modules/chunking/TextChunker.py @@ -14,17 +14,15 @@ class TextChunker: chunk_size = 0 token_count = 0 - def __init__( - self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024 - ): + def __init__(self, document, get_text: callable, max_chunk_tokens: int, chunk_size: int = 1024): self.document = document self.max_chunk_size = chunk_size self.get_text = get_text - self.max_tokens = max_tokens if max_tokens else float("inf") + self.max_chunk_tokens = max_chunk_tokens def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data): word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size - token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens + token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_chunk_tokens return word_count_fits and token_count_fits def read(self): @@ -32,7 +30,7 @@ class TextChunker: for content_text in self.get_text(): for chunk_data in chunk_by_paragraph( content_text, - self.max_tokens, + self.max_chunk_tokens, self.max_chunk_size, batch_paragraphs=True, ): diff --git a/cognee/modules/cognify/config.py b/cognee/modules/cognify/config.py index dd94d8b41..4ba0f4bd6 100644 --- a/cognee/modules/cognify/config.py +++ b/cognee/modules/cognify/config.py @@ -8,7 +8,6 @@ import os class CognifyConfig(BaseSettings): classification_model: object = DefaultContentPrediction summarization_model: object = SummarizedContent - max_tokens: Optional[int] = os.getenv("MAX_TOKENS") model_config = SettingsConfigDict(env_file=".env", extra="allow") def to_dict(self) -> dict: diff --git a/cognee/modules/data/processing/document_types/AudioDocument.py b/cognee/modules/data/processing/document_types/AudioDocument.py index b7d2476b4..75152fd3d 100644 --- a/cognee/modules/data/processing/document_types/AudioDocument.py +++ b/cognee/modules/data/processing/document_types/AudioDocument.py @@ -13,14 +13,14 @@ class AudioDocument(Document): result = get_llm_client().create_transcript(self.raw_data_location) return result.text - def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): # Transcribe the audio file text = self.create_transcript() chunker_func = ChunkerConfig.get_chunker(chunker) chunker = chunker_func( - self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens + self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/Document.py b/cognee/modules/data/processing/document_types/Document.py index 4d9f3bf72..76ff1e045 100644 --- a/cognee/modules/data/processing/document_types/Document.py +++ b/cognee/modules/data/processing/document_types/Document.py @@ -11,5 +11,5 @@ class Document(DataPoint): mime_type: str _metadata: dict = {"index_fields": ["name"], "type": "Document"} - def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str: + def read(self, chunk_size: int, chunker=str, max_chunk_tokens: Optional[int] = None) -> str: pass diff --git a/cognee/modules/data/processing/document_types/ImageDocument.py b/cognee/modules/data/processing/document_types/ImageDocument.py index c055b8253..5f4cb287c 100644 --- a/cognee/modules/data/processing/document_types/ImageDocument.py +++ b/cognee/modules/data/processing/document_types/ImageDocument.py @@ -13,13 +13,13 @@ class ImageDocument(Document): result = get_llm_client().transcribe_image(self.raw_data_location) return result.choices[0].message.content - def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): # Transcribe the image file text = self.transcribe_image() chunker_func = ChunkerConfig.get_chunker(chunker) chunker = chunker_func( - self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens + self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/PdfDocument.py b/cognee/modules/data/processing/document_types/PdfDocument.py index 768f91264..8273e0177 100644 --- a/cognee/modules/data/processing/document_types/PdfDocument.py +++ b/cognee/modules/data/processing/document_types/PdfDocument.py @@ -9,7 +9,7 @@ from .Document import Document class PdfDocument(Document): type: str = "pdf" - def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): file = PdfReader(self.raw_data_location) def get_text(): @@ -19,7 +19,7 @@ class PdfDocument(Document): chunker_func = ChunkerConfig.get_chunker(chunker) chunker = chunker_func( - self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens + self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/TextDocument.py b/cognee/modules/data/processing/document_types/TextDocument.py index b62ccd56e..6bef959de 100644 --- a/cognee/modules/data/processing/document_types/TextDocument.py +++ b/cognee/modules/data/processing/document_types/TextDocument.py @@ -7,7 +7,7 @@ from .Document import Document class TextDocument(Document): type: str = "text" - def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): def get_text(): with open(self.raw_data_location, mode="r", encoding="utf-8") as file: while True: @@ -21,7 +21,7 @@ class TextDocument(Document): chunker_func = ChunkerConfig.get_chunker(chunker) chunker = chunker_func( - self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens + self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/UnstructuredDocument.py b/cognee/modules/data/processing/document_types/UnstructuredDocument.py index 1c291d0dc..254958d14 100644 --- a/cognee/modules/data/processing/document_types/UnstructuredDocument.py +++ b/cognee/modules/data/processing/document_types/UnstructuredDocument.py @@ -10,7 +10,7 @@ from .Document import Document class UnstructuredDocument(Document): type: str = "unstructured" - def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str: + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int) -> str: def get_text(): try: from unstructured.partition.auto import partition @@ -29,6 +29,8 @@ class UnstructuredDocument(Document): yield text - chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens) + chunker = TextChunker( + self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens + ) yield from chunker.read() diff --git a/cognee/shared/utils.py b/cognee/shared/utils.py index 6b1ca7f8f..4e5523fd2 100644 --- a/cognee/shared/utils.py +++ b/cognee/shared/utils.py @@ -10,8 +10,6 @@ import graphistry import networkx as nx import pandas as pd import matplotlib.pyplot as plt -import tiktoken -import time import logging import sys @@ -100,15 +98,6 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}): print(f"Error sending telemetry through proxy: {response.status_code}") -def num_tokens_from_string(string: str, encoding_name: str) -> int: - """Returns the number of tokens in a text string.""" - - # tiktoken.get_encoding("cl100k_base") - encoding = tiktoken.encoding_for_model(encoding_name) - num_tokens = len(encoding.encode(string)) - return num_tokens - - def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str: h = hashlib.md5() @@ -134,34 +123,6 @@ def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str: raise IngestionError(message=f"Failed to load data from {file}: {e}") -def trim_text_to_max_tokens(text: str, max_tokens: int, encoding_name: str) -> str: - """ - Trims the text so that the number of tokens does not exceed max_tokens. - - Args: - text (str): Original text string to be trimmed. - max_tokens (int): Maximum number of tokens allowed. - encoding_name (str): The name of the token encoding to use. - - Returns: - str: Trimmed version of text or original text if under the limit. - """ - # First check the number of tokens - num_tokens = num_tokens_from_string(text, encoding_name) - - # If the number of tokens is within the limit, return the text as is - if num_tokens <= max_tokens: - return text - - # If the number exceeds the limit, trim the text - # This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut - encoded_text = tiktoken.get_encoding(encoding_name).encode(text) - trimmed_encoded_text = encoded_text[:max_tokens] - # Decoding the trimmed text - trimmed_text = tiktoken.get_encoding(encoding_name).decode(trimmed_encoded_text) - return trimmed_text - - def generate_color_palette(unique_layers): colormap = plt.cm.get_cmap("viridis", len(unique_layers)) colors = [colormap(i) for i in range(len(unique_layers))] diff --git a/cognee/tasks/chunks/chunk_by_paragraph.py b/cognee/tasks/chunks/chunk_by_paragraph.py index 9cb1f5b54..34205d9f6 100644 --- a/cognee/tasks/chunks/chunk_by_paragraph.py +++ b/cognee/tasks/chunks/chunk_by_paragraph.py @@ -10,7 +10,7 @@ from .chunk_by_sentence import chunk_by_sentence def chunk_by_paragraph( data: str, - max_tokens: Optional[Union[int, float]] = None, + max_chunk_tokens, paragraph_length: int = 1024, batch_paragraphs: bool = True, ) -> Iterator[Dict[str, Any]]: @@ -30,8 +30,6 @@ def chunk_by_paragraph( paragraph_ids = [] last_cut_type = None current_token_count = 0 - if not max_tokens: - max_tokens = float("inf") vector_engine = get_vector_engine() embedding_model = vector_engine.embedding_engine.model @@ -47,7 +45,7 @@ def chunk_by_paragraph( if current_word_count > 0 and ( current_word_count + word_count > paragraph_length - or current_token_count + token_count > max_tokens + or current_token_count + token_count > max_chunk_tokens ): # Yield current chunk chunk_dict = { diff --git a/cognee/tasks/documents/extract_chunks_from_documents.py b/cognee/tasks/documents/extract_chunks_from_documents.py index ecdd6817d..4a089c7bc 100644 --- a/cognee/tasks/documents/extract_chunks_from_documents.py +++ b/cognee/tasks/documents/extract_chunks_from_documents.py @@ -5,9 +5,9 @@ from cognee.modules.data.processing.document_types.Document import Document async def extract_chunks_from_documents( documents: list[Document], + max_chunk_tokens: int, chunk_size: int = 1024, chunker="text_chunker", - max_tokens: Optional[int] = None, ) -> AsyncGenerator: """ Extracts chunks of data from a list of documents based on the specified chunking parameters. @@ -18,6 +18,6 @@ async def extract_chunks_from_documents( """ for document in documents: for document_chunk in document.read( - chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens + chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens ): yield document_chunk diff --git a/cognee/tasks/repo_processor/get_source_code_chunks.py b/cognee/tasks/repo_processor/get_source_code_chunks.py index 82fa46cf0..ca1c76e46 100644 --- a/cognee/tasks/repo_processor/get_source_code_chunks.py +++ b/cognee/tasks/repo_processor/get_source_code_chunks.py @@ -89,26 +89,29 @@ def _get_subchunk_token_counts( def _get_chunk_source_code( - code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int + code_token_counts: list[tuple[str, int]], overlap: float ) -> tuple[list[tuple[str, int]], str]: """Generates a chunk of source code from tokenized subchunks with overlap handling.""" current_count = 0 cumulative_counts = [] current_source_code = "" + # Get embedding engine used in vector database + embedding_engine = get_vector_engine().embedding_engine + for i, (child_code, token_count) in enumerate(code_token_counts): current_count += token_count cumulative_counts.append(current_count) - if current_count > max_tokens: + if current_count > embedding_engine.max_tokens: break current_source_code += f"\n{child_code}" - if current_count <= max_tokens: + if current_count <= embedding_engine.max_tokens: return [], current_source_code.strip() cutoff = 1 for i, cum_count in enumerate(cumulative_counts): - if cum_count > (1 - overlap) * max_tokens: + if cum_count > (1 - overlap) * embedding_engine.max_tokens: break cutoff = i @@ -117,21 +120,18 @@ def _get_chunk_source_code( def get_source_code_chunks_from_code_part( code_file_part: CodePart, - max_tokens: int = 8192, overlap: float = 0.25, granularity: float = 0.1, - model_name: str = "text-embedding-3-large", ) -> Generator[SourceCodeChunk, None, None]: """Yields source code chunks from a CodePart object, with configurable token limits and overlap.""" if not code_file_part.source_code: logger.error(f"No source code in CodeFile {code_file_part.id}") return - vector_engine = get_vector_engine() - embedding_model = vector_engine.embedding_engine.model - model_name = embedding_model.split("/")[-1] - tokenizer = tiktoken.encoding_for_model(model_name) - max_subchunk_tokens = max(1, int(granularity * max_tokens)) + embedding_engine = get_vector_engine().embedding_engine + tokenizer = embedding_engine.tokenizer + + max_subchunk_tokens = max(1, int(granularity * embedding_engine.max_tokens)) subchunk_token_counts = _get_subchunk_token_counts( tokenizer, code_file_part.source_code, max_subchunk_tokens ) @@ -139,7 +139,7 @@ def get_source_code_chunks_from_code_part( previous_chunk = None while subchunk_token_counts: subchunk_token_counts, chunk_source_code = _get_chunk_source_code( - subchunk_token_counts, overlap, max_tokens + subchunk_token_counts, overlap ) if not chunk_source_code: continue diff --git a/cognee/tests/integration/documents/AudioDocument_test.py b/cognee/tests/integration/documents/AudioDocument_test.py index 9719d90fc..38b547140 100644 --- a/cognee/tests/integration/documents/AudioDocument_test.py +++ b/cognee/tests/integration/documents/AudioDocument_test.py @@ -34,7 +34,7 @@ def test_AudioDocument(): ) with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") + GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512) ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/ImageDocument_test.py b/cognee/tests/integration/documents/ImageDocument_test.py index bd15961ee..faa54fa27 100644 --- a/cognee/tests/integration/documents/ImageDocument_test.py +++ b/cognee/tests/integration/documents/ImageDocument_test.py @@ -23,7 +23,7 @@ def test_ImageDocument(): ) with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") + GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512) ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/PdfDocument_test.py b/cognee/tests/integration/documents/PdfDocument_test.py index 82d304b6c..e9530fc12 100644 --- a/cognee/tests/integration/documents/PdfDocument_test.py +++ b/cognee/tests/integration/documents/PdfDocument_test.py @@ -25,7 +25,7 @@ def test_PdfDocument(): ) for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker") + GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker", max_chunk_tokens=2048) ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/TextDocument_test.py b/cognee/tests/integration/documents/TextDocument_test.py index 17db39be8..99e28a3ac 100644 --- a/cognee/tests/integration/documents/TextDocument_test.py +++ b/cognee/tests/integration/documents/TextDocument_test.py @@ -37,7 +37,8 @@ def test_TextDocument(input_file, chunk_size): ) for ground_truth, paragraph_data in zip( - GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker") + GROUND_TRUTH[input_file], + document.read(chunk_size=chunk_size, chunker="text_chunker", max_chunk_tokens=1024), ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/UnstructuredDocument_test.py b/cognee/tests/integration/documents/UnstructuredDocument_test.py index 81e804f07..d76843c0a 100644 --- a/cognee/tests/integration/documents/UnstructuredDocument_test.py +++ b/cognee/tests/integration/documents/UnstructuredDocument_test.py @@ -68,7 +68,9 @@ def test_UnstructuredDocument(): ) # Test PPTX - for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"): + for paragraph_data in pptx_document.read( + chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + ): assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" assert "sentence_cut" == paragraph_data.cut_type, ( @@ -76,7 +78,9 @@ def test_UnstructuredDocument(): ) # Test DOCX - for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"): + for paragraph_data in docx_document.read( + chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + ): assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" assert "sentence_end" == paragraph_data.cut_type, ( @@ -84,7 +88,9 @@ def test_UnstructuredDocument(): ) # TEST CSV - for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"): + for paragraph_data in csv_document.read( + chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + ): assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }" assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, ( f"Read text doesn't match expected text: {paragraph_data.text}" @@ -94,7 +100,9 @@ def test_UnstructuredDocument(): ) # Test XLSX - for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"): + for paragraph_data in xlsx_document.read( + chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + ): assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }" assert "sentence_cut" == paragraph_data.cut_type, ( diff --git a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py index d8680a604..5555a7dc9 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py @@ -8,14 +8,24 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS paragraph_lengths = [64, 256, 1024] batch_paragraphs_vals = [True, False] +max_chunk_tokens_vals = [512, 1024, 4096] @pytest.mark.parametrize( - "input_text,paragraph_length,batch_paragraphs", - list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), + "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs", + list( + product( + list(INPUT_TEXTS.values()), + max_chunk_tokens_vals, + paragraph_lengths, + batch_paragraphs_vals, + ) + ), ) -def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs): - chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs) +def test_chunk_by_paragraph_isomorphism( + input_text, max_chunk_tokens, paragraph_length, batch_paragraphs +): + chunks = chunk_by_paragraph(input_text, max_chunk_tokens, paragraph_length, batch_paragraphs) reconstructed_text = "".join([chunk["text"] for chunk in chunks]) assert reconstructed_text == input_text, ( f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" @@ -23,13 +33,23 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para @pytest.mark.parametrize( - "input_text,paragraph_length,batch_paragraphs", - list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), + "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs", + list( + product( + list(INPUT_TEXTS.values()), + max_chunk_tokens_vals, + paragraph_lengths, + batch_paragraphs_vals, + ) + ), ) -def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): +def test_paragraph_chunk_length(input_text, max_chunk_tokens, paragraph_length, batch_paragraphs): chunks = list( chunk_by_paragraph( - data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs + data=input_text, + max_chunk_tokens=max_chunk_tokens, + paragraph_length=paragraph_length, + batch_paragraphs=batch_paragraphs, ) ) @@ -42,12 +62,24 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): @pytest.mark.parametrize( - "input_text,paragraph_length,batch_paragraphs", - list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), + "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs", + list( + product( + list(INPUT_TEXTS.values()), + max_chunk_tokens_vals, + paragraph_lengths, + batch_paragraphs_vals, + ) + ), ) -def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs): +def test_chunk_by_paragraph_chunk_numbering( + input_text, max_chunk_tokens, paragraph_length, batch_paragraphs +): chunks = chunk_by_paragraph( - data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs + data=input_text, + max_chunk_tokens=max_chunk_tokens, + paragraph_length=paragraph_length, + batch_paragraphs=batch_paragraphs, ) chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( diff --git a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py index e420b2e9f..ed706830e 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py @@ -50,7 +50,7 @@ Third paragraph is cut and is missing the dot at the end""", def run_chunking_test(test_text, expected_chunks): chunks = [] for chunk_data in chunk_by_paragraph( - data=test_text, paragraph_length=12, batch_paragraphs=False + data=test_text, paragraph_length=12, batch_paragraphs=False, max_chunk_tokens=512 ): chunks.append(chunk_data) diff --git a/cognee/tests/unit/processing/utils/utils_test.py b/cognee/tests/unit/processing/utils/utils_test.py index f8c325100..067ab6ea7 100644 --- a/cognee/tests/unit/processing/utils/utils_test.py +++ b/cognee/tests/unit/processing/utils/utils_test.py @@ -11,9 +11,7 @@ from cognee.shared.exceptions import IngestionError from cognee.shared.utils import ( get_anonymous_id, send_telemetry, - num_tokens_from_string, get_file_content_hash, - trim_text_to_max_tokens, prepare_edges, prepare_nodes, create_cognee_style_network_with_logo, @@ -45,15 +43,6 @@ def test_get_anonymous_id(mock_open_file, mock_makedirs, temp_dir): # args, kwargs = mock_post.call_args # assert kwargs["json"]["event_name"] == "test_event" -# -# @patch("tiktoken.encoding_for_model") -# def test_num_tokens_from_string(mock_encoding): -# mock_encoding.return_value.encode = lambda x: list(x) -# -# assert num_tokens_from_string("hello", "test_encoding") == 5 -# assert num_tokens_from_string("world", "test_encoding") == 5 -# - @patch("builtins.open", new_callable=mock_open, read_data=b"test_data") def test_get_file_content_hash_file(mock_open_file): @@ -73,18 +62,6 @@ def test_get_file_content_hash_stream(): assert result == expected_hash -# def test_trim_text_to_max_tokens(): -# text = "This is a test string with multiple words." -# encoding_name = "test_encoding" -# -# with patch("tiktoken.get_encoding") as mock_get_encoding: -# mock_get_encoding.return_value.encode = lambda x: list(x) -# mock_get_encoding.return_value.decode = lambda x: "".join(x) -# -# result = trim_text_to_max_tokens(text, 5, encoding_name) -# assert result == text[:5] - - def test_prepare_edges(): graph = nx.MultiDiGraph() graph.add_edge("A", "B", key="AB", weight=1) diff --git a/notebooks/cognee_demo.ipynb b/notebooks/cognee_demo.ipynb index 6540d19c5..a90c02958 100644 --- a/notebooks/cognee_demo.ipynb +++ b/notebooks/cognee_demo.ipynb @@ -650,6 +650,7 @@ "from cognee.modules.pipelines import run_tasks\n", "from cognee.modules.users.models import User\n", "from cognee.tasks.documents import check_permissions_on_documents, classify_documents, extract_chunks_from_documents\n", + "from cognee.infrastructure.llm import get_max_chunk_tokens\n", "from cognee.tasks.graph import extract_graph_from_data\n", "from cognee.tasks.storage import add_data_points\n", "from cognee.tasks.summarization import summarize_text\n", @@ -663,7 +664,7 @@ " tasks = [\n", " Task(classify_documents),\n", " Task(check_permissions_on_documents, user = user, permissions = [\"write\"]),\n", - " Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n", + " Task(extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()), # Extract text chunks based on the document type.\n", " Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { \"batch_size\": 10 }), # Generate knowledge graphs from the document chunks.\n", " Task(\n", " summarize_text,\n", diff --git a/poetry.lock b/poetry.lock index 45d3f8210..542e0377b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiofiles" @@ -645,6 +645,17 @@ urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version > [package.extras] crt = ["awscrt (==0.23.4)"] +[[package]] +name = "cachetools" +version = "5.5.1" +description = "Extensible memoizing collections and decorators" +optional = true +python-versions = ">=3.7" +files = [ + {file = "cachetools-5.5.1-py3-none-any.whl", hash = "sha256:b76651fdc3b24ead3c648bbdeeb940c1b04d365b38b4af66788f9ec4a81d42bb"}, + {file = "cachetools-5.5.1.tar.gz", hash = "sha256:70f238fbba50383ef62e55c6aff6d9673175fe59f7c6782c7a0b9e38f4a9df95"}, +] + [[package]] name = "certifi" version = "2024.12.14" @@ -1995,6 +2006,135 @@ files = [ {file = "giturlparse-0.12.0.tar.gz", hash = "sha256:c0fff7c21acc435491b1779566e038757a205c1ffdcb47e4f81ea52ad8c3859a"}, ] +[[package]] +name = "google-ai-generativelanguage" +version = "0.6.15" +description = "Google Ai Generativelanguage API client library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google_ai_generativelanguage-0.6.15-py3-none-any.whl", hash = "sha256:5a03ef86377aa184ffef3662ca28f19eeee158733e45d7947982eb953c6ebb6c"}, + {file = "google_ai_generativelanguage-0.6.15.tar.gz", hash = "sha256:8f6d9dc4c12b065fe2d0289026171acea5183ebf2d0b11cefe12f3821e159ec3"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" + +[[package]] +name = "google-api-core" +version = "2.24.0" +description = "Google API client core library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google_api_core-2.24.0-py3-none-any.whl", hash = "sha256:10d82ac0fca69c82a25b3efdeefccf6f28e02ebb97925a8cce8edbfe379929d9"}, + {file = "google_api_core-2.24.0.tar.gz", hash = "sha256:e255640547a597a4da010876d333208ddac417d60add22b6851a0c66a831fcaf"}, +] + +[package.dependencies] +google-auth = ">=2.14.1,<3.0.dev0" +googleapis-common-protos = ">=1.56.2,<2.0.dev0" +grpcio = [ + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, +] +grpcio-status = [ + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, +] +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" +requests = ">=2.18.0,<3.0.0.dev0" + +[package.extras] +async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.dev0)"] +grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"] +grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] + +[[package]] +name = "google-api-python-client" +version = "2.159.0" +description = "Google API Client Library for Python" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google_api_python_client-2.159.0-py2.py3-none-any.whl", hash = "sha256:baef0bb631a60a0bd7c0bf12a5499e3a40cd4388484de7ee55c1950bf820a0cf"}, + {file = "google_api_python_client-2.159.0.tar.gz", hash = "sha256:55197f430f25c907394b44fa078545ffef89d33fd4dca501b7db9f0d8e224bd6"}, +] + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0" +google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0" +google-auth-httplib2 = ">=0.2.0,<1.0.0" +httplib2 = ">=0.19.0,<1.dev0" +uritemplate = ">=3.0.1,<5" + +[[package]] +name = "google-auth" +version = "2.38.0" +description = "Google Authentication Library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google_auth-2.38.0-py2.py3-none-any.whl", hash = "sha256:e7dae6694313f434a2727bf2906f27ad259bae090d7aa896590d86feec3d9d4a"}, + {file = "google_auth-2.38.0.tar.gz", hash = "sha256:8285113607d3b80a3f1543b75962447ba8a09fe85783432a784fdeef6ac094c4"}, +] + +[package.dependencies] +cachetools = ">=2.0.0,<6.0" +pyasn1-modules = ">=0.2.1" +rsa = ">=3.1.4,<5" + +[package.extras] +aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"] +enterprise-cert = ["cryptography", "pyopenssl"] +pyjwt = ["cryptography (>=38.0.3)", "pyjwt (>=2.0)"] +pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] +reauth = ["pyu2f (>=0.1.5)"] +requests = ["requests (>=2.20.0,<3.0.0.dev0)"] + +[[package]] +name = "google-auth-httplib2" +version = "0.2.0" +description = "Google Authentication Library: httplib2 transport" +optional = true +python-versions = "*" +files = [ + {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"}, + {file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"}, +] + +[package.dependencies] +google-auth = "*" +httplib2 = ">=0.19.0" + +[[package]] +name = "google-generativeai" +version = "0.8.4" +description = "Google Generative AI High level API client library and tools." +optional = true +python-versions = ">=3.9" +files = [ + {file = "google_generativeai-0.8.4-py3-none-any.whl", hash = "sha256:e987b33ea6decde1e69191ddcaec6ef974458864d243de7191db50c21a7c5b82"}, +] + +[package.dependencies] +google-ai-generativelanguage = "0.6.15" +google-api-core = "*" +google-api-python-client = "*" +google-auth = ">=2.15.0" +protobuf = "*" +pydantic = "*" +tqdm = "*" +typing-extensions = "*" + +[package.extras] +dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] + [[package]] name = "googleapis-common-protos" version = "1.66.0" @@ -2251,6 +2391,22 @@ files = [ grpcio = ">=1.67.1" protobuf = ">=5.26.1,<6.0dev" +[[package]] +name = "grpcio-status" +version = "1.67.1" +description = "Status proto mapping for gRPC" +optional = true +python-versions = ">=3.8" +files = [ + {file = "grpcio_status-1.67.1-py3-none-any.whl", hash = "sha256:16e6c085950bdacac97c779e6a502ea671232385e6e37f258884d6883392c2bd"}, + {file = "grpcio_status-1.67.1.tar.gz", hash = "sha256:2bf38395e028ceeecfd8866b081f61628114b384da7d51ae064ddc8d766a5d11"}, +] + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.67.1" +protobuf = ">=5.26.1,<6.0dev" + [[package]] name = "grpcio-tools" version = "1.67.1" @@ -2445,6 +2601,20 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<1.0)"] +[[package]] +name = "httplib2" +version = "0.22.0" +description = "A comprehensive HTTP client library." +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"}, + {file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"}, +] + +[package.dependencies] +pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} + [[package]] name = "httpx" version = "0.27.0" @@ -4998,8 +5168,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -5520,6 +5690,23 @@ files = [ {file = "propcache-0.2.1.tar.gz", hash = "sha256:3f77ce728b19cb537714499928fe800c3dda29e8d9428778fc7c186da4c09a64"}, ] +[[package]] +name = "proto-plus" +version = "1.25.0" +description = "Beautiful, Pythonic protocol buffers." +optional = true +python-versions = ">=3.7" +files = [ + {file = "proto_plus-1.25.0-py3-none-any.whl", hash = "sha256:c91fc4a65074ade8e458e95ef8bac34d4008daa7cce4a12d6707066fca648961"}, + {file = "proto_plus-1.25.0.tar.gz", hash = "sha256:fbb17f57f7bd05a68b7707e745e26528b0b3c34e378db91eef93912c54982d91"}, +] + +[package.dependencies] +protobuf = ">=3.19.0,<6.0.0dev" + +[package.extras] +testing = ["google-api-core (>=1.31.5)"] + [[package]] name = "protobuf" version = "5.29.3" @@ -5686,6 +5873,31 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pyasn1" +version = "0.6.1" +description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, + {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.1" +description = "A collection of ASN.1-based protocols modules" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, + {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, +] + +[package.dependencies] +pyasn1 = ">=0.4.6,<0.7.0" + [[package]] name = "pycparser" version = "2.22" @@ -5927,8 +6139,8 @@ astroid = ">=3.3.8,<=3.4.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, - {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, ] isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" mccabe = ">=0.6,<0.8" @@ -6967,6 +7179,20 @@ files = [ {file = "rpds_py-0.22.3.tar.gz", hash = "sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d"}, ] +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +optional = true +python-versions = ">=3.6,<4" +files = [ + {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, + {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, +] + +[package.dependencies] +pyasn1 = ">=0.1.3" + [[package]] name = "ruff" version = "0.9.3" @@ -8218,6 +8444,17 @@ files = [ [package.extras] dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"] +[[package]] +name = "uritemplate" +version = "4.1.1" +description = "Implementation of RFC 6570 URI Templates" +optional = true +python-versions = ">=3.6" +files = [ + {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"}, + {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"}, +] + [[package]] name = "urllib3" version = "2.3.0" @@ -8801,6 +9038,7 @@ deepeval = ["deepeval"] docs = ["unstructured"] falkordb = ["falkordb"] filesystem = ["botocore"] +gemini = ["google-generativeai"] groq = ["groq"] langchain = ["langchain_text_splitters", "langsmith"] llama-index = ["llama-index-core"] @@ -8815,4 +9053,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.10.0,<3.13" -content-hash = "1cc352109264d0e3add524cdc15c9b2e6153e1bab20d968b40e42a4d5138967f" +content-hash = "480675c274cd85a76a95bf03af865b1a0b462f25bbc21d7427b0a0b8e21c13db" diff --git a/pyproject.toml b/pyproject.toml index 034b4889d..98497cded 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ pre-commit = "^4.0.1" httpx = "0.27.0" bokeh="^3.6.2" nltk = "3.9.1" +google-generativeai = {version = "^0.8.4", optional = true} parso = {version = "^0.8.4", optional = true} jedi = {version = "^0.19.2", optional = true} @@ -90,6 +91,7 @@ postgres = ["psycopg2", "pgvector", "asyncpg"] notebook = ["notebook", "ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"] langchain = ["langsmith", "langchain_text_splitters"] llama-index = ["llama-index-core"] +gemini = ["google-generativeai"] deepeval = ["deepeval"] posthog = ["posthog"] falkordb = ["falkordb"]