feat: Add max_chunk_tokens value to chunkers
Add formula and forwarding of max_chunk_tokens value through Cognee
This commit is contained in:
parent
49f60971bb
commit
3db7f85c9c
15 changed files with 71 additions and 32 deletions
|
|
@ -7,6 +7,7 @@ LLM_MODEL="openai/gpt-4o-mini"
|
||||||
LLM_PROVIDER="openai"
|
LLM_PROVIDER="openai"
|
||||||
LLM_ENDPOINT=""
|
LLM_ENDPOINT=""
|
||||||
LLM_API_VERSION=""
|
LLM_API_VERSION=""
|
||||||
|
LLM_MAX_TOKENS="128000"
|
||||||
|
|
||||||
GRAPHISTRY_USERNAME=
|
GRAPHISTRY_USERNAME=
|
||||||
GRAPHISTRY_PASSWORD=
|
GRAPHISTRY_PASSWORD=
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ from typing import Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
from cognee.modules.cognify.config import get_cognify_config
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||||
|
|
@ -146,12 +148,23 @@ async def get_default_tasks(
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
|
# Calculate max chunk size based on the following formula
|
||||||
|
embedding_engine = get_vector_engine().embedding_engine
|
||||||
|
llm_client = get_llm_client()
|
||||||
|
|
||||||
|
# We need to make sure chunk size won't take more than half of LLM max context token size
|
||||||
|
# but it also can't be bigger than the embedding engine max token size
|
||||||
|
llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division
|
||||||
|
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cognee_config = get_cognify_config()
|
cognee_config = get_cognify_config()
|
||||||
default_tasks = [
|
default_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
||||||
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
|
Task(
|
||||||
|
extract_chunks_from_documents, max_chunk_tokens=max_chunk_tokens
|
||||||
|
), # Extract text chunks based on the document type.
|
||||||
Task(
|
Task(
|
||||||
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}
|
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}
|
||||||
), # Generate knowledge graphs from the document chunks.
|
), # Generate knowledge graphs from the document chunks.
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,12 @@ class AnthropicAdapter(LLMInterface):
|
||||||
name = "Anthropic"
|
name = "Anthropic"
|
||||||
model: str
|
model: str
|
||||||
|
|
||||||
def __init__(self, model: str = None):
|
def __init__(self, max_tokens: int, model: str = None):
|
||||||
self.aclient = instructor.patch(
|
self.aclient = instructor.patch(
|
||||||
create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_TOOLS
|
create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_TOOLS
|
||||||
)
|
)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ class LLMConfig(BaseSettings):
|
||||||
llm_api_version: Optional[str] = None
|
llm_api_version: Optional[str] = None
|
||||||
llm_temperature: float = 0.0
|
llm_temperature: float = 0.0
|
||||||
llm_streaming: bool = False
|
llm_streaming: bool = False
|
||||||
|
llm_max_tokens: int = 128000
|
||||||
transcription_model: str = "whisper-1"
|
transcription_model: str = "whisper-1"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
@ -24,6 +25,7 @@ class LLMConfig(BaseSettings):
|
||||||
"api_version": self.llm_api_version,
|
"api_version": self.llm_api_version,
|
||||||
"temperature": self.llm_temperature,
|
"temperature": self.llm_temperature,
|
||||||
"streaming": self.llm_streaming,
|
"streaming": self.llm_streaming,
|
||||||
|
"max_tokens": self.llm_max_tokens,
|
||||||
"transcription_model": self.transcription_model,
|
"transcription_model": self.transcription_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import instructor
|
import instructor
|
||||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
|
|
@ -16,11 +17,12 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
model: str
|
model: str
|
||||||
api_key: str
|
api_key: str
|
||||||
|
|
||||||
def __init__(self, endpoint, api_key: str, model: str, name: str):
|
def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
llm_config = get_llm_config()
|
llm_config = get_llm_config()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ def get_llm_client():
|
||||||
api_version=llm_config.llm_api_version,
|
api_version=llm_config.llm_api_version,
|
||||||
model=llm_config.llm_model,
|
model=llm_config.llm_model,
|
||||||
transcription_model=llm_config.transcription_model,
|
transcription_model=llm_config.transcription_model,
|
||||||
|
max_tokens=llm_config.llm_max_tokens,
|
||||||
streaming=llm_config.llm_streaming,
|
streaming=llm_config.llm_streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -42,13 +43,17 @@ def get_llm_client():
|
||||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
|
|
||||||
return GenericAPIAdapter(
|
return GenericAPIAdapter(
|
||||||
llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama"
|
llm_config.llm_endpoint,
|
||||||
|
llm_config.llm_api_key,
|
||||||
|
llm_config.llm_model,
|
||||||
|
"Ollama",
|
||||||
|
max_tokens=llm_config.llm_max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif provider == LLMProvider.ANTHROPIC:
|
elif provider == LLMProvider.ANTHROPIC:
|
||||||
from .anthropic.adapter import AnthropicAdapter
|
from .anthropic.adapter import AnthropicAdapter
|
||||||
|
|
||||||
return AnthropicAdapter(llm_config.llm_model)
|
return AnthropicAdapter(max_tokens=llm_config.llm_max_tokens, model=llm_config.llm_model)
|
||||||
|
|
||||||
elif provider == LLMProvider.CUSTOM:
|
elif provider == LLMProvider.CUSTOM:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
|
|
@ -57,7 +62,11 @@ def get_llm_client():
|
||||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
|
|
||||||
return GenericAPIAdapter(
|
return GenericAPIAdapter(
|
||||||
llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom"
|
llm_config.llm_endpoint,
|
||||||
|
llm_config.llm_api_key,
|
||||||
|
llm_config.llm_model,
|
||||||
|
"Custom",
|
||||||
|
max_tokens=llm_config.llm_max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
api_version: str,
|
api_version: str,
|
||||||
model: str,
|
model: str,
|
||||||
transcription_model: str,
|
transcription_model: str,
|
||||||
|
max_tokens: int,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
):
|
):
|
||||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||||
|
|
@ -41,6 +42,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.api_version = api_version
|
self.api_version = api_version
|
||||||
|
self.max_tokens = max_tokens
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
|
|
||||||
@observe(as_type="generation")
|
@observe(as_type="generation")
|
||||||
|
|
|
||||||
|
|
@ -14,22 +14,15 @@ class TextChunker:
|
||||||
chunk_size = 0
|
chunk_size = 0
|
||||||
token_count = 0
|
token_count = 0
|
||||||
|
|
||||||
def __init__(self, document, get_text: callable, chunk_size: int = 1024):
|
def __init__(self, document, get_text: callable, max_chunk_tokens: int, chunk_size: int = 1024):
|
||||||
self.document = document
|
self.document = document
|
||||||
self.max_chunk_size = chunk_size
|
self.max_chunk_size = chunk_size
|
||||||
self.get_text = get_text
|
self.get_text = get_text
|
||||||
|
self.max_chunk_tokens = max_chunk_tokens
|
||||||
|
|
||||||
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
|
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
|
||||||
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
|
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
|
||||||
|
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_chunk_tokens
|
||||||
# Get embedding engine related to vector database
|
|
||||||
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
|
||||||
|
|
||||||
embedding_engine = get_vector_engine().embedding_engine
|
|
||||||
|
|
||||||
token_count_fits = (
|
|
||||||
token_count_before + chunk_data["token_count"] <= embedding_engine.max_tokens
|
|
||||||
)
|
|
||||||
return word_count_fits and token_count_fits
|
return word_count_fits and token_count_fits
|
||||||
|
|
||||||
def read(self):
|
def read(self):
|
||||||
|
|
@ -37,6 +30,7 @@ class TextChunker:
|
||||||
for content_text in self.get_text():
|
for content_text in self.get_text():
|
||||||
for chunk_data in chunk_by_paragraph(
|
for chunk_data in chunk_by_paragraph(
|
||||||
content_text,
|
content_text,
|
||||||
|
self.max_chunk_tokens,
|
||||||
self.max_chunk_size,
|
self.max_chunk_size,
|
||||||
batch_paragraphs=True,
|
batch_paragraphs=True,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -13,12 +13,14 @@ class AudioDocument(Document):
|
||||||
result = get_llm_client().create_transcript(self.raw_data_location)
|
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||||
return result.text
|
return result.text
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str):
|
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||||
# Transcribe the audio file
|
# Transcribe the audio file
|
||||||
|
|
||||||
text = self.create_transcript()
|
text = self.create_transcript()
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
|
chunker = chunker_func(
|
||||||
|
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,13 @@ class ImageDocument(Document):
|
||||||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||||
return result.choices[0].message.content
|
return result.choices[0].message.content
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str):
|
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||||
# Transcribe the image file
|
# Transcribe the image file
|
||||||
text = self.transcribe_image()
|
text = self.transcribe_image()
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text])
|
chunker = chunker_func(
|
||||||
|
self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from .Document import Document
|
||||||
class PdfDocument(Document):
|
class PdfDocument(Document):
|
||||||
type: str = "pdf"
|
type: str = "pdf"
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str):
|
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||||
file = PdfReader(self.raw_data_location)
|
file = PdfReader(self.raw_data_location)
|
||||||
|
|
||||||
def get_text():
|
def get_text():
|
||||||
|
|
@ -18,7 +18,9 @@ class PdfDocument(Document):
|
||||||
yield page_text
|
yield page_text
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
|
chunker = chunker_func(
|
||||||
|
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from .Document import Document
|
||||||
class TextDocument(Document):
|
class TextDocument(Document):
|
||||||
type: str = "text"
|
type: str = "text"
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str):
|
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
|
||||||
def get_text():
|
def get_text():
|
||||||
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -20,6 +20,8 @@ class TextDocument(Document):
|
||||||
|
|
||||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||||
|
|
||||||
chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text)
|
chunker = chunker_func(
|
||||||
|
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from .Document import Document
|
||||||
class UnstructuredDocument(Document):
|
class UnstructuredDocument(Document):
|
||||||
type: str = "unstructured"
|
type: str = "unstructured"
|
||||||
|
|
||||||
def read(self, chunk_size: int, chunker: str) -> str:
|
def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int) -> str:
|
||||||
def get_text():
|
def get_text():
|
||||||
try:
|
try:
|
||||||
from unstructured.partition.auto import partition
|
from unstructured.partition.auto import partition
|
||||||
|
|
@ -29,6 +29,8 @@ class UnstructuredDocument(Document):
|
||||||
|
|
||||||
yield text
|
yield text
|
||||||
|
|
||||||
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text)
|
chunker = TextChunker(
|
||||||
|
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield from chunker.read()
|
yield from chunker.read()
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,13 @@ from uuid import NAMESPACE_OID, uuid5
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
|
||||||
|
|
||||||
from .chunk_by_sentence import chunk_by_sentence
|
from .chunk_by_sentence import chunk_by_sentence
|
||||||
|
|
||||||
|
|
||||||
def chunk_by_paragraph(
|
def chunk_by_paragraph(
|
||||||
data: str,
|
data: str,
|
||||||
|
max_chunk_tokens,
|
||||||
paragraph_length: int = 1024,
|
paragraph_length: int = 1024,
|
||||||
batch_paragraphs: bool = True,
|
batch_paragraphs: bool = True,
|
||||||
) -> Iterator[Dict[str, Any]]:
|
) -> Iterator[Dict[str, Any]]:
|
||||||
|
|
@ -31,19 +31,21 @@ def chunk_by_paragraph(
|
||||||
last_cut_type = None
|
last_cut_type = None
|
||||||
current_token_count = 0
|
current_token_count = 0
|
||||||
|
|
||||||
# Get vector and embedding engine
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
embedding_engine = vector_engine.embedding_engine
|
embedding_model = vector_engine.embedding_engine.model
|
||||||
|
embedding_model = embedding_model.split("/")[-1]
|
||||||
|
|
||||||
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
|
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(
|
||||||
data, maximum_length=paragraph_length
|
data, maximum_length=paragraph_length
|
||||||
):
|
):
|
||||||
# Check if this sentence would exceed length limit
|
# Check if this sentence would exceed length limit
|
||||||
token_count = embedding_engine.tokenizer.count_tokens(sentence)
|
|
||||||
|
tokenizer = tiktoken.encoding_for_model(embedding_model)
|
||||||
|
token_count = len(tokenizer.encode(sentence))
|
||||||
|
|
||||||
if current_word_count > 0 and (
|
if current_word_count > 0 and (
|
||||||
current_word_count + word_count > paragraph_length
|
current_word_count + word_count > paragraph_length
|
||||||
or current_token_count + token_count > embedding_engine.max_tokens
|
or current_token_count + token_count > max_chunk_tokens
|
||||||
):
|
):
|
||||||
# Yield current chunk
|
# Yield current chunk
|
||||||
chunk_dict = {
|
chunk_dict = {
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from cognee.modules.data.processing.document_types.Document import Document
|
||||||
|
|
||||||
async def extract_chunks_from_documents(
|
async def extract_chunks_from_documents(
|
||||||
documents: list[Document],
|
documents: list[Document],
|
||||||
|
max_chunk_tokens: int,
|
||||||
chunk_size: int = 1024,
|
chunk_size: int = 1024,
|
||||||
chunker="text_chunker",
|
chunker="text_chunker",
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
@ -16,5 +17,7 @@ async def extract_chunks_from_documents(
|
||||||
- The `chunker` parameter determines the chunking logic and should align with the document type.
|
- The `chunker` parameter determines the chunking logic and should align with the document type.
|
||||||
"""
|
"""
|
||||||
for document in documents:
|
for document in documents:
|
||||||
for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker):
|
for document_chunk in document.read(
|
||||||
|
chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens
|
||||||
|
):
|
||||||
yield document_chunk
|
yield document_chunk
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue