Get embedding engine instead of passing it in code chunking.
This commit is contained in:
parent
34a9267f41
commit
97814e334f
3 changed files with 9 additions and 10 deletions
|
|
@ -3,8 +3,6 @@ import logging
|
|||
from pathlib import Path
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector.embeddings import \
|
||||
get_embedding_engine
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
|
|
@ -51,8 +49,6 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
|||
await cognee.prune.prune_system(metadata=True)
|
||||
await create_db_and_tables()
|
||||
|
||||
embedding_engine = get_embedding_engine()
|
||||
|
||||
cognee_config = get_cognify_config()
|
||||
user = await get_default_user()
|
||||
|
||||
|
|
@ -60,7 +56,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
|||
Task(get_repo_file_dependencies),
|
||||
Task(enrich_dependency_graph),
|
||||
Task(expand_dependency_graph, task_config={"batch_size": 50}),
|
||||
Task(get_source_code_chunks, embedding_model=embedding_engine.model, task_config={"batch_size": 50}),
|
||||
Task(get_source_code_chunks, task_config={"batch_size": 50}),
|
||||
Task(summarize_code, task_config={"batch_size": 50}),
|
||||
Task(add_data_points, task_config={"batch_size": 50}),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -29,11 +29,11 @@ def chunk_by_paragraph(
|
|||
|
||||
vector_engine = get_vector_engine()
|
||||
embedding_model = vector_engine.embedding_engine.model
|
||||
|
||||
embedding_model = embedding_model.split("/")[-1]
|
||||
|
||||
for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(data, maximum_length=paragraph_length):
|
||||
# Check if this sentence would exceed length limit
|
||||
|
||||
embedding_model = embedding_model.split("/")[-1]
|
||||
tokenizer = tiktoken.encoding_for_model(embedding_model)
|
||||
token_count = len(tokenizer.encode(sentence))
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from uuid import NAMESPACE_OID, uuid5
|
|||
import parso
|
||||
import tiktoken
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
|
||||
|
||||
|
|
@ -115,13 +116,15 @@ def get_source_code_chunks_from_code_part(
|
|||
max_tokens: int = 8192,
|
||||
overlap: float = 0.25,
|
||||
granularity: float = 0.1,
|
||||
model_name: str = "text-embedding-3-large"
|
||||
) -> Generator[SourceCodeChunk, None, None]:
|
||||
"""Yields source code chunks from a CodePart object, with configurable token limits and overlap."""
|
||||
if not code_file_part.source_code:
|
||||
logger.error(f"No source code in CodeFile {code_file_part.id}")
|
||||
return
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
embedding_model = vector_engine.embedding_engine.model
|
||||
model_name = embedding_model.split("/")[-1]
|
||||
tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
max_subchunk_tokens = max(1, int(granularity * max_tokens))
|
||||
subchunk_token_counts = _get_subchunk_token_counts(tokenizer, code_file_part.source_code, max_subchunk_tokens)
|
||||
|
|
@ -141,7 +144,7 @@ def get_source_code_chunks_from_code_part(
|
|||
previous_chunk = current_chunk
|
||||
|
||||
|
||||
async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="text-embedding-3-large") -> \
|
||||
async def get_source_code_chunks(data_points: list[DataPoint]) -> \
|
||||
AsyncGenerator[list[DataPoint], None]:
|
||||
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
|
||||
# TODO: Add support for other embedding models, with max_token mapping
|
||||
|
|
@ -156,7 +159,7 @@ async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="
|
|||
for code_part in data_point.contains:
|
||||
try:
|
||||
yield code_part
|
||||
for source_code_chunk in get_source_code_chunks_from_code_part(code_part, model_name=embedding_model):
|
||||
for source_code_chunk in get_source_code_chunks_from_code_part(code_part):
|
||||
yield source_code_chunk
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing code part: {e}")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue