fix: Initial commit to resolve issue with using tokenizer based on LLMs
Currently TikToken is used for tokenizing by default which is only supported by OpenAI, this is an initial commit in an attempt to add Cognee tokenizing support for multiple LLMs
This commit is contained in:
parent
77f0b45a0d
commit
93249c72c5
22 changed files with 176 additions and 84 deletions
|
|
@ -71,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
|||
Task(ingest_data, dataset_name="repo_docs", user=user),
|
||||
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
|
||||
Task(classify_documents),
|
||||
Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens),
|
||||
Task(extract_chunks_from_documents),
|
||||
Task(
|
||||
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
|
||||
),
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ import litellm
|
|||
import os
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
import tiktoken # Assuming this is how you import TikToken
|
||||
|
||||
litellm.set_verbose = False
|
||||
logger = logging.getLogger("LiteLLMEmbeddingEngine")
|
||||
|
|
@ -15,23 +18,30 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
api_key: str
|
||||
endpoint: str
|
||||
api_version: str
|
||||
provider: str
|
||||
model: str
|
||||
dimensions: int
|
||||
mock: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = "text-embedding-3-large",
|
||||
dimensions: Optional[int] = 3072,
|
||||
api_key: str = None,
|
||||
endpoint: str = None,
|
||||
api_version: str = None,
|
||||
max_tokens: int = float("inf"),
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
# TODO: Add or remove provider info
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_tokens = max_tokens
|
||||
self.tokenizer = self.set_tokenizer()
|
||||
|
||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||
if isinstance(enable_mocking, bool):
|
||||
|
|
@ -104,3 +114,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
def get_vector_size(self) -> int:
|
||||
return self.dimensions
|
||||
|
||||
def set_tokenizer(self):
|
||||
logger.debug(f"Loading tokenizer for model {self.model}...")
|
||||
# If model also contains provider information, extract only model information
|
||||
model = self.model.split("/")[-1]
|
||||
|
||||
if "openai" in self.provider.lower() or "gpt" in self.model:
|
||||
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model)
|
||||
|
||||
logger.debug(f"Tokenizer loaded for model: {self.model}")
|
||||
return tokenizer
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class EmbeddingConfig(BaseSettings):
|
|||
embedding_endpoint: Optional[str] = None
|
||||
embedding_api_key: Optional[str] = None
|
||||
embedding_api_version: Optional[str] = None
|
||||
|
||||
embedding_max_tokens: Optional[int] = float("inf")
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,4 +15,5 @@ def get_embedding_engine() -> EmbeddingEngine:
|
|||
api_version=config.embedding_api_version,
|
||||
model=config.embedding_model,
|
||||
dimensions=config.embedding_dimensions,
|
||||
max_tokens=config.embedding_max_tokens,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .adapter import HuggingFaceTokenizer
|
||||
22
cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py
Normal file
22
cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from typing import List, Any
|
||||
|
||||
from ..tokenizer_interface import TokenizerInterface
|
||||
|
||||
|
||||
class HuggingFaceTokenizer(TokenizerInterface):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int = float("inf"),
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def num_tokens_from_text(self, text: str) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def trim_text_to_max_tokens(self, text: str) -> str:
|
||||
raise NotImplementedError
|
||||
1
cognee/infrastructure/llm/tokenizer/TikToken/__init__.py
Normal file
1
cognee/infrastructure/llm/tokenizer/TikToken/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .adapter import TikTokenTokenizer
|
||||
69
cognee/infrastructure/llm/tokenizer/TikToken/adapter.py
Normal file
69
cognee/infrastructure/llm/tokenizer/TikToken/adapter.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
from typing import List, Any
|
||||
import tiktoken
|
||||
|
||||
from ..tokenizer_interface import TokenizerInterface
|
||||
|
||||
|
||||
class TikTokenTokenizer(TokenizerInterface):
|
||||
"""
|
||||
Tokenizer adapter for OpenAI.
|
||||
Inteded to be used as part of LLM Embedding and LLM Adapters classes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int = float("inf"),
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
# Initialize TikToken for GPT based on model
|
||||
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
||||
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
tokens = []
|
||||
# Using TikToken's method to tokenize text
|
||||
token_ids = self.tokenizer.encode(text)
|
||||
# Go through tokens and decode them to text value
|
||||
for token_id in token_ids:
|
||||
token = self.tokenizer.decode([token_id])
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def num_tokens_from_text(self, text: str) -> int:
|
||||
"""
|
||||
Returns the number of tokens in the given text.
|
||||
Args:
|
||||
text: str
|
||||
|
||||
Returns:
|
||||
number of tokens in the given text
|
||||
|
||||
"""
|
||||
num_tokens = len(self.tokenizer.encode(text))
|
||||
return num_tokens
|
||||
|
||||
def trim_text_to_max_tokens(self, text: str) -> str:
|
||||
"""
|
||||
Trims the text so that the number of tokens does not exceed max_tokens.
|
||||
|
||||
Args:
|
||||
text (str): Original text string to be trimmed.
|
||||
|
||||
Returns:
|
||||
str: Trimmed version of text or original text if under the limit.
|
||||
"""
|
||||
# First check the number of tokens
|
||||
num_tokens = self.num_tokens_from_string(text)
|
||||
|
||||
# If the number of tokens is within the limit, return the text as is
|
||||
if num_tokens <= self.max_tokens:
|
||||
return text
|
||||
|
||||
# If the number exceeds the limit, trim the text
|
||||
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
|
||||
encoded_text = self.tokenizer.encode(text)
|
||||
trimmed_encoded_text = encoded_text[: self.max_tokens]
|
||||
# Decoding the trimmed text
|
||||
trimmed_text = self.tokenizer.decode(trimmed_encoded_text)
|
||||
return trimmed_text
|
||||
1
cognee/infrastructure/llm/tokenizer/__init__.py
Normal file
1
cognee/infrastructure/llm/tokenizer/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .tokenizer_interface import TokenizerInterface
|
||||
18
cognee/infrastructure/llm/tokenizer/tokenizer_interface.py
Normal file
18
cognee/infrastructure/llm/tokenizer/tokenizer_interface.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from typing import List, Protocol, Any
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class TokenizerInterface(Protocol):
|
||||
"""Tokenizer interface"""
|
||||
|
||||
@abstractmethod
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_tokens_from_text(self, text: str) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def trim_text_to_max_tokens(self, text: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
|
@ -14,17 +14,22 @@ class TextChunker:
|
|||
chunk_size = 0
|
||||
token_count = 0
|
||||
|
||||
def __init__(
|
||||
self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024
|
||||
):
|
||||
def __init__(self, document, get_text: callable, chunk_size: int = 1024):
|
||||
self.document = document
|
||||
self.max_chunk_size = chunk_size
|
||||
self.get_text = get_text
|
||||
self.max_tokens = max_tokens if max_tokens else float("inf")
|
||||
|
||||
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
|
||||
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
|
||||
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens
|
||||
|
||||
# Get embedding engine related to vector database
|
||||
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
||||
|
||||
embedding_engine = get_vector_engine().embedding_engine
|
||||
|
||||
token_count_fits = (
|
||||
token_count_before + chunk_data["token_count"] <= embedding_engine.max_tokens
|
||||
)
|
||||
return word_count_fits and token_count_fits
|
||||
|
||||
def read(self):
|
||||
|
|
@ -32,7 +37,6 @@ class TextChunker:
|
|||
for content_text in self.get_text():
|
||||
for chunk_data in chunk_by_paragraph(
|
||||
content_text,
|
||||
self.max_tokens,
|
||||
self.max_chunk_size,
|
||||
batch_paragraphs=True,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import os
|
|||
class CognifyConfig(BaseSettings):
|
||||
classification_model: object = DefaultContentPrediction
|
||||
summarization_model: object = SummarizedContent
|
||||
max_tokens: Optional[int] = os.getenv("MAX_TOKENS")
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
|
|
|
|||
|
|
@ -13,14 +13,12 @@ class AudioDocument(Document):
|
|||
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||
return result.text
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
# Transcribe the audio file
|
||||
|
||||
text = self.create_transcript()
|
||||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(
|
||||
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
|
||||
)
|
||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -11,5 +11,5 @@ class Document(DataPoint):
|
|||
mime_type: str
|
||||
_metadata: dict = {"index_fields": ["name"], "type": "Document"}
|
||||
|
||||
def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str:
|
||||
def read(self, chunk_size: int, chunker=str) -> str:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,13 +13,11 @@ class ImageDocument(Document):
|
|||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||
return result.choices[0].message.content
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
# Transcribe the image file
|
||||
text = self.transcribe_image()
|
||||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(
|
||||
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens
|
||||
)
|
||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from .Document import Document
|
|||
class PdfDocument(Document):
|
||||
type: str = "pdf"
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
file = PdfReader(self.raw_data_location)
|
||||
|
||||
def get_text():
|
||||
|
|
@ -18,9 +18,7 @@ class PdfDocument(Document):
|
|||
yield page_text
|
||||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(
|
||||
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
|
||||
)
|
||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
|
||||
|
||||
yield from chunker.read()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from .Document import Document
|
|||
class TextDocument(Document):
|
||||
type: str = "text"
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
def get_text():
|
||||
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
||||
while True:
|
||||
|
|
@ -20,8 +20,6 @@ class TextDocument(Document):
|
|||
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
|
||||
chunker = chunker_func(
|
||||
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens
|
||||
)
|
||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from .Document import Document
|
|||
class UnstructuredDocument(Document):
|
||||
type: str = "unstructured"
|
||||
|
||||
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str:
|
||||
def read(self, chunk_size: int, chunker: str) -> str:
|
||||
def get_text():
|
||||
try:
|
||||
from unstructured.partition.auto import partition
|
||||
|
|
@ -29,6 +29,6 @@ class UnstructuredDocument(Document):
|
|||
|
||||
yield text
|
||||
|
||||
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens)
|
||||
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ import graphistry
|
|||
import networkx as nx
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import tiktoken
|
||||
import time
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
|
@ -100,15 +98,6 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
|
|||
print(f"Error sending telemetry through proxy: {response.status_code}")
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str, encoding_name: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
|
||||
# tiktoken.get_encoding("cl100k_base")
|
||||
encoding = tiktoken.encoding_for_model(encoding_name)
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
|
||||
|
||||
def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
|
||||
h = hashlib.md5()
|
||||
|
||||
|
|
@ -134,34 +123,6 @@ def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
|
|||
raise IngestionError(message=f"Failed to load data from {file}: {e}")
|
||||
|
||||
|
||||
def trim_text_to_max_tokens(text: str, max_tokens: int, encoding_name: str) -> str:
|
||||
"""
|
||||
Trims the text so that the number of tokens does not exceed max_tokens.
|
||||
|
||||
Args:
|
||||
text (str): Original text string to be trimmed.
|
||||
max_tokens (int): Maximum number of tokens allowed.
|
||||
encoding_name (str): The name of the token encoding to use.
|
||||
|
||||
Returns:
|
||||
str: Trimmed version of text or original text if under the limit.
|
||||
"""
|
||||
# First check the number of tokens
|
||||
num_tokens = num_tokens_from_string(text, encoding_name)
|
||||
|
||||
# If the number of tokens is within the limit, return the text as is
|
||||
if num_tokens <= max_tokens:
|
||||
return text
|
||||
|
||||
# If the number exceeds the limit, trim the text
|
||||
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
|
||||
encoded_text = tiktoken.get_encoding(encoding_name).encode(text)
|
||||
trimmed_encoded_text = encoded_text[:max_tokens]
|
||||
# Decoding the trimmed text
|
||||
trimmed_text = tiktoken.get_encoding(encoding_name).decode(trimmed_encoded_text)
|
||||
return trimmed_text
|
||||
|
||||
|
||||
def generate_color_palette(unique_layers):
|
||||
colormap = plt.cm.get_cmap("viridis", len(unique_layers))
|
||||
colors = [colormap(i) for i in range(len(unique_layers))]
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ from uuid import NAMESPACE_OID, uuid5
|
|||
import tiktoken
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
|
||||
from .chunk_by_sentence import chunk_by_sentence
|
||||
|
||||
|
||||
def chunk_by_paragraph(
|
||||
data: str,
|
||||
max_tokens: Optional[Union[int, float]] = None,
|
||||
paragraph_length: int = 1024,
|
||||
batch_paragraphs: bool = True,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
|
|
@ -24,24 +24,22 @@ def chunk_by_paragraph(
|
|||
paragraph_ids = []
|
||||
last_cut_type = None
|
||||
current_token_count = 0
|
||||
if not max_tokens:
|
||||
max_tokens = float("inf")
|
||||
|
||||
# Get vector and embedding engine
|
||||
vector_engine = get_vector_engine()
|
||||
embedding_model = vector_engine.embedding_engine.model
|
||||
embedding_model = embedding_model.split("/")[-1]
|
||||
embedding_engine = vector_engine.embedding_engine
|
||||
|
||||
# embedding_model = embedding_engine.model.split("/")[-1]
|
||||
|
||||
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
|
||||
data, maximum_length=paragraph_length
|
||||
):
|
||||
# Check if this sentence would exceed length limit
|
||||
|
||||
tokenizer = tiktoken.encoding_for_model(embedding_model)
|
||||
token_count = len(tokenizer.encode(sentence))
|
||||
token_count = embedding_engine.tokenizer.num_tokens_from_text(sentence)
|
||||
|
||||
if current_word_count > 0 and (
|
||||
current_word_count + word_count > paragraph_length
|
||||
or current_token_count + token_count > max_tokens
|
||||
or current_token_count + token_count > embedding_engine.max_tokens
|
||||
):
|
||||
# Yield current chunk
|
||||
chunk_dict = {
|
||||
|
|
|
|||
|
|
@ -7,10 +7,7 @@ async def extract_chunks_from_documents(
|
|||
documents: list[Document],
|
||||
chunk_size: int = 1024,
|
||||
chunker="text_chunker",
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
for document in documents:
|
||||
for document_chunk in document.read(
|
||||
chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens
|
||||
):
|
||||
for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker):
|
||||
yield document_chunk
|
||||
|
|
|
|||
|
|
@ -89,26 +89,31 @@ def _get_subchunk_token_counts(
|
|||
|
||||
|
||||
def _get_chunk_source_code(
|
||||
code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int
|
||||
code_token_counts: list[tuple[str, int]], overlap: float
|
||||
) -> tuple[list[tuple[str, int]], str]:
|
||||
"""Generates a chunk of source code from tokenized subchunks with overlap handling."""
|
||||
current_count = 0
|
||||
cumulative_counts = []
|
||||
current_source_code = ""
|
||||
|
||||
# Get embedding engine used in vector database
|
||||
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
||||
|
||||
embedding_engine = get_vector_engine().embedding_engine
|
||||
|
||||
for i, (child_code, token_count) in enumerate(code_token_counts):
|
||||
current_count += token_count
|
||||
cumulative_counts.append(current_count)
|
||||
if current_count > max_tokens:
|
||||
if current_count > embedding_engine.max_tokens:
|
||||
break
|
||||
current_source_code += f"\n{child_code}"
|
||||
|
||||
if current_count <= max_tokens:
|
||||
if current_count <= embedding_engine.max_tokens:
|
||||
return [], current_source_code.strip()
|
||||
|
||||
cutoff = 1
|
||||
for i, cum_count in enumerate(cumulative_counts):
|
||||
if cum_count > (1 - overlap) * max_tokens:
|
||||
if cum_count > (1 - overlap) * embedding_engine.max_tokens:
|
||||
break
|
||||
cutoff = i
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue