From 4e56cd64a1ab6cef90e20c8f2fd20f5b25d098ce Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 28 Jan 2025 15:33:34 +0100 Subject: [PATCH] refactor: Add max chunk tokens to code graph pipeline --- cognee/api/v1/cognify/code_graph_pipeline.py | 3 ++- cognee/api/v1/cognify/cognify_v2.py | 14 ++------------ cognee/infrastructure/llm/__init__.py | 1 + cognee/infrastructure/llm/utils.py | 15 +++++++++++++++ 4 files changed, 20 insertions(+), 13 deletions(-) create mode 100644 cognee/infrastructure/llm/utils.py diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 4a864eb0e..125245f40 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -21,6 +21,7 @@ from cognee.tasks.repo_processor import ( from cognee.tasks.repo_processor.get_source_code_chunks import get_source_code_chunks from cognee.tasks.storage import add_data_points from cognee.tasks.summarization import summarize_code, summarize_text +from cognee.infrastructure.llm import get_max_chunk_tokens monitoring = get_base_config().monitoring_tool if monitoring == MonitoringTool.LANGFUSE: @@ -71,7 +72,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True): Task(ingest_data, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(classify_documents), - Task(extract_chunks_from_documents), + Task(extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()), Task( extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50} ), diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index 73504f057..12a84030d 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -4,8 +4,7 @@ from typing import Union from pydantic import BaseModel -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm import get_max_chunk_tokens from cognee.modules.cognify.config import get_cognify_config from cognee.modules.data.methods import get_datasets, get_datasets_by_name from cognee.modules.data.methods.get_dataset_data import get_dataset_data @@ -148,22 +147,13 @@ async def get_default_tasks( if user is None: user = await get_default_user() - # Calculate max chunk size based on the following formula - embedding_engine = get_vector_engine().embedding_engine - llm_client = get_llm_client() - - # We need to make sure chunk size won't take more than half of LLM max context token size - # but it also can't be bigger than the embedding engine max token size - llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division - max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) - try: cognee_config = get_cognify_config() default_tasks = [ Task(classify_documents), Task(check_permissions_on_documents, user=user, permissions=["write"]), Task( - extract_chunks_from_documents, max_chunk_tokens=max_chunk_tokens + extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens() ), # Extract text chunks based on the document type. Task( extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10} diff --git a/cognee/infrastructure/llm/__init__.py b/cognee/infrastructure/llm/__init__.py index 7fb3be736..36d7e56ad 100644 --- a/cognee/infrastructure/llm/__init__.py +++ b/cognee/infrastructure/llm/__init__.py @@ -1 +1,2 @@ from .config import get_llm_config +from .utils import get_max_chunk_tokens diff --git a/cognee/infrastructure/llm/utils.py b/cognee/infrastructure/llm/utils.py new file mode 100644 index 000000000..816eaf285 --- /dev/null +++ b/cognee/infrastructure/llm/utils.py @@ -0,0 +1,15 @@ +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.llm.get_llm_client import get_llm_client + + +def get_max_chunk_tokens(): + # Calculate max chunk size based on the following formula + embedding_engine = get_vector_engine().embedding_engine + llm_client = get_llm_client() + + # We need to make sure chunk size won't take more than half of LLM max context token size + # but it also can't be bigger than the embedding engine max token size + llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division + max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point) + + return max_chunk_tokens