From 97814e334f282b344cb0357df387b70cbf801397 Mon Sep 17 00:00:00 2001 From: Rita Aleksziev Date: Wed, 8 Jan 2025 13:45:04 +0100 Subject: [PATCH] Get embedding engine instead of passing it in code chunking. --- cognee/api/v1/cognify/code_graph_pipeline.py | 6 +----- cognee/tasks/chunks/chunk_by_paragraph.py | 4 ++-- cognee/tasks/repo_processor/get_source_code_chunks.py | 9 ++++++--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 7ba461f88..6e06edfa3 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -3,8 +3,6 @@ import logging from pathlib import Path from cognee.base_config import get_base_config -from cognee.infrastructure.databases.vector.embeddings import \ - get_embedding_engine from cognee.modules.cognify.config import get_cognify_config from cognee.modules.pipelines import run_tasks from cognee.modules.pipelines.tasks.Task import Task @@ -51,8 +49,6 @@ async def run_code_graph_pipeline(repo_path, include_docs=True): await cognee.prune.prune_system(metadata=True) await create_db_and_tables() - embedding_engine = get_embedding_engine() - cognee_config = get_cognify_config() user = await get_default_user() @@ -60,7 +56,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True): Task(get_repo_file_dependencies), Task(enrich_dependency_graph), Task(expand_dependency_graph, task_config={"batch_size": 50}), - Task(get_source_code_chunks, embedding_model=embedding_engine.model, task_config={"batch_size": 50}), + Task(get_source_code_chunks, task_config={"batch_size": 50}), Task(summarize_code, task_config={"batch_size": 50}), Task(add_data_points, task_config={"batch_size": 50}), ] diff --git a/cognee/tasks/chunks/chunk_by_paragraph.py b/cognee/tasks/chunks/chunk_by_paragraph.py index 8ab66bd7f..44355a1ad 100644 --- a/cognee/tasks/chunks/chunk_by_paragraph.py +++ b/cognee/tasks/chunks/chunk_by_paragraph.py @@ -29,11 +29,11 @@ def chunk_by_paragraph( vector_engine = get_vector_engine() embedding_model = vector_engine.embedding_engine.model - + embedding_model = embedding_model.split("/")[-1] + for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(data, maximum_length=paragraph_length): # Check if this sentence would exceed length limit - embedding_model = embedding_model.split("/")[-1] tokenizer = tiktoken.encoding_for_model(embedding_model) token_count = len(tokenizer.encode(sentence)) diff --git a/cognee/tasks/repo_processor/get_source_code_chunks.py b/cognee/tasks/repo_processor/get_source_code_chunks.py index 4d0ce3200..0bf7ebe32 100644 --- a/cognee/tasks/repo_processor/get_source_code_chunks.py +++ b/cognee/tasks/repo_processor/get_source_code_chunks.py @@ -5,6 +5,7 @@ from uuid import NAMESPACE_OID, uuid5 import parso import tiktoken +from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.engine import DataPoint from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk @@ -115,13 +116,15 @@ def get_source_code_chunks_from_code_part( max_tokens: int = 8192, overlap: float = 0.25, granularity: float = 0.1, - model_name: str = "text-embedding-3-large" ) -> Generator[SourceCodeChunk, None, None]: """Yields source code chunks from a CodePart object, with configurable token limits and overlap.""" if not code_file_part.source_code: logger.error(f"No source code in CodeFile {code_file_part.id}") return + vector_engine = get_vector_engine() + embedding_model = vector_engine.embedding_engine.model + model_name = embedding_model.split("/")[-1] tokenizer = tiktoken.encoding_for_model(model_name) max_subchunk_tokens = max(1, int(granularity * max_tokens)) subchunk_token_counts = _get_subchunk_token_counts(tokenizer, code_file_part.source_code, max_subchunk_tokens) @@ -141,7 +144,7 @@ def get_source_code_chunks_from_code_part( previous_chunk = current_chunk -async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="text-embedding-3-large") -> \ +async def get_source_code_chunks(data_points: list[DataPoint]) -> \ AsyncGenerator[list[DataPoint], None]: """Processes code graph datapoints, create SourceCodeChink datapoints.""" # TODO: Add support for other embedding models, with max_token mapping @@ -156,7 +159,7 @@ async def get_source_code_chunks(data_points: list[DataPoint], embedding_model=" for code_part in data_point.contains: try: yield code_part - for source_code_chunk in get_source_code_chunks_from_code_part(code_part, model_name=embedding_model): + for source_code_chunk in get_source_code_chunks_from_code_part(code_part): yield source_code_chunk except Exception as e: logger.error(f"Error processing code part: {e}")