fix: Initial commit to resolve issue with using tokenizer based on LLMs

Currently TikToken is used for tokenizing by default which is only supported by OpenAI,
this is an initial commit in an attempt to add Cognee tokenizing support for multiple LLMs
This commit is contained in:
Igor Ilic 2025-01-21 19:53:22 +01:00
parent 77f0b45a0d
commit 93249c72c5
22 changed files with 176 additions and 84 deletions

View file

@ -71,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(ingest_data, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents),
Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens),
Task(extract_chunks_from_documents),
Task(
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
),

View file

@ -6,6 +6,9 @@ import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
from transformers import AutoTokenizer
import tiktoken # Assuming this is how you import TikToken
litellm.set_verbose = False
logger = logging.getLogger("LiteLLMEmbeddingEngine")
@ -15,23 +18,30 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str
endpoint: str
api_version: str
provider: str
model: str
dimensions: int
mock: bool
def __init__(
self,
provider: str = "openai",
model: Optional[str] = "text-embedding-3-large",
dimensions: Optional[int] = 3072,
api_key: str = None,
endpoint: str = None,
api_version: str = None,
max_tokens: int = float("inf"),
):
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
# TODO: Add or remove provider info
self.provider = provider
self.model = model
self.dimensions = dimensions
self.max_tokens = max_tokens
self.tokenizer = self.set_tokenizer()
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
@ -104,3 +114,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
def get_vector_size(self) -> int:
return self.dimensions
def set_tokenizer(self):
logger.debug(f"Loading tokenizer for model {self.model}...")
# If model also contains provider information, extract only model information
model = self.model.split("/")[-1]
if "openai" in self.provider.lower() or "gpt" in self.model:
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
else:
tokenizer = AutoTokenizer.from_pretrained(self.model)
logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer

View file

@ -9,7 +9,7 @@ class EmbeddingConfig(BaseSettings):
embedding_endpoint: Optional[str] = None
embedding_api_key: Optional[str] = None
embedding_api_version: Optional[str] = None
embedding_max_tokens: Optional[int] = float("inf")
model_config = SettingsConfigDict(env_file=".env", extra="allow")

View file

@ -15,4 +15,5 @@ def get_embedding_engine() -> EmbeddingEngine:
api_version=config.embedding_api_version,
model=config.embedding_model,
dimensions=config.embedding_dimensions,
max_tokens=config.embedding_max_tokens,
)

View file

@ -0,0 +1 @@
from .adapter import HuggingFaceTokenizer

View file

@ -0,0 +1,22 @@
from typing import List, Any
from ..tokenizer_interface import TokenizerInterface
class HuggingFaceTokenizer(TokenizerInterface):
def __init__(
self,
model: str,
max_tokens: int = float("inf"),
):
self.model = model
self.max_tokens = max_tokens
def extract_tokens(self, text: str) -> List[Any]:
raise NotImplementedError
def num_tokens_from_text(self, text: str) -> int:
raise NotImplementedError
def trim_text_to_max_tokens(self, text: str) -> str:
raise NotImplementedError

View file

@ -0,0 +1 @@
from .adapter import TikTokenTokenizer

View file

@ -0,0 +1,69 @@
from typing import List, Any
import tiktoken
from ..tokenizer_interface import TokenizerInterface
class TikTokenTokenizer(TokenizerInterface):
"""
Tokenizer adapter for OpenAI.
Inteded to be used as part of LLM Embedding and LLM Adapters classes
"""
def __init__(
self,
model: str,
max_tokens: int = float("inf"),
):
self.model = model
self.max_tokens = max_tokens
# Initialize TikToken for GPT based on model
self.tokenizer = tiktoken.encoding_for_model(self.model)
def extract_tokens(self, text: str) -> List[Any]:
tokens = []
# Using TikToken's method to tokenize text
token_ids = self.tokenizer.encode(text)
# Go through tokens and decode them to text value
for token_id in token_ids:
token = self.tokenizer.decode([token_id])
tokens.append(token)
return tokens
def num_tokens_from_text(self, text: str) -> int:
"""
Returns the number of tokens in the given text.
Args:
text: str
Returns:
number of tokens in the given text
"""
num_tokens = len(self.tokenizer.encode(text))
return num_tokens
def trim_text_to_max_tokens(self, text: str) -> str:
"""
Trims the text so that the number of tokens does not exceed max_tokens.
Args:
text (str): Original text string to be trimmed.
Returns:
str: Trimmed version of text or original text if under the limit.
"""
# First check the number of tokens
num_tokens = self.num_tokens_from_string(text)
# If the number of tokens is within the limit, return the text as is
if num_tokens <= self.max_tokens:
return text
# If the number exceeds the limit, trim the text
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
encoded_text = self.tokenizer.encode(text)
trimmed_encoded_text = encoded_text[: self.max_tokens]
# Decoding the trimmed text
trimmed_text = self.tokenizer.decode(trimmed_encoded_text)
return trimmed_text

View file

@ -0,0 +1 @@
from .tokenizer_interface import TokenizerInterface

View file

@ -0,0 +1,18 @@
from typing import List, Protocol, Any
from abc import abstractmethod
class TokenizerInterface(Protocol):
"""Tokenizer interface"""
@abstractmethod
def extract_tokens(self, text: str) -> List[Any]:
raise NotImplementedError
@abstractmethod
def num_tokens_from_text(self, text: str) -> int:
raise NotImplementedError
@abstractmethod
def trim_text_to_max_tokens(self, text: str) -> str:
raise NotImplementedError

View file

@ -14,17 +14,22 @@ class TextChunker:
chunk_size = 0
token_count = 0
def __init__(
self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024
):
def __init__(self, document, get_text: callable, chunk_size: int = 1024):
self.document = document
self.max_chunk_size = chunk_size
self.get_text = get_text
self.max_tokens = max_tokens if max_tokens else float("inf")
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens
# Get embedding engine related to vector database
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
embedding_engine = get_vector_engine().embedding_engine
token_count_fits = (
token_count_before + chunk_data["token_count"] <= embedding_engine.max_tokens
)
return word_count_fits and token_count_fits
def read(self):
@ -32,7 +37,6 @@ class TextChunker:
for content_text in self.get_text():
for chunk_data in chunk_by_paragraph(
content_text,
self.max_tokens,
self.max_chunk_size,
batch_paragraphs=True,
):

View file

@ -8,7 +8,6 @@ import os
class CognifyConfig(BaseSettings):
classification_model: object = DefaultContentPrediction
summarization_model: object = SummarizedContent
max_tokens: Optional[int] = os.getenv("MAX_TOKENS")
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:

View file

@ -13,14 +13,12 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location)
return result.text
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
def read(self, chunk_size: int, chunker: str):
# Transcribe the audio file
text = self.create_transcript()
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
yield from chunker.read()

View file

@ -11,5 +11,5 @@ class Document(DataPoint):
mime_type: str
_metadata: dict = {"index_fields": ["name"], "type": "Document"}
def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str:
def read(self, chunk_size: int, chunker=str) -> str:
pass

View file

@ -13,13 +13,11 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location)
return result.choices[0].message.content
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
def read(self, chunk_size: int, chunker: str):
# Transcribe the image file
text = self.transcribe_image()
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
yield from chunker.read()

View file

@ -9,7 +9,7 @@ from .Document import Document
class PdfDocument(Document):
type: str = "pdf"
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
def read(self, chunk_size: int, chunker: str):
file = PdfReader(self.raw_data_location)
def get_text():
@ -18,9 +18,7 @@ class PdfDocument(Document):
yield page_text
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
yield from chunker.read()

View file

@ -7,7 +7,7 @@ from .Document import Document
class TextDocument(Document):
type: str = "text"
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
def read(self, chunk_size: int, chunker: str):
def get_text():
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
while True:
@ -20,8 +20,6 @@ class TextDocument(Document):
chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
yield from chunker.read()

View file

@ -10,7 +10,7 @@ from .Document import Document
class UnstructuredDocument(Document):
type: str = "unstructured"
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str:
def read(self, chunk_size: int, chunker: str) -> str:
def get_text():
try:
from unstructured.partition.auto import partition
@ -29,6 +29,6 @@ class UnstructuredDocument(Document):
yield text
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens)
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text)
yield from chunker.read()

View file

@ -10,8 +10,6 @@ import graphistry
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import tiktoken
import time
import logging
import sys
@ -100,15 +98,6 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
print(f"Error sending telemetry through proxy: {response.status_code}")
def num_tokens_from_string(string: str, encoding_name: str) -> int:
"""Returns the number of tokens in a text string."""
# tiktoken.get_encoding("cl100k_base")
encoding = tiktoken.encoding_for_model(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
h = hashlib.md5()
@ -134,34 +123,6 @@ def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
raise IngestionError(message=f"Failed to load data from {file}: {e}")
def trim_text_to_max_tokens(text: str, max_tokens: int, encoding_name: str) -> str:
"""
Trims the text so that the number of tokens does not exceed max_tokens.
Args:
text (str): Original text string to be trimmed.
max_tokens (int): Maximum number of tokens allowed.
encoding_name (str): The name of the token encoding to use.
Returns:
str: Trimmed version of text or original text if under the limit.
"""
# First check the number of tokens
num_tokens = num_tokens_from_string(text, encoding_name)
# If the number of tokens is within the limit, return the text as is
if num_tokens <= max_tokens:
return text
# If the number exceeds the limit, trim the text
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
encoded_text = tiktoken.get_encoding(encoding_name).encode(text)
trimmed_encoded_text = encoded_text[:max_tokens]
# Decoding the trimmed text
trimmed_text = tiktoken.get_encoding(encoding_name).decode(trimmed_encoded_text)
return trimmed_text
def generate_color_palette(unique_layers):
colormap = plt.cm.get_cmap("viridis", len(unique_layers))
colors = [colormap(i) for i in range(len(unique_layers))]

View file

@ -4,13 +4,13 @@ from uuid import NAMESPACE_OID, uuid5
import tiktoken
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from .chunk_by_sentence import chunk_by_sentence
def chunk_by_paragraph(
data: str,
max_tokens: Optional[Union[int, float]] = None,
paragraph_length: int = 1024,
batch_paragraphs: bool = True,
) -> Iterator[Dict[str, Any]]:
@ -24,24 +24,22 @@ def chunk_by_paragraph(
paragraph_ids = []
last_cut_type = None
current_token_count = 0
if not max_tokens:
max_tokens = float("inf")
# Get vector and embedding engine
vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
embedding_model = embedding_model.split("/")[-1]
embedding_engine = vector_engine.embedding_engine
# embedding_model = embedding_engine.model.split("/")[-1]
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
data, maximum_length=paragraph_length
):
# Check if this sentence would exceed length limit
tokenizer = tiktoken.encoding_for_model(embedding_model)
token_count = len(tokenizer.encode(sentence))
token_count = embedding_engine.tokenizer.num_tokens_from_text(sentence)
if current_word_count > 0 and (
current_word_count + word_count > paragraph_length
or current_token_count + token_count > max_tokens
or current_token_count + token_count > embedding_engine.max_tokens
):
# Yield current chunk
chunk_dict = {

View file

@ -7,10 +7,7 @@ async def extract_chunks_from_documents(
documents: list[Document],
chunk_size: int = 1024,
chunker="text_chunker",
max_tokens: Optional[int] = None,
):
for document in documents:
for document_chunk in document.read(
chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens
):
for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker):
yield document_chunk

View file

@ -89,26 +89,31 @@ def _get_subchunk_token_counts(
def _get_chunk_source_code(
code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int
code_token_counts: list[tuple[str, int]], overlap: float
) -> tuple[list[tuple[str, int]], str]:
"""Generates a chunk of source code from tokenized subchunks with overlap handling."""
current_count = 0
cumulative_counts = []
current_source_code = ""
# Get embedding engine used in vector database
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
embedding_engine = get_vector_engine().embedding_engine
for i, (child_code, token_count) in enumerate(code_token_counts):
current_count += token_count
cumulative_counts.append(current_count)
if current_count > max_tokens:
if current_count > embedding_engine.max_tokens:
break
current_source_code += f"\n{child_code}"
if current_count <= max_tokens:
if current_count <= embedding_engine.max_tokens:
return [], current_source_code.strip()
cutoff = 1
for i, cum_count in enumerate(cumulative_counts):
if cum_count > (1 - overlap) * max_tokens:
if cum_count > (1 - overlap) * embedding_engine.max_tokens:
break
cutoff = i