fix: Initial commit to resolve issue with using tokenizer based on LLMs

Currently TikToken is used for tokenizing by default which is only supported by OpenAI,
this is an initial commit in an attempt to add Cognee tokenizing support for multiple LLMs
This commit is contained in:
Igor Ilic 2025-01-21 19:53:22 +01:00
parent 77f0b45a0d
commit 93249c72c5
22 changed files with 176 additions and 84 deletions

View file

@ -71,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(ingest_data, dataset_name="repo_docs", user=user), Task(ingest_data, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents), Task(classify_documents),
Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens), Task(extract_chunks_from_documents),
Task( Task(
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50} extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
), ),

View file

@ -6,6 +6,9 @@ import litellm
import os import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
from transformers import AutoTokenizer
import tiktoken # Assuming this is how you import TikToken
litellm.set_verbose = False litellm.set_verbose = False
logger = logging.getLogger("LiteLLMEmbeddingEngine") logger = logging.getLogger("LiteLLMEmbeddingEngine")
@ -15,23 +18,30 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str api_key: str
endpoint: str endpoint: str
api_version: str api_version: str
provider: str
model: str model: str
dimensions: int dimensions: int
mock: bool mock: bool
def __init__( def __init__(
self, self,
provider: str = "openai",
model: Optional[str] = "text-embedding-3-large", model: Optional[str] = "text-embedding-3-large",
dimensions: Optional[int] = 3072, dimensions: Optional[int] = 3072,
api_key: str = None, api_key: str = None,
endpoint: str = None, endpoint: str = None,
api_version: str = None, api_version: str = None,
max_tokens: int = float("inf"),
): ):
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.api_version = api_version self.api_version = api_version
# TODO: Add or remove provider info
self.provider = provider
self.model = model self.model = model
self.dimensions = dimensions self.dimensions = dimensions
self.max_tokens = max_tokens
self.tokenizer = self.set_tokenizer()
enable_mocking = os.getenv("MOCK_EMBEDDING", "false") enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool): if isinstance(enable_mocking, bool):
@ -104,3 +114,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
def get_vector_size(self) -> int: def get_vector_size(self) -> int:
return self.dimensions return self.dimensions
def set_tokenizer(self):
logger.debug(f"Loading tokenizer for model {self.model}...")
# If model also contains provider information, extract only model information
model = self.model.split("/")[-1]
if "openai" in self.provider.lower() or "gpt" in self.model:
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
else:
tokenizer = AutoTokenizer.from_pretrained(self.model)
logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer

View file

@ -9,7 +9,7 @@ class EmbeddingConfig(BaseSettings):
embedding_endpoint: Optional[str] = None embedding_endpoint: Optional[str] = None
embedding_api_key: Optional[str] = None embedding_api_key: Optional[str] = None
embedding_api_version: Optional[str] = None embedding_api_version: Optional[str] = None
embedding_max_tokens: Optional[int] = float("inf")
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")

View file

@ -15,4 +15,5 @@ def get_embedding_engine() -> EmbeddingEngine:
api_version=config.embedding_api_version, api_version=config.embedding_api_version,
model=config.embedding_model, model=config.embedding_model,
dimensions=config.embedding_dimensions, dimensions=config.embedding_dimensions,
max_tokens=config.embedding_max_tokens,
) )

View file

@ -0,0 +1 @@
from .adapter import HuggingFaceTokenizer

View file

@ -0,0 +1,22 @@
from typing import List, Any
from ..tokenizer_interface import TokenizerInterface
class HuggingFaceTokenizer(TokenizerInterface):
def __init__(
self,
model: str,
max_tokens: int = float("inf"),
):
self.model = model
self.max_tokens = max_tokens
def extract_tokens(self, text: str) -> List[Any]:
raise NotImplementedError
def num_tokens_from_text(self, text: str) -> int:
raise NotImplementedError
def trim_text_to_max_tokens(self, text: str) -> str:
raise NotImplementedError

View file

@ -0,0 +1 @@
from .adapter import TikTokenTokenizer

View file

@ -0,0 +1,69 @@
from typing import List, Any
import tiktoken
from ..tokenizer_interface import TokenizerInterface
class TikTokenTokenizer(TokenizerInterface):
"""
Tokenizer adapter for OpenAI.
Inteded to be used as part of LLM Embedding and LLM Adapters classes
"""
def __init__(
self,
model: str,
max_tokens: int = float("inf"),
):
self.model = model
self.max_tokens = max_tokens
# Initialize TikToken for GPT based on model
self.tokenizer = tiktoken.encoding_for_model(self.model)
def extract_tokens(self, text: str) -> List[Any]:
tokens = []
# Using TikToken's method to tokenize text
token_ids = self.tokenizer.encode(text)
# Go through tokens and decode them to text value
for token_id in token_ids:
token = self.tokenizer.decode([token_id])
tokens.append(token)
return tokens
def num_tokens_from_text(self, text: str) -> int:
"""
Returns the number of tokens in the given text.
Args:
text: str
Returns:
number of tokens in the given text
"""
num_tokens = len(self.tokenizer.encode(text))
return num_tokens
def trim_text_to_max_tokens(self, text: str) -> str:
"""
Trims the text so that the number of tokens does not exceed max_tokens.
Args:
text (str): Original text string to be trimmed.
Returns:
str: Trimmed version of text or original text if under the limit.
"""
# First check the number of tokens
num_tokens = self.num_tokens_from_string(text)
# If the number of tokens is within the limit, return the text as is
if num_tokens <= self.max_tokens:
return text
# If the number exceeds the limit, trim the text
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
encoded_text = self.tokenizer.encode(text)
trimmed_encoded_text = encoded_text[: self.max_tokens]
# Decoding the trimmed text
trimmed_text = self.tokenizer.decode(trimmed_encoded_text)
return trimmed_text

View file

@ -0,0 +1 @@
from .tokenizer_interface import TokenizerInterface

View file

@ -0,0 +1,18 @@
from typing import List, Protocol, Any
from abc import abstractmethod
class TokenizerInterface(Protocol):
"""Tokenizer interface"""
@abstractmethod
def extract_tokens(self, text: str) -> List[Any]:
raise NotImplementedError
@abstractmethod
def num_tokens_from_text(self, text: str) -> int:
raise NotImplementedError
@abstractmethod
def trim_text_to_max_tokens(self, text: str) -> str:
raise NotImplementedError

View file

@ -14,17 +14,22 @@ class TextChunker:
chunk_size = 0 chunk_size = 0
token_count = 0 token_count = 0
def __init__( def __init__(self, document, get_text: callable, chunk_size: int = 1024):
self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024
):
self.document = document self.document = document
self.max_chunk_size = chunk_size self.max_chunk_size = chunk_size
self.get_text = get_text self.get_text = get_text
self.max_tokens = max_tokens if max_tokens else float("inf")
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data): def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens
# Get embedding engine related to vector database
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
embedding_engine = get_vector_engine().embedding_engine
token_count_fits = (
token_count_before + chunk_data["token_count"] <= embedding_engine.max_tokens
)
return word_count_fits and token_count_fits return word_count_fits and token_count_fits
def read(self): def read(self):
@ -32,7 +37,6 @@ class TextChunker:
for content_text in self.get_text(): for content_text in self.get_text():
for chunk_data in chunk_by_paragraph( for chunk_data in chunk_by_paragraph(
content_text, content_text,
self.max_tokens,
self.max_chunk_size, self.max_chunk_size,
batch_paragraphs=True, batch_paragraphs=True,
): ):

View file

@ -8,7 +8,6 @@ import os
class CognifyConfig(BaseSettings): class CognifyConfig(BaseSettings):
classification_model: object = DefaultContentPrediction classification_model: object = DefaultContentPrediction
summarization_model: object = SummarizedContent summarization_model: object = SummarizedContent
max_tokens: Optional[int] = os.getenv("MAX_TOKENS")
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict: def to_dict(self) -> dict:

View file

@ -13,14 +13,12 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location) result = get_llm_client().create_transcript(self.raw_data_location)
return result.text return result.text
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): def read(self, chunk_size: int, chunker: str):
# Transcribe the audio file # Transcribe the audio file
text = self.create_transcript() text = self.create_transcript()
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func( chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -11,5 +11,5 @@ class Document(DataPoint):
mime_type: str mime_type: str
_metadata: dict = {"index_fields": ["name"], "type": "Document"} _metadata: dict = {"index_fields": ["name"], "type": "Document"}
def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str: def read(self, chunk_size: int, chunker=str) -> str:
pass pass

View file

@ -13,13 +13,11 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location) result = get_llm_client().transcribe_image(self.raw_data_location)
return result.choices[0].message.content return result.choices[0].message.content
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): def read(self, chunk_size: int, chunker: str):
# Transcribe the image file # Transcribe the image file
text = self.transcribe_image() text = self.transcribe_image()
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func( chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -9,7 +9,7 @@ from .Document import Document
class PdfDocument(Document): class PdfDocument(Document):
type: str = "pdf" type: str = "pdf"
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): def read(self, chunk_size: int, chunker: str):
file = PdfReader(self.raw_data_location) file = PdfReader(self.raw_data_location)
def get_text(): def get_text():
@ -18,9 +18,7 @@ class PdfDocument(Document):
yield page_text yield page_text
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func( chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -7,7 +7,7 @@ from .Document import Document
class TextDocument(Document): class TextDocument(Document):
type: str = "text" type: str = "text"
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): def read(self, chunk_size: int, chunker: str):
def get_text(): def get_text():
with open(self.raw_data_location, mode="r", encoding="utf-8") as file: with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
while True: while True:
@ -20,8 +20,6 @@ class TextDocument(Document):
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func( chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -10,7 +10,7 @@ from .Document import Document
class UnstructuredDocument(Document): class UnstructuredDocument(Document):
type: str = "unstructured" type: str = "unstructured"
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str: def read(self, chunk_size: int, chunker: str) -> str:
def get_text(): def get_text():
try: try:
from unstructured.partition.auto import partition from unstructured.partition.auto import partition
@ -29,6 +29,6 @@ class UnstructuredDocument(Document):
yield text yield text
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens) chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text)
yield from chunker.read() yield from chunker.read()

View file

@ -10,8 +10,6 @@ import graphistry
import networkx as nx import networkx as nx
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import tiktoken
import time
import logging import logging
import sys import sys
@ -100,15 +98,6 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
print(f"Error sending telemetry through proxy: {response.status_code}") print(f"Error sending telemetry through proxy: {response.status_code}")
def num_tokens_from_string(string: str, encoding_name: str) -> int:
"""Returns the number of tokens in a text string."""
# tiktoken.get_encoding("cl100k_base")
encoding = tiktoken.encoding_for_model(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str: def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
h = hashlib.md5() h = hashlib.md5()
@ -134,34 +123,6 @@ def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
raise IngestionError(message=f"Failed to load data from {file}: {e}") raise IngestionError(message=f"Failed to load data from {file}: {e}")
def trim_text_to_max_tokens(text: str, max_tokens: int, encoding_name: str) -> str:
"""
Trims the text so that the number of tokens does not exceed max_tokens.
Args:
text (str): Original text string to be trimmed.
max_tokens (int): Maximum number of tokens allowed.
encoding_name (str): The name of the token encoding to use.
Returns:
str: Trimmed version of text or original text if under the limit.
"""
# First check the number of tokens
num_tokens = num_tokens_from_string(text, encoding_name)
# If the number of tokens is within the limit, return the text as is
if num_tokens <= max_tokens:
return text
# If the number exceeds the limit, trim the text
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
encoded_text = tiktoken.get_encoding(encoding_name).encode(text)
trimmed_encoded_text = encoded_text[:max_tokens]
# Decoding the trimmed text
trimmed_text = tiktoken.get_encoding(encoding_name).decode(trimmed_encoded_text)
return trimmed_text
def generate_color_palette(unique_layers): def generate_color_palette(unique_layers):
colormap = plt.cm.get_cmap("viridis", len(unique_layers)) colormap = plt.cm.get_cmap("viridis", len(unique_layers))
colors = [colormap(i) for i in range(len(unique_layers))] colors = [colormap(i) for i in range(len(unique_layers))]

View file

@ -4,13 +4,13 @@ from uuid import NAMESPACE_OID, uuid5
import tiktoken import tiktoken
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from .chunk_by_sentence import chunk_by_sentence from .chunk_by_sentence import chunk_by_sentence
def chunk_by_paragraph( def chunk_by_paragraph(
data: str, data: str,
max_tokens: Optional[Union[int, float]] = None,
paragraph_length: int = 1024, paragraph_length: int = 1024,
batch_paragraphs: bool = True, batch_paragraphs: bool = True,
) -> Iterator[Dict[str, Any]]: ) -> Iterator[Dict[str, Any]]:
@ -24,24 +24,22 @@ def chunk_by_paragraph(
paragraph_ids = [] paragraph_ids = []
last_cut_type = None last_cut_type = None
current_token_count = 0 current_token_count = 0
if not max_tokens:
max_tokens = float("inf")
# Get vector and embedding engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model embedding_engine = vector_engine.embedding_engine
embedding_model = embedding_model.split("/")[-1]
# embedding_model = embedding_engine.model.split("/")[-1]
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence( for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
data, maximum_length=paragraph_length data, maximum_length=paragraph_length
): ):
# Check if this sentence would exceed length limit # Check if this sentence would exceed length limit
token_count = embedding_engine.tokenizer.num_tokens_from_text(sentence)
tokenizer = tiktoken.encoding_for_model(embedding_model)
token_count = len(tokenizer.encode(sentence))
if current_word_count > 0 and ( if current_word_count > 0 and (
current_word_count + word_count > paragraph_length current_word_count + word_count > paragraph_length
or current_token_count + token_count > max_tokens or current_token_count + token_count > embedding_engine.max_tokens
): ):
# Yield current chunk # Yield current chunk
chunk_dict = { chunk_dict = {

View file

@ -7,10 +7,7 @@ async def extract_chunks_from_documents(
documents: list[Document], documents: list[Document],
chunk_size: int = 1024, chunk_size: int = 1024,
chunker="text_chunker", chunker="text_chunker",
max_tokens: Optional[int] = None,
): ):
for document in documents: for document in documents:
for document_chunk in document.read( for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker):
chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens
):
yield document_chunk yield document_chunk

View file

@ -89,26 +89,31 @@ def _get_subchunk_token_counts(
def _get_chunk_source_code( def _get_chunk_source_code(
code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int code_token_counts: list[tuple[str, int]], overlap: float
) -> tuple[list[tuple[str, int]], str]: ) -> tuple[list[tuple[str, int]], str]:
"""Generates a chunk of source code from tokenized subchunks with overlap handling.""" """Generates a chunk of source code from tokenized subchunks with overlap handling."""
current_count = 0 current_count = 0
cumulative_counts = [] cumulative_counts = []
current_source_code = "" current_source_code = ""
# Get embedding engine used in vector database
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
embedding_engine = get_vector_engine().embedding_engine
for i, (child_code, token_count) in enumerate(code_token_counts): for i, (child_code, token_count) in enumerate(code_token_counts):
current_count += token_count current_count += token_count
cumulative_counts.append(current_count) cumulative_counts.append(current_count)
if current_count > max_tokens: if current_count > embedding_engine.max_tokens:
break break
current_source_code += f"\n{child_code}" current_source_code += f"\n{child_code}"
if current_count <= max_tokens: if current_count <= embedding_engine.max_tokens:
return [], current_source_code.strip() return [], current_source_code.strip()
cutoff = 1 cutoff = 1
for i, cum_count in enumerate(cumulative_counts): for i, cum_count in enumerate(cumulative_counts):
if cum_count > (1 - overlap) * max_tokens: if cum_count > (1 - overlap) * embedding_engine.max_tokens:
break break
cutoff = i cutoff = i