From 93249c72c514c84face1aa9faf57f94db7605487 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 21 Jan 2025 19:53:22 +0100 Subject: [PATCH 01/16] fix: Initial commit to resolve issue with using tokenizer based on LLMs Currently TikToken is used for tokenizing by default which is only supported by OpenAI, this is an initial commit in an attempt to add Cognee tokenizing support for multiple LLMs --- cognee/api/v1/cognify/code_graph_pipeline.py | 2 +- .../embeddings/LiteLLMEmbeddingEngine.py | 23 +++++++ .../databases/vector/embeddings/config.py | 2 +- .../vector/embeddings/get_embedding_engine.py | 1 + .../llm/tokenizer/HuggingFace/__init__.py | 1 + .../llm/tokenizer/HuggingFace/adapter.py | 22 ++++++ .../llm/tokenizer/TikToken/__init__.py | 1 + .../llm/tokenizer/TikToken/adapter.py | 69 +++++++++++++++++++ .../infrastructure/llm/tokenizer/__init__.py | 1 + .../llm/tokenizer/tokenizer_interface.py | 18 +++++ cognee/modules/chunking/TextChunker.py | 16 +++-- cognee/modules/cognify/config.py | 1 - .../document_types/AudioDocument.py | 6 +- .../processing/document_types/Document.py | 2 +- .../document_types/ImageDocument.py | 6 +- .../processing/document_types/PdfDocument.py | 6 +- .../processing/document_types/TextDocument.py | 6 +- .../document_types/UnstructuredDocument.py | 4 +- cognee/shared/utils.py | 39 ----------- cognee/tasks/chunks/chunk_by_paragraph.py | 16 ++--- .../extract_chunks_from_documents.py | 5 +- .../repo_processor/get_source_code_chunks.py | 13 ++-- 22 files changed, 176 insertions(+), 84 deletions(-) create mode 100644 cognee/infrastructure/llm/tokenizer/HuggingFace/__init__.py create mode 100644 cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py create mode 100644 cognee/infrastructure/llm/tokenizer/TikToken/__init__.py create mode 100644 cognee/infrastructure/llm/tokenizer/TikToken/adapter.py create mode 100644 cognee/infrastructure/llm/tokenizer/__init__.py create mode 100644 cognee/infrastructure/llm/tokenizer/tokenizer_interface.py diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index dc2af0cd5..4a864eb0e 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -71,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True): Task(ingest_data, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(classify_documents), - Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens), + Task(extract_chunks_from_documents), Task( extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50} ), diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index f0a40ca36..842256659 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -6,6 +6,9 @@ import litellm import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException +from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer +from transformers import AutoTokenizer +import tiktoken # Assuming this is how you import TikToken litellm.set_verbose = False logger = logging.getLogger("LiteLLMEmbeddingEngine") @@ -15,23 +18,30 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): api_key: str endpoint: str api_version: str + provider: str model: str dimensions: int mock: bool def __init__( self, + provider: str = "openai", model: Optional[str] = "text-embedding-3-large", dimensions: Optional[int] = 3072, api_key: str = None, endpoint: str = None, api_version: str = None, + max_tokens: int = float("inf"), ): self.api_key = api_key self.endpoint = endpoint self.api_version = api_version + # TODO: Add or remove provider info + self.provider = provider self.model = model self.dimensions = dimensions + self.max_tokens = max_tokens + self.tokenizer = self.set_tokenizer() enable_mocking = os.getenv("MOCK_EMBEDDING", "false") if isinstance(enable_mocking, bool): @@ -104,3 +114,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): def get_vector_size(self) -> int: return self.dimensions + + def set_tokenizer(self): + logger.debug(f"Loading tokenizer for model {self.model}...") + # If model also contains provider information, extract only model information + model = self.model.split("/")[-1] + + if "openai" in self.provider.lower() or "gpt" in self.model: + tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens) + else: + tokenizer = AutoTokenizer.from_pretrained(self.model) + + logger.debug(f"Tokenizer loaded for model: {self.model}") + return tokenizer diff --git a/cognee/infrastructure/databases/vector/embeddings/config.py b/cognee/infrastructure/databases/vector/embeddings/config.py index 042c063f8..62335ea41 100644 --- a/cognee/infrastructure/databases/vector/embeddings/config.py +++ b/cognee/infrastructure/databases/vector/embeddings/config.py @@ -9,7 +9,7 @@ class EmbeddingConfig(BaseSettings): embedding_endpoint: Optional[str] = None embedding_api_key: Optional[str] = None embedding_api_version: Optional[str] = None - + embedding_max_tokens: Optional[int] = float("inf") model_config = SettingsConfigDict(env_file=".env", extra="allow") diff --git a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py index 6bfb4dd15..e894da892 100644 --- a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +++ b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py @@ -15,4 +15,5 @@ def get_embedding_engine() -> EmbeddingEngine: api_version=config.embedding_api_version, model=config.embedding_model, dimensions=config.embedding_dimensions, + max_tokens=config.embedding_max_tokens, ) diff --git a/cognee/infrastructure/llm/tokenizer/HuggingFace/__init__.py b/cognee/infrastructure/llm/tokenizer/HuggingFace/__init__.py new file mode 100644 index 000000000..7cdfb9aa0 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/HuggingFace/__init__.py @@ -0,0 +1 @@ +from .adapter import HuggingFaceTokenizer diff --git a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py new file mode 100644 index 000000000..19238b62e --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py @@ -0,0 +1,22 @@ +from typing import List, Any + +from ..tokenizer_interface import TokenizerInterface + + +class HuggingFaceTokenizer(TokenizerInterface): + def __init__( + self, + model: str, + max_tokens: int = float("inf"), + ): + self.model = model + self.max_tokens = max_tokens + + def extract_tokens(self, text: str) -> List[Any]: + raise NotImplementedError + + def num_tokens_from_text(self, text: str) -> int: + raise NotImplementedError + + def trim_text_to_max_tokens(self, text: str) -> str: + raise NotImplementedError diff --git a/cognee/infrastructure/llm/tokenizer/TikToken/__init__.py b/cognee/infrastructure/llm/tokenizer/TikToken/__init__.py new file mode 100644 index 000000000..4c7a39401 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/TikToken/__init__.py @@ -0,0 +1 @@ +from .adapter import TikTokenTokenizer diff --git a/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py new file mode 100644 index 000000000..6ba1e0027 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py @@ -0,0 +1,69 @@ +from typing import List, Any +import tiktoken + +from ..tokenizer_interface import TokenizerInterface + + +class TikTokenTokenizer(TokenizerInterface): + """ + Tokenizer adapter for OpenAI. + Inteded to be used as part of LLM Embedding and LLM Adapters classes + """ + + def __init__( + self, + model: str, + max_tokens: int = float("inf"), + ): + self.model = model + self.max_tokens = max_tokens + # Initialize TikToken for GPT based on model + self.tokenizer = tiktoken.encoding_for_model(self.model) + + def extract_tokens(self, text: str) -> List[Any]: + tokens = [] + # Using TikToken's method to tokenize text + token_ids = self.tokenizer.encode(text) + # Go through tokens and decode them to text value + for token_id in token_ids: + token = self.tokenizer.decode([token_id]) + tokens.append(token) + return tokens + + def num_tokens_from_text(self, text: str) -> int: + """ + Returns the number of tokens in the given text. + Args: + text: str + + Returns: + number of tokens in the given text + + """ + num_tokens = len(self.tokenizer.encode(text)) + return num_tokens + + def trim_text_to_max_tokens(self, text: str) -> str: + """ + Trims the text so that the number of tokens does not exceed max_tokens. + + Args: + text (str): Original text string to be trimmed. + + Returns: + str: Trimmed version of text or original text if under the limit. + """ + # First check the number of tokens + num_tokens = self.num_tokens_from_string(text) + + # If the number of tokens is within the limit, return the text as is + if num_tokens <= self.max_tokens: + return text + + # If the number exceeds the limit, trim the text + # This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut + encoded_text = self.tokenizer.encode(text) + trimmed_encoded_text = encoded_text[: self.max_tokens] + # Decoding the trimmed text + trimmed_text = self.tokenizer.decode(trimmed_encoded_text) + return trimmed_text diff --git a/cognee/infrastructure/llm/tokenizer/__init__.py b/cognee/infrastructure/llm/tokenizer/__init__.py new file mode 100644 index 000000000..0cc895c13 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/__init__.py @@ -0,0 +1 @@ +from .tokenizer_interface import TokenizerInterface diff --git a/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py b/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py new file mode 100644 index 000000000..abd111f12 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py @@ -0,0 +1,18 @@ +from typing import List, Protocol, Any +from abc import abstractmethod + + +class TokenizerInterface(Protocol): + """Tokenizer interface""" + + @abstractmethod + def extract_tokens(self, text: str) -> List[Any]: + raise NotImplementedError + + @abstractmethod + def num_tokens_from_text(self, text: str) -> int: + raise NotImplementedError + + @abstractmethod + def trim_text_to_max_tokens(self, text: str) -> str: + raise NotImplementedError diff --git a/cognee/modules/chunking/TextChunker.py b/cognee/modules/chunking/TextChunker.py index cd71dd128..4fa246272 100644 --- a/cognee/modules/chunking/TextChunker.py +++ b/cognee/modules/chunking/TextChunker.py @@ -14,17 +14,22 @@ class TextChunker: chunk_size = 0 token_count = 0 - def __init__( - self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024 - ): + def __init__(self, document, get_text: callable, chunk_size: int = 1024): self.document = document self.max_chunk_size = chunk_size self.get_text = get_text - self.max_tokens = max_tokens if max_tokens else float("inf") def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data): word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size - token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens + + # Get embedding engine related to vector database + from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine + + embedding_engine = get_vector_engine().embedding_engine + + token_count_fits = ( + token_count_before + chunk_data["token_count"] <= embedding_engine.max_tokens + ) return word_count_fits and token_count_fits def read(self): @@ -32,7 +37,6 @@ class TextChunker: for content_text in self.get_text(): for chunk_data in chunk_by_paragraph( content_text, - self.max_tokens, self.max_chunk_size, batch_paragraphs=True, ): diff --git a/cognee/modules/cognify/config.py b/cognee/modules/cognify/config.py index dd94d8b41..4ba0f4bd6 100644 --- a/cognee/modules/cognify/config.py +++ b/cognee/modules/cognify/config.py @@ -8,7 +8,6 @@ import os class CognifyConfig(BaseSettings): classification_model: object = DefaultContentPrediction summarization_model: object = SummarizedContent - max_tokens: Optional[int] = os.getenv("MAX_TOKENS") model_config = SettingsConfigDict(env_file=".env", extra="allow") def to_dict(self) -> dict: diff --git a/cognee/modules/data/processing/document_types/AudioDocument.py b/cognee/modules/data/processing/document_types/AudioDocument.py index b7d2476b4..77c700482 100644 --- a/cognee/modules/data/processing/document_types/AudioDocument.py +++ b/cognee/modules/data/processing/document_types/AudioDocument.py @@ -13,14 +13,12 @@ class AudioDocument(Document): result = get_llm_client().create_transcript(self.raw_data_location) return result.text - def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): + def read(self, chunk_size: int, chunker: str): # Transcribe the audio file text = self.create_transcript() chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func( - self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens - ) + chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/Document.py b/cognee/modules/data/processing/document_types/Document.py index e1bbb70ba..e3a56a0a5 100644 --- a/cognee/modules/data/processing/document_types/Document.py +++ b/cognee/modules/data/processing/document_types/Document.py @@ -11,5 +11,5 @@ class Document(DataPoint): mime_type: str _metadata: dict = {"index_fields": ["name"], "type": "Document"} - def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str: + def read(self, chunk_size: int, chunker=str) -> str: pass diff --git a/cognee/modules/data/processing/document_types/ImageDocument.py b/cognee/modules/data/processing/document_types/ImageDocument.py index c055b8253..3b5c27d75 100644 --- a/cognee/modules/data/processing/document_types/ImageDocument.py +++ b/cognee/modules/data/processing/document_types/ImageDocument.py @@ -13,13 +13,11 @@ class ImageDocument(Document): result = get_llm_client().transcribe_image(self.raw_data_location) return result.choices[0].message.content - def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): + def read(self, chunk_size: int, chunker: str): # Transcribe the image file text = self.transcribe_image() chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func( - self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens - ) + chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/PdfDocument.py b/cognee/modules/data/processing/document_types/PdfDocument.py index 768f91264..0142610bf 100644 --- a/cognee/modules/data/processing/document_types/PdfDocument.py +++ b/cognee/modules/data/processing/document_types/PdfDocument.py @@ -9,7 +9,7 @@ from .Document import Document class PdfDocument(Document): type: str = "pdf" - def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): + def read(self, chunk_size: int, chunker: str): file = PdfReader(self.raw_data_location) def get_text(): @@ -18,9 +18,7 @@ class PdfDocument(Document): yield page_text chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func( - self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens - ) + chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/TextDocument.py b/cognee/modules/data/processing/document_types/TextDocument.py index b62ccd56e..692edcb89 100644 --- a/cognee/modules/data/processing/document_types/TextDocument.py +++ b/cognee/modules/data/processing/document_types/TextDocument.py @@ -7,7 +7,7 @@ from .Document import Document class TextDocument(Document): type: str = "text" - def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): + def read(self, chunk_size: int, chunker: str): def get_text(): with open(self.raw_data_location, mode="r", encoding="utf-8") as file: while True: @@ -20,8 +20,6 @@ class TextDocument(Document): chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func( - self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens - ) + chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/UnstructuredDocument.py b/cognee/modules/data/processing/document_types/UnstructuredDocument.py index 1c291d0dc..b3849616f 100644 --- a/cognee/modules/data/processing/document_types/UnstructuredDocument.py +++ b/cognee/modules/data/processing/document_types/UnstructuredDocument.py @@ -10,7 +10,7 @@ from .Document import Document class UnstructuredDocument(Document): type: str = "unstructured" - def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str: + def read(self, chunk_size: int, chunker: str) -> str: def get_text(): try: from unstructured.partition.auto import partition @@ -29,6 +29,6 @@ class UnstructuredDocument(Document): yield text - chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens) + chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text) yield from chunker.read() diff --git a/cognee/shared/utils.py b/cognee/shared/utils.py index 6b1ca7f8f..4e5523fd2 100644 --- a/cognee/shared/utils.py +++ b/cognee/shared/utils.py @@ -10,8 +10,6 @@ import graphistry import networkx as nx import pandas as pd import matplotlib.pyplot as plt -import tiktoken -import time import logging import sys @@ -100,15 +98,6 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}): print(f"Error sending telemetry through proxy: {response.status_code}") -def num_tokens_from_string(string: str, encoding_name: str) -> int: - """Returns the number of tokens in a text string.""" - - # tiktoken.get_encoding("cl100k_base") - encoding = tiktoken.encoding_for_model(encoding_name) - num_tokens = len(encoding.encode(string)) - return num_tokens - - def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str: h = hashlib.md5() @@ -134,34 +123,6 @@ def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str: raise IngestionError(message=f"Failed to load data from {file}: {e}") -def trim_text_to_max_tokens(text: str, max_tokens: int, encoding_name: str) -> str: - """ - Trims the text so that the number of tokens does not exceed max_tokens. - - Args: - text (str): Original text string to be trimmed. - max_tokens (int): Maximum number of tokens allowed. - encoding_name (str): The name of the token encoding to use. - - Returns: - str: Trimmed version of text or original text if under the limit. - """ - # First check the number of tokens - num_tokens = num_tokens_from_string(text, encoding_name) - - # If the number of tokens is within the limit, return the text as is - if num_tokens <= max_tokens: - return text - - # If the number exceeds the limit, trim the text - # This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut - encoded_text = tiktoken.get_encoding(encoding_name).encode(text) - trimmed_encoded_text = encoded_text[:max_tokens] - # Decoding the trimmed text - trimmed_text = tiktoken.get_encoding(encoding_name).decode(trimmed_encoded_text) - return trimmed_text - - def generate_color_palette(unique_layers): colormap = plt.cm.get_cmap("viridis", len(unique_layers)) colors = [colormap(i) for i in range(len(unique_layers))] diff --git a/cognee/tasks/chunks/chunk_by_paragraph.py b/cognee/tasks/chunks/chunk_by_paragraph.py index 24d566074..9cb88935e 100644 --- a/cognee/tasks/chunks/chunk_by_paragraph.py +++ b/cognee/tasks/chunks/chunk_by_paragraph.py @@ -4,13 +4,13 @@ from uuid import NAMESPACE_OID, uuid5 import tiktoken from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine from .chunk_by_sentence import chunk_by_sentence def chunk_by_paragraph( data: str, - max_tokens: Optional[Union[int, float]] = None, paragraph_length: int = 1024, batch_paragraphs: bool = True, ) -> Iterator[Dict[str, Any]]: @@ -24,24 +24,22 @@ def chunk_by_paragraph( paragraph_ids = [] last_cut_type = None current_token_count = 0 - if not max_tokens: - max_tokens = float("inf") + # Get vector and embedding engine vector_engine = get_vector_engine() - embedding_model = vector_engine.embedding_engine.model - embedding_model = embedding_model.split("/")[-1] + embedding_engine = vector_engine.embedding_engine + + # embedding_model = embedding_engine.model.split("/")[-1] for paragraph_id, sentence, word_count, end_type in chunk_by_sentence( data, maximum_length=paragraph_length ): # Check if this sentence would exceed length limit - - tokenizer = tiktoken.encoding_for_model(embedding_model) - token_count = len(tokenizer.encode(sentence)) + token_count = embedding_engine.tokenizer.num_tokens_from_text(sentence) if current_word_count > 0 and ( current_word_count + word_count > paragraph_length - or current_token_count + token_count > max_tokens + or current_token_count + token_count > embedding_engine.max_tokens ): # Yield current chunk chunk_dict = { diff --git a/cognee/tasks/documents/extract_chunks_from_documents.py b/cognee/tasks/documents/extract_chunks_from_documents.py index 5ce224002..6b239975c 100644 --- a/cognee/tasks/documents/extract_chunks_from_documents.py +++ b/cognee/tasks/documents/extract_chunks_from_documents.py @@ -7,10 +7,7 @@ async def extract_chunks_from_documents( documents: list[Document], chunk_size: int = 1024, chunker="text_chunker", - max_tokens: Optional[int] = None, ): for document in documents: - for document_chunk in document.read( - chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens - ): + for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker): yield document_chunk diff --git a/cognee/tasks/repo_processor/get_source_code_chunks.py b/cognee/tasks/repo_processor/get_source_code_chunks.py index 82fa46cf0..ada71e596 100644 --- a/cognee/tasks/repo_processor/get_source_code_chunks.py +++ b/cognee/tasks/repo_processor/get_source_code_chunks.py @@ -89,26 +89,31 @@ def _get_subchunk_token_counts( def _get_chunk_source_code( - code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int + code_token_counts: list[tuple[str, int]], overlap: float ) -> tuple[list[tuple[str, int]], str]: """Generates a chunk of source code from tokenized subchunks with overlap handling.""" current_count = 0 cumulative_counts = [] current_source_code = "" + # Get embedding engine used in vector database + from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine + + embedding_engine = get_vector_engine().embedding_engine + for i, (child_code, token_count) in enumerate(code_token_counts): current_count += token_count cumulative_counts.append(current_count) - if current_count > max_tokens: + if current_count > embedding_engine.max_tokens: break current_source_code += f"\n{child_code}" - if current_count <= max_tokens: + if current_count <= embedding_engine.max_tokens: return [], current_source_code.strip() cutoff = 1 for i, cum_count in enumerate(cumulative_counts): - if cum_count > (1 - overlap) * max_tokens: + if cum_count > (1 - overlap) * embedding_engine.max_tokens: break cutoff = i From 294ed1d960cd4285046b5bfe0168b31b67ce8812 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 23 Jan 2025 16:52:35 +0100 Subject: [PATCH 02/16] feat: Add HuggingFace Tokenizer support --- .../embeddings/LiteLLMEmbeddingEngine.py | 3 ++- .../llm/tokenizer/HuggingFace/adapter.py | 18 +++++++++++++-- .../tests/unit/processing/utils/utils_test.py | 23 ------------------- 3 files changed, 18 insertions(+), 26 deletions(-) diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 842256659..c037b45e0 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -6,6 +6,7 @@ import litellm import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException +from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer from transformers import AutoTokenizer import tiktoken # Assuming this is how you import TikToken @@ -123,7 +124,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): if "openai" in self.provider.lower() or "gpt" in self.model: tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens) else: - tokenizer = AutoTokenizer.from_pretrained(self.model) + tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens) logger.debug(f"Tokenizer loaded for model: {self.model}") return tokenizer diff --git a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py index 19238b62e..7b92fb76b 100644 --- a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py @@ -1,5 +1,7 @@ from typing import List, Any +from transformers import AutoTokenizer + from ..tokenizer_interface import TokenizerInterface @@ -12,11 +14,23 @@ class HuggingFaceTokenizer(TokenizerInterface): self.model = model self.max_tokens = max_tokens + self.tokenizer = AutoTokenizer.from_pretrained(model) + def extract_tokens(self, text: str) -> List[Any]: - raise NotImplementedError + tokens = self.tokenizer.tokenize(text) + return tokens def num_tokens_from_text(self, text: str) -> int: - raise NotImplementedError + """ + Returns the number of tokens in the given text. + Args: + text: str + + Returns: + number of tokens in the given text + + """ + return len(self.tokenizer.tokenize(text)) def trim_text_to_max_tokens(self, text: str) -> str: raise NotImplementedError diff --git a/cognee/tests/unit/processing/utils/utils_test.py b/cognee/tests/unit/processing/utils/utils_test.py index f8c325100..067ab6ea7 100644 --- a/cognee/tests/unit/processing/utils/utils_test.py +++ b/cognee/tests/unit/processing/utils/utils_test.py @@ -11,9 +11,7 @@ from cognee.shared.exceptions import IngestionError from cognee.shared.utils import ( get_anonymous_id, send_telemetry, - num_tokens_from_string, get_file_content_hash, - trim_text_to_max_tokens, prepare_edges, prepare_nodes, create_cognee_style_network_with_logo, @@ -45,15 +43,6 @@ def test_get_anonymous_id(mock_open_file, mock_makedirs, temp_dir): # args, kwargs = mock_post.call_args # assert kwargs["json"]["event_name"] == "test_event" -# -# @patch("tiktoken.encoding_for_model") -# def test_num_tokens_from_string(mock_encoding): -# mock_encoding.return_value.encode = lambda x: list(x) -# -# assert num_tokens_from_string("hello", "test_encoding") == 5 -# assert num_tokens_from_string("world", "test_encoding") == 5 -# - @patch("builtins.open", new_callable=mock_open, read_data=b"test_data") def test_get_file_content_hash_file(mock_open_file): @@ -73,18 +62,6 @@ def test_get_file_content_hash_stream(): assert result == expected_hash -# def test_trim_text_to_max_tokens(): -# text = "This is a test string with multiple words." -# encoding_name = "test_encoding" -# -# with patch("tiktoken.get_encoding") as mock_get_encoding: -# mock_get_encoding.return_value.encode = lambda x: list(x) -# mock_get_encoding.return_value.decode = lambda x: "".join(x) -# -# result = trim_text_to_max_tokens(text, 5, encoding_name) -# assert result == text[:5] - - def test_prepare_edges(): graph = nx.MultiDiGraph() graph.add_edge("A", "B", key="AB", weight=1) From b686376c5415ebeb7b9b89d86e66635e8e0b7d9f Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 23 Jan 2025 17:55:04 +0100 Subject: [PATCH 03/16] feat: Add gemini tokenizer to cognee --- .../embeddings/LiteLLMEmbeddingEngine.py | 5 ++- .../databases/vector/embeddings/config.py | 3 +- .../vector/embeddings/get_embedding_engine.py | 1 + .../llm/tokenizer/Gemini/__init__.py | 1 + .../llm/tokenizer/Gemini/adapter.py | 44 +++++++++++++++++++ 5 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 cognee/infrastructure/llm/tokenizer/Gemini/__init__.py create mode 100644 cognee/infrastructure/llm/tokenizer/Gemini/adapter.py diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index c037b45e0..50dde8e89 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -6,6 +6,7 @@ import litellm import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException +from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer from transformers import AutoTokenizer @@ -121,8 +122,10 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): # If model also contains provider information, extract only model information model = self.model.split("/")[-1] - if "openai" in self.provider.lower() or "gpt" in self.model: + if "openai" in self.provider.lower(): tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens) + elif "gemini" in self.provider.lower(): + tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens) else: tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens) diff --git a/cognee/infrastructure/databases/vector/embeddings/config.py b/cognee/infrastructure/databases/vector/embeddings/config.py index 62335ea41..cb72a46f4 100644 --- a/cognee/infrastructure/databases/vector/embeddings/config.py +++ b/cognee/infrastructure/databases/vector/embeddings/config.py @@ -4,7 +4,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class EmbeddingConfig(BaseSettings): - embedding_model: Optional[str] = "text-embedding-3-large" + embedding_provider: Optional[str] = "openai" + embedding_model: Optional[str] = "openai/text-embedding-3-large" embedding_dimensions: Optional[int] = 3072 embedding_endpoint: Optional[str] = None embedding_api_key: Optional[str] = None diff --git a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py index e894da892..d3011f059 100644 --- a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +++ b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py @@ -10,6 +10,7 @@ def get_embedding_engine() -> EmbeddingEngine: return LiteLLMEmbeddingEngine( # If OpenAI API is used for embeddings, litellm needs only the api_key. + provider=config.embedding_provider, api_key=config.embedding_api_key or llm_config.llm_api_key, endpoint=config.embedding_endpoint, api_version=config.embedding_api_version, diff --git a/cognee/infrastructure/llm/tokenizer/Gemini/__init__.py b/cognee/infrastructure/llm/tokenizer/Gemini/__init__.py new file mode 100644 index 000000000..4ed4ad4d5 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/Gemini/__init__.py @@ -0,0 +1 @@ +from .adapter import GeminiTokenizer diff --git a/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py new file mode 100644 index 000000000..697bc9577 --- /dev/null +++ b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py @@ -0,0 +1,44 @@ +from typing import List, Any + +from ..tokenizer_interface import TokenizerInterface + + +class GeminiTokenizer(TokenizerInterface): + def __init__( + self, + model: str, + max_tokens: int = float("inf"), + ): + self.model = model + self.max_tokens = max_tokens + + # Get LLM API key from config + from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config + from cognee.infrastructure.llm.config import get_llm_config + + config = get_embedding_config() + llm_config = get_llm_config() + + import google.generativeai as genai + + genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key) + + def extract_tokens(self, text: str) -> List[Any]: + raise NotImplementedError + + def num_tokens_from_text(self, text: str) -> int: + """ + Returns the number of tokens in the given text. + Args: + text: str + + Returns: + number of tokens in the given text + + """ + import google.generativeai as genai + + return len(genai.embed_content(model=f"models/{self.model}", content=text)) + + def trim_text_to_max_tokens(self, text: str) -> str: + raise NotImplementedError From b25a82e206bbdce98ca7a1a8dbec14fd6a620c0e Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 23 Jan 2025 17:56:56 +0100 Subject: [PATCH 04/16] chore: Add google-generativeai as gemini optional dependency to Cognee --- poetry.lock | 240 ++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 5 +- 2 files changed, 241 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index 48a7b9d10..64ee8f073 100644 --- a/poetry.lock +++ b/poetry.lock @@ -645,6 +645,17 @@ urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version > [package.extras] crt = ["awscrt (==0.23.4)"] +[[package]] +name = "cachetools" +version = "5.5.1" +description = "Extensible memoizing collections and decorators" +optional = true +python-versions = ">=3.7" +files = [ + {file = "cachetools-5.5.1-py3-none-any.whl", hash = "sha256:b76651fdc3b24ead3c648bbdeeb940c1b04d365b38b4af66788f9ec4a81d42bb"}, + {file = "cachetools-5.5.1.tar.gz", hash = "sha256:70f238fbba50383ef62e55c6aff6d9673175fe59f7c6782c7a0b9e38f4a9df95"}, +] + [[package]] name = "certifi" version = "2024.12.14" @@ -1995,6 +2006,135 @@ files = [ {file = "giturlparse-0.12.0.tar.gz", hash = "sha256:c0fff7c21acc435491b1779566e038757a205c1ffdcb47e4f81ea52ad8c3859a"}, ] +[[package]] +name = "google-ai-generativelanguage" +version = "0.6.15" +description = "Google Ai Generativelanguage API client library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google_ai_generativelanguage-0.6.15-py3-none-any.whl", hash = "sha256:5a03ef86377aa184ffef3662ca28f19eeee158733e45d7947982eb953c6ebb6c"}, + {file = "google_ai_generativelanguage-0.6.15.tar.gz", hash = "sha256:8f6d9dc4c12b065fe2d0289026171acea5183ebf2d0b11cefe12f3821e159ec3"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" + +[[package]] +name = "google-api-core" +version = "2.24.0" +description = "Google API client core library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google_api_core-2.24.0-py3-none-any.whl", hash = "sha256:10d82ac0fca69c82a25b3efdeefccf6f28e02ebb97925a8cce8edbfe379929d9"}, + {file = "google_api_core-2.24.0.tar.gz", hash = "sha256:e255640547a597a4da010876d333208ddac417d60add22b6851a0c66a831fcaf"}, +] + +[package.dependencies] +google-auth = ">=2.14.1,<3.0.dev0" +googleapis-common-protos = ">=1.56.2,<2.0.dev0" +grpcio = [ + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, +] +grpcio-status = [ + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, +] +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" +requests = ">=2.18.0,<3.0.0.dev0" + +[package.extras] +async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.dev0)"] +grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"] +grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] + +[[package]] +name = "google-api-python-client" +version = "2.159.0" +description = "Google API Client Library for Python" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google_api_python_client-2.159.0-py2.py3-none-any.whl", hash = "sha256:baef0bb631a60a0bd7c0bf12a5499e3a40cd4388484de7ee55c1950bf820a0cf"}, + {file = "google_api_python_client-2.159.0.tar.gz", hash = "sha256:55197f430f25c907394b44fa078545ffef89d33fd4dca501b7db9f0d8e224bd6"}, +] + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0" +google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0" +google-auth-httplib2 = ">=0.2.0,<1.0.0" +httplib2 = ">=0.19.0,<1.dev0" +uritemplate = ">=3.0.1,<5" + +[[package]] +name = "google-auth" +version = "2.38.0" +description = "Google Authentication Library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google_auth-2.38.0-py2.py3-none-any.whl", hash = "sha256:e7dae6694313f434a2727bf2906f27ad259bae090d7aa896590d86feec3d9d4a"}, + {file = "google_auth-2.38.0.tar.gz", hash = "sha256:8285113607d3b80a3f1543b75962447ba8a09fe85783432a784fdeef6ac094c4"}, +] + +[package.dependencies] +cachetools = ">=2.0.0,<6.0" +pyasn1-modules = ">=0.2.1" +rsa = ">=3.1.4,<5" + +[package.extras] +aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"] +enterprise-cert = ["cryptography", "pyopenssl"] +pyjwt = ["cryptography (>=38.0.3)", "pyjwt (>=2.0)"] +pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] +reauth = ["pyu2f (>=0.1.5)"] +requests = ["requests (>=2.20.0,<3.0.0.dev0)"] + +[[package]] +name = "google-auth-httplib2" +version = "0.2.0" +description = "Google Authentication Library: httplib2 transport" +optional = true +python-versions = "*" +files = [ + {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"}, + {file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"}, +] + +[package.dependencies] +google-auth = "*" +httplib2 = ">=0.19.0" + +[[package]] +name = "google-generativeai" +version = "0.8.4" +description = "Google Generative AI High level API client library and tools." +optional = true +python-versions = ">=3.9" +files = [ + {file = "google_generativeai-0.8.4-py3-none-any.whl", hash = "sha256:e987b33ea6decde1e69191ddcaec6ef974458864d243de7191db50c21a7c5b82"}, +] + +[package.dependencies] +google-ai-generativelanguage = "0.6.15" +google-api-core = "*" +google-api-python-client = "*" +google-auth = ">=2.15.0" +protobuf = "*" +pydantic = "*" +tqdm = "*" +typing-extensions = "*" + +[package.extras] +dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] + [[package]] name = "googleapis-common-protos" version = "1.66.0" @@ -2251,6 +2391,22 @@ files = [ grpcio = ">=1.67.1" protobuf = ">=5.26.1,<6.0dev" +[[package]] +name = "grpcio-status" +version = "1.67.1" +description = "Status proto mapping for gRPC" +optional = true +python-versions = ">=3.8" +files = [ + {file = "grpcio_status-1.67.1-py3-none-any.whl", hash = "sha256:16e6c085950bdacac97c779e6a502ea671232385e6e37f258884d6883392c2bd"}, + {file = "grpcio_status-1.67.1.tar.gz", hash = "sha256:2bf38395e028ceeecfd8866b081f61628114b384da7d51ae064ddc8d766a5d11"}, +] + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.67.1" +protobuf = ">=5.26.1,<6.0dev" + [[package]] name = "grpcio-tools" version = "1.67.1" @@ -2445,6 +2601,20 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<1.0)"] +[[package]] +name = "httplib2" +version = "0.22.0" +description = "A comprehensive HTTP client library." +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"}, + {file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"}, +] + +[package.dependencies] +pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} + [[package]] name = "httpx" version = "0.27.0" @@ -5516,6 +5686,23 @@ files = [ {file = "propcache-0.2.1.tar.gz", hash = "sha256:3f77ce728b19cb537714499928fe800c3dda29e8d9428778fc7c186da4c09a64"}, ] +[[package]] +name = "proto-plus" +version = "1.25.0" +description = "Beautiful, Pythonic protocol buffers." +optional = true +python-versions = ">=3.7" +files = [ + {file = "proto_plus-1.25.0-py3-none-any.whl", hash = "sha256:c91fc4a65074ade8e458e95ef8bac34d4008daa7cce4a12d6707066fca648961"}, + {file = "proto_plus-1.25.0.tar.gz", hash = "sha256:fbb17f57f7bd05a68b7707e745e26528b0b3c34e378db91eef93912c54982d91"}, +] + +[package.dependencies] +protobuf = ">=3.19.0,<6.0.0dev" + +[package.extras] +testing = ["google-api-core (>=1.31.5)"] + [[package]] name = "protobuf" version = "5.29.3" @@ -5682,6 +5869,31 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pyasn1" +version = "0.6.1" +description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, + {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.1" +description = "A collection of ASN.1-based protocols modules" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, + {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, +] + +[package.dependencies] +pyasn1 = ">=0.4.6,<0.7.0" + [[package]] name = "pycparser" version = "2.22" @@ -6951,6 +7163,20 @@ files = [ {file = "rpds_py-0.22.3.tar.gz", hash = "sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d"}, ] +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +optional = true +python-versions = ">=3.6,<4" +files = [ + {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, + {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, +] + +[package.dependencies] +pyasn1 = ">=0.1.3" + [[package]] name = "ruff" version = "0.9.2" @@ -8202,6 +8428,17 @@ files = [ [package.extras] dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"] +[[package]] +name = "uritemplate" +version = "4.1.1" +description = "Implementation of RFC 6570 URI Templates" +optional = true +python-versions = ">=3.6" +files = [ + {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"}, + {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"}, +] + [[package]] name = "urllib3" version = "2.3.0" @@ -8783,6 +9020,7 @@ deepeval = ["deepeval"] docs = ["unstructured"] falkordb = ["falkordb"] filesystem = ["botocore"] +gemini = ["google-generativeai"] groq = ["groq"] langchain = ["langchain_text_splitters", "langsmith"] llama-index = ["llama-index-core"] @@ -8797,4 +9035,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.10.0,<3.13" -content-hash = "097955773827cdf96b42e54328f66b79e2b92e5a7f221a06afe1a71fea2c33bc" +content-hash = "b74880407c173a0b15631d3f2197e2f66cc72fbe1d80f93b20140c8779e2bcc2" diff --git a/pyproject.toml b/pyproject.toml index ea0b64404..15189ff1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,9 +77,7 @@ pre-commit = "^4.0.1" httpx = "0.27.0" bokeh="^3.6.2" nltk = "3.9.1" - - - +google-generativeai = {version = "^0.8.4", optional = true} [tool.poetry.extras] @@ -91,6 +89,7 @@ postgres = ["psycopg2", "pgvector", "asyncpg"] notebook = ["notebook", "ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"] langchain = ["langsmith", "langchain_text_splitters"] llama-index = ["llama-index-core"] +gemini = ["google-generativeai"] deepeval = ["deepeval"] posthog = ["posthog"] falkordb = ["falkordb"] From 1319944dcd6e4a325289b9474d976dab49df25fb Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 23 Jan 2025 18:05:45 +0100 Subject: [PATCH 05/16] docs: Update .env.template to include llm and embedding options --- .env.template | 17 ++++++++++++++++- .../vector/embeddings/LiteLLMEmbeddingEngine.py | 2 -- .../databases/vector/embeddings/config.py | 2 +- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/.env.template b/.env.template index 75a57de4d..6ce9fedf9 100644 --- a/.env.template +++ b/.env.template @@ -1,12 +1,27 @@ ENV="local" TOKENIZERS_PARALLELISM="false" -LLM_API_KEY= + +# LLM settings +LLM_API_KEY="" +LLM_MODEL="openai/gpt-4o-mini" +LLM_PROVIDER="openai" +LLM_ENDPOINT="" +LLM_API_VERSION="" GRAPHISTRY_USERNAME= GRAPHISTRY_PASSWORD= SENTRY_REPORTING_URL= +# Embedding settings +EMBEDDING_PROVIDER="openai" +EMBEDDING_API_KEY="" +EMBEDDING_MODEL="openai/text-embedding-3-large" +EMBEDDING_ENDPOINT="" +EMBEDDING_API_VERSION="" +EMBEDDING_DIMENSIONS=3072 +EMBEDDING_MAX_TOKENS=8191 + # "neo4j" or "networkx" GRAPH_DATABASE_PROVIDER="networkx" # Not needed if using networkx diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 50dde8e89..cb84337c2 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -9,8 +9,6 @@ from cognee.infrastructure.databases.exceptions.EmbeddingException import Embedd from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer -from transformers import AutoTokenizer -import tiktoken # Assuming this is how you import TikToken litellm.set_verbose = False logger = logging.getLogger("LiteLLMEmbeddingEngine") diff --git a/cognee/infrastructure/databases/vector/embeddings/config.py b/cognee/infrastructure/databases/vector/embeddings/config.py index cb72a46f4..315caf7ef 100644 --- a/cognee/infrastructure/databases/vector/embeddings/config.py +++ b/cognee/infrastructure/databases/vector/embeddings/config.py @@ -10,7 +10,7 @@ class EmbeddingConfig(BaseSettings): embedding_endpoint: Optional[str] = None embedding_api_key: Optional[str] = None embedding_api_version: Optional[str] = None - embedding_max_tokens: Optional[int] = float("inf") + embedding_max_tokens: Optional[int] = 8191 model_config = SettingsConfigDict(env_file=".env", extra="allow") From 7dea1d54d7f0f97012c4e4be4a29dd661fa80d19 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 23 Jan 2025 18:18:45 +0100 Subject: [PATCH 06/16] refactor: Add specific max token values to embedding models --- .../databases/vector/embeddings/LiteLLMEmbeddingEngine.py | 3 +-- cognee/infrastructure/llm/tokenizer/Gemini/adapter.py | 2 +- cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py | 2 +- cognee/infrastructure/llm/tokenizer/TikToken/adapter.py | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index cb84337c2..10992b22c 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -31,12 +31,11 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): api_key: str = None, endpoint: str = None, api_version: str = None, - max_tokens: int = float("inf"), + max_tokens: int = 512, ): self.api_key = api_key self.endpoint = endpoint self.api_version = api_version - # TODO: Add or remove provider info self.provider = provider self.model = model self.dimensions = dimensions diff --git a/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py index 697bc9577..f3131ea08 100644 --- a/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py @@ -7,7 +7,7 @@ class GeminiTokenizer(TokenizerInterface): def __init__( self, model: str, - max_tokens: int = float("inf"), + max_tokens: int = 3072, ): self.model = model self.max_tokens = max_tokens diff --git a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py index 7b92fb76b..a8eac29d9 100644 --- a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py @@ -9,7 +9,7 @@ class HuggingFaceTokenizer(TokenizerInterface): def __init__( self, model: str, - max_tokens: int = float("inf"), + max_tokens: int = 512, ): self.model = model self.max_tokens = max_tokens diff --git a/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py index 6ba1e0027..862a79296 100644 --- a/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py @@ -13,7 +13,7 @@ class TikTokenTokenizer(TokenizerInterface): def __init__( self, model: str, - max_tokens: int = float("inf"), + max_tokens: int = 8191, ): self.model = model self.max_tokens = max_tokens From 844d99cb720eaeba81d29c66586c55a4cdd4d93d Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 23 Jan 2025 18:24:26 +0100 Subject: [PATCH 07/16] docs: Remove commented code --- cognee/tasks/chunks/chunk_by_paragraph.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cognee/tasks/chunks/chunk_by_paragraph.py b/cognee/tasks/chunks/chunk_by_paragraph.py index 6504fd434..077db1cd4 100644 --- a/cognee/tasks/chunks/chunk_by_paragraph.py +++ b/cognee/tasks/chunks/chunk_by_paragraph.py @@ -35,8 +35,6 @@ def chunk_by_paragraph( vector_engine = get_vector_engine() embedding_engine = vector_engine.embedding_engine - # embedding_model = embedding_engine.model.split("/")[-1] - for paragraph_id, sentence, word_count, end_type in chunk_by_sentence( data, maximum_length=paragraph_length ): From 902979c1de76c86b58d64fc5d1c05dd497b0a96f Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Fri, 24 Jan 2025 13:40:10 +0100 Subject: [PATCH 08/16] refactor: Refactor get source code chunks based on tokenizer rework --- .../tasks/repo_processor/get_source_code_chunks.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/cognee/tasks/repo_processor/get_source_code_chunks.py b/cognee/tasks/repo_processor/get_source_code_chunks.py index ada71e596..358d5218d 100644 --- a/cognee/tasks/repo_processor/get_source_code_chunks.py +++ b/cognee/tasks/repo_processor/get_source_code_chunks.py @@ -122,21 +122,18 @@ def _get_chunk_source_code( def get_source_code_chunks_from_code_part( code_file_part: CodePart, - max_tokens: int = 8192, overlap: float = 0.25, granularity: float = 0.1, - model_name: str = "text-embedding-3-large", ) -> Generator[SourceCodeChunk, None, None]: """Yields source code chunks from a CodePart object, with configurable token limits and overlap.""" if not code_file_part.source_code: logger.error(f"No source code in CodeFile {code_file_part.id}") return - vector_engine = get_vector_engine() - embedding_model = vector_engine.embedding_engine.model - model_name = embedding_model.split("/")[-1] - tokenizer = tiktoken.encoding_for_model(model_name) - max_subchunk_tokens = max(1, int(granularity * max_tokens)) + embedding_engine = get_vector_engine().embedding_engine + tokenizer = embedding_engine.tokenizer + + max_subchunk_tokens = max(1, int(granularity * embedding_engine.max_tokens)) subchunk_token_counts = _get_subchunk_token_counts( tokenizer, code_file_part.source_code, max_subchunk_tokens ) @@ -144,7 +141,7 @@ def get_source_code_chunks_from_code_part( previous_chunk = None while subchunk_token_counts: subchunk_token_counts, chunk_source_code = _get_chunk_source_code( - subchunk_token_counts, overlap, max_tokens + subchunk_token_counts, overlap ) if not chunk_source_code: continue From 0a9f1349f21b624d351be35bf2ac2b8bcf6ad076 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 28 Jan 2025 10:10:29 +0100 Subject: [PATCH 09/16] refactor: Change variable and function names based on PR comments Change variable and function names based on PR comments --- .../databases/vector/embeddings/LiteLLMEmbeddingEngine.py | 4 ++-- cognee/infrastructure/llm/tokenizer/Gemini/adapter.py | 2 +- cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py | 2 +- cognee/infrastructure/llm/tokenizer/TikToken/adapter.py | 4 ++-- cognee/infrastructure/llm/tokenizer/tokenizer_interface.py | 2 +- cognee/tasks/chunks/chunk_by_paragraph.py | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 10992b22c..f81bc8515 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -40,7 +40,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): self.model = model self.dimensions = dimensions self.max_tokens = max_tokens - self.tokenizer = self.set_tokenizer() + self.tokenizer = self.get_tokenizer() enable_mocking = os.getenv("MOCK_EMBEDDING", "false") if isinstance(enable_mocking, bool): @@ -114,7 +114,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): def get_vector_size(self) -> int: return self.dimensions - def set_tokenizer(self): + def get_tokenizer(self): logger.debug(f"Loading tokenizer for model {self.model}...") # If model also contains provider information, extract only model information model = self.model.split("/")[-1] diff --git a/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py index f3131ea08..e4cc4f145 100644 --- a/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py @@ -26,7 +26,7 @@ class GeminiTokenizer(TokenizerInterface): def extract_tokens(self, text: str) -> List[Any]: raise NotImplementedError - def num_tokens_from_text(self, text: str) -> int: + def count_tokens(self, text: str) -> int: """ Returns the number of tokens in the given text. Args: diff --git a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py index a8eac29d9..878458414 100644 --- a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py @@ -20,7 +20,7 @@ class HuggingFaceTokenizer(TokenizerInterface): tokens = self.tokenizer.tokenize(text) return tokens - def num_tokens_from_text(self, text: str) -> int: + def count_tokens(self, text: str) -> int: """ Returns the number of tokens in the given text. Args: diff --git a/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py index 862a79296..3d649ef38 100644 --- a/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py @@ -30,7 +30,7 @@ class TikTokenTokenizer(TokenizerInterface): tokens.append(token) return tokens - def num_tokens_from_text(self, text: str) -> int: + def count_tokens(self, text: str) -> int: """ Returns the number of tokens in the given text. Args: @@ -54,7 +54,7 @@ class TikTokenTokenizer(TokenizerInterface): str: Trimmed version of text or original text if under the limit. """ # First check the number of tokens - num_tokens = self.num_tokens_from_string(text) + num_tokens = self.count_tokens(text) # If the number of tokens is within the limit, return the text as is if num_tokens <= self.max_tokens: diff --git a/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py b/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py index abd111f12..c533f0cf9 100644 --- a/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py +++ b/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py @@ -10,7 +10,7 @@ class TokenizerInterface(Protocol): raise NotImplementedError @abstractmethod - def num_tokens_from_text(self, text: str) -> int: + def count_tokens(self, text: str) -> int: raise NotImplementedError @abstractmethod diff --git a/cognee/tasks/chunks/chunk_by_paragraph.py b/cognee/tasks/chunks/chunk_by_paragraph.py index 077db1cd4..7d7221b87 100644 --- a/cognee/tasks/chunks/chunk_by_paragraph.py +++ b/cognee/tasks/chunks/chunk_by_paragraph.py @@ -39,7 +39,7 @@ def chunk_by_paragraph( data, maximum_length=paragraph_length ): # Check if this sentence would exceed length limit - token_count = embedding_engine.tokenizer.num_tokens_from_text(sentence) + token_count = embedding_engine.tokenizer.count_tokens(sentence) if current_word_count > 0 and ( current_word_count + word_count > paragraph_length From 3db7f85c9cd3b36d717caaa346167ccfa6c0d0b5 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 28 Jan 2025 14:32:00 +0100 Subject: [PATCH 10/16] feat: Add max_chunk_tokens value to chunkers Add formula and forwarding of max_chunk_tokens value through Cognee --- .env.template | 1 + cognee/api/v1/cognify/cognify_v2.py | 15 ++++++++++++++- cognee/infrastructure/llm/anthropic/adapter.py | 3 ++- cognee/infrastructure/llm/config.py | 2 ++ .../infrastructure/llm/generic_llm_api/adapter.py | 4 +++- cognee/infrastructure/llm/get_llm_client.py | 15 ++++++++++++--- cognee/infrastructure/llm/openai/adapter.py | 2 ++ cognee/modules/chunking/TextChunker.py | 14 ++++---------- .../processing/document_types/AudioDocument.py | 6 ++++-- .../processing/document_types/ImageDocument.py | 6 ++++-- .../data/processing/document_types/PdfDocument.py | 6 ++++-- .../processing/document_types/TextDocument.py | 6 ++++-- .../document_types/UnstructuredDocument.py | 6 ++++-- cognee/tasks/chunks/chunk_by_paragraph.py | 12 +++++++----- .../documents/extract_chunks_from_documents.py | 5 ++++- 15 files changed, 71 insertions(+), 32 deletions(-) diff --git a/.env.template b/.env.template index 6ce9fedf9..ec6d01596 100644 --- a/.env.template +++ b/.env.template @@ -7,6 +7,7 @@ LLM_MODEL="openai/gpt-4o-mini" LLM_PROVIDER="openai" LLM_ENDPOINT="" LLM_API_VERSION="" +LLM_MAX_TOKENS="128000" GRAPHISTRY_USERNAME= GRAPHISTRY_PASSWORD= diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index 738f77c52..73504f057 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -4,6 +4,8 @@ from typing import Union from pydantic import BaseModel +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.modules.cognify.config import get_cognify_config from cognee.modules.data.methods import get_datasets, get_datasets_by_name from cognee.modules.data.methods.get_dataset_data import get_dataset_data @@ -146,12 +148,23 @@ async def get_default_tasks( if user is None: user = await get_default_user() + # Calculate max chunk size based on the following formula + embedding_engine = get_vector_engine().embedding_engine + llm_client = get_llm_client() + + # We need to make sure chunk size won't take more than half of LLM max context token size + # but it also can't be bigger than the embedding engine max token size + llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division + max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) + try: cognee_config = get_cognify_config() default_tasks = [ Task(classify_documents), Task(check_permissions_on_documents, user=user, permissions=["write"]), - Task(extract_chunks_from_documents), # Extract text chunks based on the document type. + Task( + extract_chunks_from_documents, max_chunk_tokens=max_chunk_tokens + ), # Extract text chunks based on the document type. Task( extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10} ), # Generate knowledge graphs from the document chunks. diff --git a/cognee/infrastructure/llm/anthropic/adapter.py b/cognee/infrastructure/llm/anthropic/adapter.py index 1fba732a0..cfeb15d25 100644 --- a/cognee/infrastructure/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/anthropic/adapter.py @@ -14,11 +14,12 @@ class AnthropicAdapter(LLMInterface): name = "Anthropic" model: str - def __init__(self, model: str = None): + def __init__(self, max_tokens: int, model: str = None): self.aclient = instructor.patch( create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_TOOLS ) self.model = model + self.max_tokens = max_tokens async def acreate_structured_output( self, text_input: str, system_prompt: str, response_model: Type[BaseModel] diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 67fc82683..00dff82b9 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -11,6 +11,7 @@ class LLMConfig(BaseSettings): llm_api_version: Optional[str] = None llm_temperature: float = 0.0 llm_streaming: bool = False + llm_max_tokens: int = 128000 transcription_model: str = "whisper-1" model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -24,6 +25,7 @@ class LLMConfig(BaseSettings): "api_version": self.llm_api_version, "temperature": self.llm_temperature, "streaming": self.llm_streaming, + "max_tokens": self.llm_max_tokens, "transcription_model": self.transcription_model, } diff --git a/cognee/infrastructure/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py index a910c0780..98002b2cb 100644 --- a/cognee/infrastructure/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/generic_llm_api/adapter.py @@ -2,6 +2,7 @@ import asyncio from typing import List, Type + from pydantic import BaseModel import instructor from cognee.infrastructure.llm.llm_interface import LLMInterface @@ -16,11 +17,12 @@ class GenericAPIAdapter(LLMInterface): model: str api_key: str - def __init__(self, endpoint, api_key: str, model: str, name: str): + def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int): self.name = name self.model = model self.api_key = api_key self.endpoint = endpoint + self.max_tokens = max_tokens llm_config = get_llm_config() diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index 0f64014e3..f601f48b2 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -32,6 +32,7 @@ def get_llm_client(): api_version=llm_config.llm_api_version, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, + max_tokens=llm_config.llm_max_tokens, streaming=llm_config.llm_streaming, ) @@ -42,13 +43,17 @@ def get_llm_client(): from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama" + llm_config.llm_endpoint, + llm_config.llm_api_key, + llm_config.llm_model, + "Ollama", + max_tokens=llm_config.llm_max_tokens, ) elif provider == LLMProvider.ANTHROPIC: from .anthropic.adapter import AnthropicAdapter - return AnthropicAdapter(llm_config.llm_model) + return AnthropicAdapter(max_tokens=llm_config.llm_max_tokens, model=llm_config.llm_model) elif provider == LLMProvider.CUSTOM: if llm_config.llm_api_key is None: @@ -57,7 +62,11 @@ def get_llm_client(): from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom" + llm_config.llm_endpoint, + llm_config.llm_api_key, + llm_config.llm_model, + "Custom", + max_tokens=llm_config.llm_max_tokens, ) else: diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index d45662380..d6939e323 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -32,6 +32,7 @@ class OpenAIAdapter(LLMInterface): api_version: str, model: str, transcription_model: str, + max_tokens: int, streaming: bool = False, ): self.aclient = instructor.from_litellm(litellm.acompletion) @@ -41,6 +42,7 @@ class OpenAIAdapter(LLMInterface): self.api_key = api_key self.endpoint = endpoint self.api_version = api_version + self.max_tokens = max_tokens self.streaming = streaming @observe(as_type="generation") diff --git a/cognee/modules/chunking/TextChunker.py b/cognee/modules/chunking/TextChunker.py index 4fa246272..f9a664d8b 100644 --- a/cognee/modules/chunking/TextChunker.py +++ b/cognee/modules/chunking/TextChunker.py @@ -14,22 +14,15 @@ class TextChunker: chunk_size = 0 token_count = 0 - def __init__(self, document, get_text: callable, chunk_size: int = 1024): + def __init__(self, document, get_text: callable, max_chunk_tokens: int, chunk_size: int = 1024): self.document = document self.max_chunk_size = chunk_size self.get_text = get_text + self.max_chunk_tokens = max_chunk_tokens def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data): word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size - - # Get embedding engine related to vector database - from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine - - embedding_engine = get_vector_engine().embedding_engine - - token_count_fits = ( - token_count_before + chunk_data["token_count"] <= embedding_engine.max_tokens - ) + token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_chunk_tokens return word_count_fits and token_count_fits def read(self): @@ -37,6 +30,7 @@ class TextChunker: for content_text in self.get_text(): for chunk_data in chunk_by_paragraph( content_text, + self.max_chunk_tokens, self.max_chunk_size, batch_paragraphs=True, ): diff --git a/cognee/modules/data/processing/document_types/AudioDocument.py b/cognee/modules/data/processing/document_types/AudioDocument.py index 77c700482..75152fd3d 100644 --- a/cognee/modules/data/processing/document_types/AudioDocument.py +++ b/cognee/modules/data/processing/document_types/AudioDocument.py @@ -13,12 +13,14 @@ class AudioDocument(Document): result = get_llm_client().create_transcript(self.raw_data_location) return result.text - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): # Transcribe the audio file text = self.create_transcript() chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/ImageDocument.py b/cognee/modules/data/processing/document_types/ImageDocument.py index 3b5c27d75..5f4cb287c 100644 --- a/cognee/modules/data/processing/document_types/ImageDocument.py +++ b/cognee/modules/data/processing/document_types/ImageDocument.py @@ -13,11 +13,13 @@ class ImageDocument(Document): result = get_llm_client().transcribe_image(self.raw_data_location) return result.choices[0].message.content - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): # Transcribe the image file text = self.transcribe_image() chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/PdfDocument.py b/cognee/modules/data/processing/document_types/PdfDocument.py index 0142610bf..8273e0177 100644 --- a/cognee/modules/data/processing/document_types/PdfDocument.py +++ b/cognee/modules/data/processing/document_types/PdfDocument.py @@ -9,7 +9,7 @@ from .Document import Document class PdfDocument(Document): type: str = "pdf" - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): file = PdfReader(self.raw_data_location) def get_text(): @@ -18,7 +18,9 @@ class PdfDocument(Document): yield page_text chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/TextDocument.py b/cognee/modules/data/processing/document_types/TextDocument.py index 692edcb89..6bef959de 100644 --- a/cognee/modules/data/processing/document_types/TextDocument.py +++ b/cognee/modules/data/processing/document_types/TextDocument.py @@ -7,7 +7,7 @@ from .Document import Document class TextDocument(Document): type: str = "text" - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int): def get_text(): with open(self.raw_data_location, mode="r", encoding="utf-8") as file: while True: @@ -20,6 +20,8 @@ class TextDocument(Document): chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/UnstructuredDocument.py b/cognee/modules/data/processing/document_types/UnstructuredDocument.py index b3849616f..254958d14 100644 --- a/cognee/modules/data/processing/document_types/UnstructuredDocument.py +++ b/cognee/modules/data/processing/document_types/UnstructuredDocument.py @@ -10,7 +10,7 @@ from .Document import Document class UnstructuredDocument(Document): type: str = "unstructured" - def read(self, chunk_size: int, chunker: str) -> str: + def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int) -> str: def get_text(): try: from unstructured.partition.auto import partition @@ -29,6 +29,8 @@ class UnstructuredDocument(Document): yield text - chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text) + chunker = TextChunker( + self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens + ) yield from chunker.read() diff --git a/cognee/tasks/chunks/chunk_by_paragraph.py b/cognee/tasks/chunks/chunk_by_paragraph.py index 7d7221b87..34205d9f6 100644 --- a/cognee/tasks/chunks/chunk_by_paragraph.py +++ b/cognee/tasks/chunks/chunk_by_paragraph.py @@ -4,13 +4,13 @@ from uuid import NAMESPACE_OID, uuid5 import tiktoken from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine from .chunk_by_sentence import chunk_by_sentence def chunk_by_paragraph( data: str, + max_chunk_tokens, paragraph_length: int = 1024, batch_paragraphs: bool = True, ) -> Iterator[Dict[str, Any]]: @@ -31,19 +31,21 @@ def chunk_by_paragraph( last_cut_type = None current_token_count = 0 - # Get vector and embedding engine vector_engine = get_vector_engine() - embedding_engine = vector_engine.embedding_engine + embedding_model = vector_engine.embedding_engine.model + embedding_model = embedding_model.split("/")[-1] for paragraph_id, sentence, word_count, end_type in chunk_by_sentence( data, maximum_length=paragraph_length ): # Check if this sentence would exceed length limit - token_count = embedding_engine.tokenizer.count_tokens(sentence) + + tokenizer = tiktoken.encoding_for_model(embedding_model) + token_count = len(tokenizer.encode(sentence)) if current_word_count > 0 and ( current_word_count + word_count > paragraph_length - or current_token_count + token_count > embedding_engine.max_tokens + or current_token_count + token_count > max_chunk_tokens ): # Yield current chunk chunk_dict = { diff --git a/cognee/tasks/documents/extract_chunks_from_documents.py b/cognee/tasks/documents/extract_chunks_from_documents.py index 3dc35ce57..4a089c7bc 100644 --- a/cognee/tasks/documents/extract_chunks_from_documents.py +++ b/cognee/tasks/documents/extract_chunks_from_documents.py @@ -5,6 +5,7 @@ from cognee.modules.data.processing.document_types.Document import Document async def extract_chunks_from_documents( documents: list[Document], + max_chunk_tokens: int, chunk_size: int = 1024, chunker="text_chunker", ) -> AsyncGenerator: @@ -16,5 +17,7 @@ async def extract_chunks_from_documents( - The `chunker` parameter determines the chunking logic and should align with the document type. """ for document in documents: - for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker): + for document_chunk in document.read( + chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens + ): yield document_chunk From 41544369af10756a3a76715ebb28206afdfcaab0 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 28 Jan 2025 14:47:17 +0100 Subject: [PATCH 11/16] test: Change test_by_paragraph tests to accomodate to change --- .../chunks/chunk_by_paragraph_2_test.py | 56 +++++++++++++++---- .../chunks/chunk_by_paragraph_test.py | 2 +- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py index d8680a604..5555a7dc9 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py @@ -8,14 +8,24 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS paragraph_lengths = [64, 256, 1024] batch_paragraphs_vals = [True, False] +max_chunk_tokens_vals = [512, 1024, 4096] @pytest.mark.parametrize( - "input_text,paragraph_length,batch_paragraphs", - list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), + "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs", + list( + product( + list(INPUT_TEXTS.values()), + max_chunk_tokens_vals, + paragraph_lengths, + batch_paragraphs_vals, + ) + ), ) -def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs): - chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs) +def test_chunk_by_paragraph_isomorphism( + input_text, max_chunk_tokens, paragraph_length, batch_paragraphs +): + chunks = chunk_by_paragraph(input_text, max_chunk_tokens, paragraph_length, batch_paragraphs) reconstructed_text = "".join([chunk["text"] for chunk in chunks]) assert reconstructed_text == input_text, ( f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" @@ -23,13 +33,23 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para @pytest.mark.parametrize( - "input_text,paragraph_length,batch_paragraphs", - list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), + "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs", + list( + product( + list(INPUT_TEXTS.values()), + max_chunk_tokens_vals, + paragraph_lengths, + batch_paragraphs_vals, + ) + ), ) -def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): +def test_paragraph_chunk_length(input_text, max_chunk_tokens, paragraph_length, batch_paragraphs): chunks = list( chunk_by_paragraph( - data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs + data=input_text, + max_chunk_tokens=max_chunk_tokens, + paragraph_length=paragraph_length, + batch_paragraphs=batch_paragraphs, ) ) @@ -42,12 +62,24 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): @pytest.mark.parametrize( - "input_text,paragraph_length,batch_paragraphs", - list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), + "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs", + list( + product( + list(INPUT_TEXTS.values()), + max_chunk_tokens_vals, + paragraph_lengths, + batch_paragraphs_vals, + ) + ), ) -def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs): +def test_chunk_by_paragraph_chunk_numbering( + input_text, max_chunk_tokens, paragraph_length, batch_paragraphs +): chunks = chunk_by_paragraph( - data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs + data=input_text, + max_chunk_tokens=max_chunk_tokens, + paragraph_length=paragraph_length, + batch_paragraphs=batch_paragraphs, ) chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( diff --git a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py index e420b2e9f..ed706830e 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py @@ -50,7 +50,7 @@ Third paragraph is cut and is missing the dot at the end""", def run_chunking_test(test_text, expected_chunks): chunks = [] for chunk_data in chunk_by_paragraph( - data=test_text, paragraph_length=12, batch_paragraphs=False + data=test_text, paragraph_length=12, batch_paragraphs=False, max_chunk_tokens=512 ): chunks.append(chunk_data) From dc0450d30e0fae34e8065a256ef72c331a349341 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 28 Jan 2025 15:21:43 +0100 Subject: [PATCH 12/16] test: Update document tests regrading max chunk tokens --- .../integration/documents/AudioDocument_test.py | 2 +- .../integration/documents/ImageDocument_test.py | 2 +- .../integration/documents/PdfDocument_test.py | 2 +- .../integration/documents/TextDocument_test.py | 3 ++- .../documents/UnstructuredDocument_test.py | 16 ++++++++++++---- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/cognee/tests/integration/documents/AudioDocument_test.py b/cognee/tests/integration/documents/AudioDocument_test.py index 9719d90fc..38b547140 100644 --- a/cognee/tests/integration/documents/AudioDocument_test.py +++ b/cognee/tests/integration/documents/AudioDocument_test.py @@ -34,7 +34,7 @@ def test_AudioDocument(): ) with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") + GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512) ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/ImageDocument_test.py b/cognee/tests/integration/documents/ImageDocument_test.py index bd15961ee..faa54fa27 100644 --- a/cognee/tests/integration/documents/ImageDocument_test.py +++ b/cognee/tests/integration/documents/ImageDocument_test.py @@ -23,7 +23,7 @@ def test_ImageDocument(): ) with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") + GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512) ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/PdfDocument_test.py b/cognee/tests/integration/documents/PdfDocument_test.py index 82d304b6c..e9530fc12 100644 --- a/cognee/tests/integration/documents/PdfDocument_test.py +++ b/cognee/tests/integration/documents/PdfDocument_test.py @@ -25,7 +25,7 @@ def test_PdfDocument(): ) for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker") + GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker", max_chunk_tokens=2048) ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/TextDocument_test.py b/cognee/tests/integration/documents/TextDocument_test.py index 17db39be8..99e28a3ac 100644 --- a/cognee/tests/integration/documents/TextDocument_test.py +++ b/cognee/tests/integration/documents/TextDocument_test.py @@ -37,7 +37,8 @@ def test_TextDocument(input_file, chunk_size): ) for ground_truth, paragraph_data in zip( - GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker") + GROUND_TRUTH[input_file], + document.read(chunk_size=chunk_size, chunker="text_chunker", max_chunk_tokens=1024), ): assert ground_truth["word_count"] == paragraph_data.word_count, ( f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' diff --git a/cognee/tests/integration/documents/UnstructuredDocument_test.py b/cognee/tests/integration/documents/UnstructuredDocument_test.py index 81e804f07..d76843c0a 100644 --- a/cognee/tests/integration/documents/UnstructuredDocument_test.py +++ b/cognee/tests/integration/documents/UnstructuredDocument_test.py @@ -68,7 +68,9 @@ def test_UnstructuredDocument(): ) # Test PPTX - for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"): + for paragraph_data in pptx_document.read( + chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + ): assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" assert "sentence_cut" == paragraph_data.cut_type, ( @@ -76,7 +78,9 @@ def test_UnstructuredDocument(): ) # Test DOCX - for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"): + for paragraph_data in docx_document.read( + chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + ): assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" assert "sentence_end" == paragraph_data.cut_type, ( @@ -84,7 +88,9 @@ def test_UnstructuredDocument(): ) # TEST CSV - for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"): + for paragraph_data in csv_document.read( + chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + ): assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }" assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, ( f"Read text doesn't match expected text: {paragraph_data.text}" @@ -94,7 +100,9 @@ def test_UnstructuredDocument(): ) # Test XLSX - for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"): + for paragraph_data in xlsx_document.read( + chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024 + ): assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }" assert "sentence_cut" == paragraph_data.cut_type, ( From 4e56cd64a1ab6cef90e20c8f2fd20f5b25d098ce Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 28 Jan 2025 15:33:34 +0100 Subject: [PATCH 13/16] refactor: Add max chunk tokens to code graph pipeline --- cognee/api/v1/cognify/code_graph_pipeline.py | 3 ++- cognee/api/v1/cognify/cognify_v2.py | 14 ++------------ cognee/infrastructure/llm/__init__.py | 1 + cognee/infrastructure/llm/utils.py | 15 +++++++++++++++ 4 files changed, 20 insertions(+), 13 deletions(-) create mode 100644 cognee/infrastructure/llm/utils.py diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 4a864eb0e..125245f40 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -21,6 +21,7 @@ from cognee.tasks.repo_processor import ( from cognee.tasks.repo_processor.get_source_code_chunks import get_source_code_chunks from cognee.tasks.storage import add_data_points from cognee.tasks.summarization import summarize_code, summarize_text +from cognee.infrastructure.llm import get_max_chunk_tokens monitoring = get_base_config().monitoring_tool if monitoring == MonitoringTool.LANGFUSE: @@ -71,7 +72,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True): Task(ingest_data, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(classify_documents), - Task(extract_chunks_from_documents), + Task(extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()), Task( extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50} ), diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index 73504f057..12a84030d 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -4,8 +4,7 @@ from typing import Union from pydantic import BaseModel -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm import get_max_chunk_tokens from cognee.modules.cognify.config import get_cognify_config from cognee.modules.data.methods import get_datasets, get_datasets_by_name from cognee.modules.data.methods.get_dataset_data import get_dataset_data @@ -148,22 +147,13 @@ async def get_default_tasks( if user is None: user = await get_default_user() - # Calculate max chunk size based on the following formula - embedding_engine = get_vector_engine().embedding_engine - llm_client = get_llm_client() - - # We need to make sure chunk size won't take more than half of LLM max context token size - # but it also can't be bigger than the embedding engine max token size - llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division - max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) - try: cognee_config = get_cognify_config() default_tasks = [ Task(classify_documents), Task(check_permissions_on_documents, user=user, permissions=["write"]), Task( - extract_chunks_from_documents, max_chunk_tokens=max_chunk_tokens + extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens() ), # Extract text chunks based on the document type. Task( extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10} diff --git a/cognee/infrastructure/llm/__init__.py b/cognee/infrastructure/llm/__init__.py index 7fb3be736..36d7e56ad 100644 --- a/cognee/infrastructure/llm/__init__.py +++ b/cognee/infrastructure/llm/__init__.py @@ -1 +1,2 @@ from .config import get_llm_config +from .utils import get_max_chunk_tokens diff --git a/cognee/infrastructure/llm/utils.py b/cognee/infrastructure/llm/utils.py new file mode 100644 index 000000000..816eaf285 --- /dev/null +++ b/cognee/infrastructure/llm/utils.py @@ -0,0 +1,15 @@ +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.llm.get_llm_client import get_llm_client + + +def get_max_chunk_tokens(): + # Calculate max chunk size based on the following formula + embedding_engine = get_vector_engine().embedding_engine + llm_client = get_llm_client() + + # We need to make sure chunk size won't take more than half of LLM max context token size + # but it also can't be bigger than the embedding engine max token size + llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division + max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) + + return max_chunk_tokens From 3e29c3d8f2369cc168209a54803634c734db5503 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 28 Jan 2025 15:38:38 +0100 Subject: [PATCH 14/16] docs: Update notebook to work with changes to max chunk tokens --- notebooks/cognee_demo.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/notebooks/cognee_demo.ipynb b/notebooks/cognee_demo.ipynb index 6540d19c5..a90c02958 100644 --- a/notebooks/cognee_demo.ipynb +++ b/notebooks/cognee_demo.ipynb @@ -650,6 +650,7 @@ "from cognee.modules.pipelines import run_tasks\n", "from cognee.modules.users.models import User\n", "from cognee.tasks.documents import check_permissions_on_documents, classify_documents, extract_chunks_from_documents\n", + "from cognee.infrastructure.llm import get_max_chunk_tokens\n", "from cognee.tasks.graph import extract_graph_from_data\n", "from cognee.tasks.storage import add_data_points\n", "from cognee.tasks.summarization import summarize_text\n", @@ -663,7 +664,7 @@ " tasks = [\n", " Task(classify_documents),\n", " Task(check_permissions_on_documents, user = user, permissions = [\"write\"]),\n", - " Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n", + " Task(extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()), # Extract text chunks based on the document type.\n", " Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { \"batch_size\": 10 }), # Generate knowledge graphs from the document chunks.\n", " Task(\n", " summarize_text,\n", From a8644e0bd75d8b6f51e291ff6e05bfd5b12dccbe Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 28 Jan 2025 17:00:47 +0100 Subject: [PATCH 15/16] feat: Use litellm max token size as default for model, if model exists in litellm --- .env.template | 2 +- cognee/infrastructure/llm/config.py | 2 +- cognee/infrastructure/llm/get_llm_client.py | 18 ++++++++++++---- cognee/infrastructure/llm/utils.py | 23 +++++++++++++++++++++ 4 files changed, 39 insertions(+), 6 deletions(-) diff --git a/.env.template b/.env.template index ec6d01596..df8408518 100644 --- a/.env.template +++ b/.env.template @@ -7,7 +7,7 @@ LLM_MODEL="openai/gpt-4o-mini" LLM_PROVIDER="openai" LLM_ENDPOINT="" LLM_API_VERSION="" -LLM_MAX_TOKENS="128000" +LLM_MAX_TOKENS="16384" GRAPHISTRY_USERNAME= GRAPHISTRY_PASSWORD= diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 00dff82b9..48c94423e 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -11,7 +11,7 @@ class LLMConfig(BaseSettings): llm_api_version: Optional[str] = None llm_temperature: float = 0.0 llm_streaming: bool = False - llm_max_tokens: int = 128000 + llm_max_tokens: int = 16384 transcription_model: str = "whisper-1" model_config = SettingsConfigDict(env_file=".env", extra="allow") diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index f601f48b2..383355fd2 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -20,6 +20,16 @@ def get_llm_client(): provider = LLMProvider(llm_config.llm_provider) + # Check if max_token value is defined in liteLLM for given model + # if not use value from cognee configuration + from cognee.infrastructure.llm.utils import get_model_max_tokens + + max_tokens = ( + get_model_max_tokens(llm_config.llm_model) + if get_model_max_tokens(llm_config.llm_model) + else llm_config.llm_max_tokens + ) + if provider == LLMProvider.OPENAI: if llm_config.llm_api_key is None: raise InvalidValueError(message="LLM API key is not set.") @@ -32,7 +42,7 @@ def get_llm_client(): api_version=llm_config.llm_api_version, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, - max_tokens=llm_config.llm_max_tokens, + max_tokens=max_tokens, streaming=llm_config.llm_streaming, ) @@ -47,13 +57,13 @@ def get_llm_client(): llm_config.llm_api_key, llm_config.llm_model, "Ollama", - max_tokens=llm_config.llm_max_tokens, + max_tokens=max_tokens, ) elif provider == LLMProvider.ANTHROPIC: from .anthropic.adapter import AnthropicAdapter - return AnthropicAdapter(max_tokens=llm_config.llm_max_tokens, model=llm_config.llm_model) + return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model) elif provider == LLMProvider.CUSTOM: if llm_config.llm_api_key is None: @@ -66,7 +76,7 @@ def get_llm_client(): llm_config.llm_api_key, llm_config.llm_model, "Custom", - max_tokens=llm_config.llm_max_tokens, + max_tokens=max_tokens, ) else: diff --git a/cognee/infrastructure/llm/utils.py b/cognee/infrastructure/llm/utils.py index 816eaf285..e0aa8945a 100644 --- a/cognee/infrastructure/llm/utils.py +++ b/cognee/infrastructure/llm/utils.py @@ -1,6 +1,11 @@ +import logging +import litellm + from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.llm.get_llm_client import get_llm_client +logger = logging.getLogger(__name__) + def get_max_chunk_tokens(): # Calculate max chunk size based on the following formula @@ -13,3 +18,21 @@ def get_max_chunk_tokens(): max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) return max_chunk_tokens + + +def get_model_max_tokens(model_name: str): + """ + Args: + model_name: name of LLM or embedding model + + Returns: Number of max tokens of model, or None if model is unknown + """ + max_tokens = None + + if model_name in litellm.model_cost: + max_tokens = litellm.model_cost[model_name]["max_tokens"] + logger.debug(f"Max input tokens for {model_name}: {max_tokens}") + else: + logger.info("Model not found in LiteLLM's model_cost.") + + return max_tokens From 860218632ffeee3b950f0c669f8c04285d33ac49 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 28 Jan 2025 17:15:25 +0100 Subject: [PATCH 16/16] refactor: add suggestions from PR Add suggestsions made by CodeRabbit on pull request --- cognee/infrastructure/llm/get_llm_client.py | 11 +++++------ .../data/processing/document_types/Document.py | 2 +- cognee/tasks/repo_processor/get_source_code_chunks.py | 2 -- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index 383355fd2..ede8bd330 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -22,13 +22,12 @@ def get_llm_client(): # Check if max_token value is defined in liteLLM for given model # if not use value from cognee configuration - from cognee.infrastructure.llm.utils import get_model_max_tokens + from cognee.infrastructure.llm.utils import ( + get_model_max_tokens, + ) # imported here to avoid circular imports - max_tokens = ( - get_model_max_tokens(llm_config.llm_model) - if get_model_max_tokens(llm_config.llm_model) - else llm_config.llm_max_tokens - ) + model_max_tokens = get_model_max_tokens(llm_config.llm_model) + max_tokens = model_max_tokens if model_max_tokens else llm_config.llm_max_tokens if provider == LLMProvider.OPENAI: if llm_config.llm_api_key is None: diff --git a/cognee/modules/data/processing/document_types/Document.py b/cognee/modules/data/processing/document_types/Document.py index 80ba8e428..76ff1e045 100644 --- a/cognee/modules/data/processing/document_types/Document.py +++ b/cognee/modules/data/processing/document_types/Document.py @@ -11,5 +11,5 @@ class Document(DataPoint): mime_type: str _metadata: dict = {"index_fields": ["name"], "type": "Document"} - def read(self, chunk_size: int, chunker=str) -> str: + def read(self, chunk_size: int, chunker=str, max_chunk_tokens: Optional[int] = None) -> str: pass diff --git a/cognee/tasks/repo_processor/get_source_code_chunks.py b/cognee/tasks/repo_processor/get_source_code_chunks.py index 358d5218d..ca1c76e46 100644 --- a/cognee/tasks/repo_processor/get_source_code_chunks.py +++ b/cognee/tasks/repo_processor/get_source_code_chunks.py @@ -97,8 +97,6 @@ def _get_chunk_source_code( current_source_code = "" # Get embedding engine used in vector database - from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine - embedding_engine = get_vector_engine().embedding_engine for i, (child_code, token_count) in enumerate(code_token_counts):