Get embedding engine instead of passing it in code chunking.

This commit is contained in:
Rita Aleksziev 2025-01-08 13:45:04 +01:00
parent 34a9267f41
commit 97814e334f
3 changed files with 9 additions and 10 deletions

View file

@ -3,8 +3,6 @@ import logging
from pathlib import Path
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.embeddings import \
get_embedding_engine
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
@ -51,8 +49,6 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
await cognee.prune.prune_system(metadata=True)
await create_db_and_tables()
embedding_engine = get_embedding_engine()
cognee_config = get_cognify_config()
user = await get_default_user()
@ -60,7 +56,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(get_repo_file_dependencies),
Task(enrich_dependency_graph),
Task(expand_dependency_graph, task_config={"batch_size": 50}),
Task(get_source_code_chunks, embedding_model=embedding_engine.model, task_config={"batch_size": 50}),
Task(get_source_code_chunks, task_config={"batch_size": 50}),
Task(summarize_code, task_config={"batch_size": 50}),
Task(add_data_points, task_config={"batch_size": 50}),
]

View file

@ -29,11 +29,11 @@ def chunk_by_paragraph(
vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
embedding_model = embedding_model.split("/")[-1]
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(data, maximum_length=paragraph_length):
# Check if this sentence would exceed length limit
embedding_model = embedding_model.split("/")[-1]
tokenizer = tiktoken.encoding_for_model(embedding_model)
token_count = len(tokenizer.encode(sentence))

View file

@ -5,6 +5,7 @@ from uuid import NAMESPACE_OID, uuid5
import parso
import tiktoken
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
@ -115,13 +116,15 @@ def get_source_code_chunks_from_code_part(
max_tokens: int = 8192,
overlap: float = 0.25,
granularity: float = 0.1,
model_name: str = "text-embedding-3-large"
) -> Generator[SourceCodeChunk, None, None]:
"""Yields source code chunks from a CodePart object, with configurable token limits and overlap."""
if not code_file_part.source_code:
logger.error(f"No source code in CodeFile {code_file_part.id}")
return
vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
model_name = embedding_model.split("/")[-1]
tokenizer = tiktoken.encoding_for_model(model_name)
max_subchunk_tokens = max(1, int(granularity * max_tokens))
subchunk_token_counts = _get_subchunk_token_counts(tokenizer, code_file_part.source_code, max_subchunk_tokens)
@ -141,7 +144,7 @@ def get_source_code_chunks_from_code_part(
previous_chunk = current_chunk
async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="text-embedding-3-large") -> \
async def get_source_code_chunks(data_points: list[DataPoint]) -> \
AsyncGenerator[list[DataPoint], None]:
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
# TODO: Add support for other embedding models, with max_token mapping
@ -156,7 +159,7 @@ async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="
for code_part in data_point.contains:
try:
yield code_part
for source_code_chunk in get_source_code_chunks_from_code_part(code_part, model_name=embedding_model):
for source_code_chunk in get_source_code_chunks_from_code_part(code_part):
yield source_code_chunk
except Exception as e:
logger.error(f"Error processing code part: {e}")