Merge branch 'dev' into COG-650-replace-pylint
This commit is contained in:
commit
c8fdbb45c4
13 changed files with 232 additions and 38 deletions
|
|
@ -17,7 +17,9 @@ Try it in a Google Colab <a href="https://colab.research.google.com/drive/1g-Qn
|
||||||
|
|
||||||
If you have questions, join our <a href="https://discord.gg/NQPKmU5CCg">Discord</a> community
|
If you have questions, join our <a href="https://discord.gg/NQPKmU5CCg">Discord</a> community
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="assets/cognee_benefits.png" alt="why cognee" width="80%" />
|
||||||
|
</div>
|
||||||
## 📦 Installation
|
## 📦 Installation
|
||||||
|
|
||||||
You can install Cognee using either **pip** or **poetry**.
|
You can install Cognee using either **pip** or **poetry**.
|
||||||
|
|
|
||||||
BIN
assets/cognee_benefits.png
Normal file
BIN
assets/cognee_benefits.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 353 KiB |
|
|
@ -3,6 +3,8 @@ import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.infrastructure.databases.vector.embeddings import \
|
||||||
|
get_embedding_engine
|
||||||
from cognee.modules.cognify.config import get_cognify_config
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
from cognee.modules.pipelines import run_tasks
|
from cognee.modules.pipelines import run_tasks
|
||||||
from cognee.modules.pipelines.tasks.Task import Task
|
from cognee.modules.pipelines.tasks.Task import Task
|
||||||
|
|
@ -15,8 +17,10 @@ from cognee.tasks.ingestion import ingest_data_with_metadata
|
||||||
from cognee.tasks.repo_processor import (enrich_dependency_graph,
|
from cognee.tasks.repo_processor import (enrich_dependency_graph,
|
||||||
expand_dependency_graph,
|
expand_dependency_graph,
|
||||||
get_data_list_for_user,
|
get_data_list_for_user,
|
||||||
get_non_code_files,
|
get_non_py_files,
|
||||||
get_repo_file_dependencies)
|
get_repo_file_dependencies)
|
||||||
|
from cognee.tasks.repo_processor.get_source_code_chunks import \
|
||||||
|
get_source_code_chunks
|
||||||
from cognee.tasks.storage import add_data_points
|
from cognee.tasks.storage import add_data_points
|
||||||
|
|
||||||
monitoring = get_base_config().monitoring_tool
|
monitoring = get_base_config().monitoring_tool
|
||||||
|
|
@ -28,6 +32,7 @@ from cognee.tasks.summarization import summarize_code, summarize_text
|
||||||
logger = logging.getLogger("code_graph_pipeline")
|
logger = logging.getLogger("code_graph_pipeline")
|
||||||
update_status_lock = asyncio.Lock()
|
update_status_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
@observe
|
@observe
|
||||||
async def run_code_graph_pipeline(repo_path, include_docs=True):
|
async def run_code_graph_pipeline(repo_path, include_docs=True):
|
||||||
import os
|
import os
|
||||||
|
|
@ -46,20 +51,23 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
await create_db_and_tables()
|
await create_db_and_tables()
|
||||||
|
|
||||||
|
embedding_engine = get_embedding_engine()
|
||||||
|
|
||||||
cognee_config = get_cognify_config()
|
cognee_config = get_cognify_config()
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
Task(get_repo_file_dependencies),
|
Task(get_repo_file_dependencies),
|
||||||
Task(enrich_dependency_graph, task_config={"batch_size": 50}),
|
Task(enrich_dependency_graph),
|
||||||
Task(expand_dependency_graph, task_config={"batch_size": 50}),
|
Task(expand_dependency_graph, task_config={"batch_size": 50}),
|
||||||
|
Task(get_source_code_chunks, embedding_model=embedding_engine.model, task_config={"batch_size": 50}),
|
||||||
Task(summarize_code, task_config={"batch_size": 50}),
|
Task(summarize_code, task_config={"batch_size": 50}),
|
||||||
Task(add_data_points, task_config={"batch_size": 50}),
|
Task(add_data_points, task_config={"batch_size": 50}),
|
||||||
]
|
]
|
||||||
|
|
||||||
if include_docs:
|
if include_docs:
|
||||||
non_code_tasks = [
|
non_code_tasks = [
|
||||||
Task(get_non_code_files, task_config={"batch_size": 50}),
|
Task(get_non_py_files, task_config={"batch_size": 50}),
|
||||||
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
|
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
|
||||||
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
|
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
|
|
@ -71,7 +79,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
|
||||||
task_config={"batch_size": 50}
|
task_config={"batch_size": 50}
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
if include_docs:
|
if include_docs:
|
||||||
async for result in run_tasks(non_code_tasks, repo_path):
|
async for result in run_tasks(non_code_tasks, repo_path):
|
||||||
yield result
|
yield result
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
class EmbeddingException(Exception):
|
||||||
|
"""Custom exception for handling embedding-related errors."""
|
||||||
|
pass
|
||||||
|
|
@ -5,17 +5,19 @@ from typing import List, Optional
|
||||||
import litellm
|
import litellm
|
||||||
import os
|
import os
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
logger = logging.getLogger("LiteLLMEmbeddingEngine")
|
logger = logging.getLogger("LiteLLMEmbeddingEngine")
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
api_key: str
|
api_key: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
api_version: str
|
api_version: str
|
||||||
model: str
|
model: str
|
||||||
dimensions: int
|
dimensions: int
|
||||||
mock:bool
|
mock: bool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -33,7 +35,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
|
|
||||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||||
if isinstance(enable_mocking, bool):
|
if isinstance(enable_mocking, bool):
|
||||||
enable_mocking= str(enable_mocking).lower()
|
enable_mocking = str(enable_mocking).lower()
|
||||||
self.mock = enable_mocking in ("true", "1", "yes")
|
self.mock = enable_mocking in ("true", "1", "yes")
|
||||||
|
|
||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
|
|
@ -43,7 +45,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
async def exponential_backoff(attempt):
|
async def exponential_backoff(attempt):
|
||||||
wait_time = min(10 * (2 ** attempt), 60) # Max 60 seconds
|
wait_time = min(10 * (2 ** attempt), 60) # Max 60 seconds
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.mock:
|
if self.mock:
|
||||||
response = {
|
response = {
|
||||||
|
|
@ -56,10 +58,10 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
else:
|
else:
|
||||||
response = await litellm.aembedding(
|
response = await litellm.aembedding(
|
||||||
self.model,
|
self.model,
|
||||||
input = text,
|
input=text,
|
||||||
api_key = self.api_key,
|
api_key=self.api_key,
|
||||||
api_base = self.endpoint,
|
api_base=self.endpoint,
|
||||||
api_version = self.api_version
|
api_version=self.api_version
|
||||||
)
|
)
|
||||||
|
|
||||||
self.retry_count = 0
|
self.retry_count = 0
|
||||||
|
|
@ -71,7 +73,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
if len(text) == 1:
|
if len(text) == 1:
|
||||||
parts = [text]
|
parts = [text]
|
||||||
else:
|
else:
|
||||||
parts = [text[0:math.ceil(len(text)/2)], text[math.ceil(len(text)/2):]]
|
parts = [text[0:math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2):]]
|
||||||
|
|
||||||
parts_futures = [self.embed_text(part) for part in parts]
|
parts_futures = [self.embed_text(part) for part in parts]
|
||||||
embeddings = await asyncio.gather(*parts_futures)
|
embeddings = await asyncio.gather(*parts_futures)
|
||||||
|
|
@ -95,6 +97,9 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
|
|
||||||
return await self.embed_text(text)
|
return await self.embed_text(text)
|
||||||
|
|
||||||
|
except (litellm.exceptions.BadRequestError, litellm.llms.OpenAI.openai.OpenAIError):
|
||||||
|
raise EmbeddingException("Failed to index data points.")
|
||||||
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error("Error embedding text: %s", str(error))
|
logger.error("Error embedding text: %s", str(error))
|
||||||
raise error
|
raise error
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -7,7 +6,7 @@ class Repository(DataPoint):
|
||||||
__tablename__ = "Repository"
|
__tablename__ = "Repository"
|
||||||
path: str
|
path: str
|
||||||
_metadata: dict = {
|
_metadata: dict = {
|
||||||
"index_fields": ["source_code"],
|
"index_fields": [],
|
||||||
"type": "Repository"
|
"type": "Repository"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -19,29 +18,31 @@ class CodeFile(DataPoint):
|
||||||
depends_on: Optional[List["CodeFile"]] = None
|
depends_on: Optional[List["CodeFile"]] = None
|
||||||
depends_directly_on: Optional[List["CodeFile"]] = None
|
depends_directly_on: Optional[List["CodeFile"]] = None
|
||||||
contains: Optional[List["CodePart"]] = None
|
contains: Optional[List["CodePart"]] = None
|
||||||
|
|
||||||
_metadata: dict = {
|
_metadata: dict = {
|
||||||
"index_fields": ["source_code"],
|
"index_fields": [],
|
||||||
"type": "CodeFile"
|
"type": "CodeFile"
|
||||||
}
|
}
|
||||||
|
|
||||||
class CodePart(DataPoint):
|
class CodePart(DataPoint):
|
||||||
__tablename__ = "codepart"
|
__tablename__ = "codepart"
|
||||||
# part_of: Optional[CodeFile]
|
# part_of: Optional[CodeFile] = None
|
||||||
source_code: str
|
source_code: Optional[str] = None
|
||||||
|
|
||||||
_metadata: dict = {
|
_metadata: dict = {
|
||||||
"index_fields": ["source_code"],
|
"index_fields": [],
|
||||||
"type": "CodePart"
|
"type": "CodePart"
|
||||||
}
|
}
|
||||||
|
|
||||||
class CodeRelationship(DataPoint):
|
class SourceCodeChunk(DataPoint):
|
||||||
source_id: str
|
__tablename__ = "sourcecodechunk"
|
||||||
target_id: str
|
code_chunk_of: Optional[CodePart] = None
|
||||||
relation: str # depends on or depends directly
|
source_code: Optional[str] = None
|
||||||
|
previous_chunk: Optional["SourceCodeChunk"] = None
|
||||||
|
|
||||||
_metadata: dict = {
|
_metadata: dict = {
|
||||||
"type": "CodeRelationship"
|
"index_fields": ["source_code"],
|
||||||
|
"type": "SourceCodeChunk"
|
||||||
}
|
}
|
||||||
|
|
||||||
CodeFile.model_rebuild()
|
CodeFile.model_rebuild()
|
||||||
CodePart.model_rebuild()
|
CodePart.model_rebuild()
|
||||||
|
SourceCodeChunk.model_rebuild()
|
||||||
|
|
@ -210,7 +210,6 @@ class SummarizedClass(BaseModel):
|
||||||
decorators: Optional[List[str]] = None
|
decorators: Optional[List[str]] = None
|
||||||
|
|
||||||
class SummarizedCode(BaseModel):
|
class SummarizedCode(BaseModel):
|
||||||
file_name: str
|
|
||||||
high_level_summary: str
|
high_level_summary: str
|
||||||
key_features: List[str]
|
key_features: List[str]
|
||||||
imports: List[str] = []
|
imports: List[str] = []
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
|
||||||
path = repo_path,
|
path = repo_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield repo
|
yield [repo]
|
||||||
|
|
||||||
with ProcessPoolExecutor(max_workers = 12) as executor:
|
with ProcessPoolExecutor(max_workers = 12) as executor:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
@ -90,10 +90,11 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
|
||||||
|
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
code_files = []
|
||||||
for (file_path, metadata), dependencies in zip(py_files_dict.items(), results):
|
for (file_path, metadata), dependencies in zip(py_files_dict.items(), results):
|
||||||
source_code = metadata.get("source_code")
|
source_code = metadata.get("source_code")
|
||||||
|
|
||||||
yield CodeFile(
|
code_files.append(CodeFile(
|
||||||
id = uuid5(NAMESPACE_OID, file_path),
|
id = uuid5(NAMESPACE_OID, file_path),
|
||||||
source_code = source_code,
|
source_code = source_code,
|
||||||
extracted_id = file_path,
|
extracted_id = file_path,
|
||||||
|
|
@ -106,4 +107,6 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
|
||||||
source_code = py_files_dict.get(dependency, {}).get("source_code"),
|
source_code = py_files_dict.get(dependency, {}).get("source_code"),
|
||||||
) for dependency in dependencies
|
) for dependency in dependencies
|
||||||
] if dependencies else None,
|
] if dependencies else None,
|
||||||
)
|
))
|
||||||
|
|
||||||
|
yield code_files
|
||||||
|
|
|
||||||
164
cognee/tasks/repo_processor/get_source_code_chunks.py
Normal file
164
cognee/tasks/repo_processor/get_source_code_chunks.py
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
import logging
|
||||||
|
from typing import AsyncGenerator, Generator
|
||||||
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
|
import parso
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
|
||||||
|
|
||||||
|
logger = logging.getLogger("task:get_source_code_chunks")
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens(tokenizer: tiktoken.Encoding, source_code: str) -> int:
|
||||||
|
return len(tokenizer.encode(source_code))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_naive_subchunk_token_counts(
|
||||||
|
tokenizer: tiktoken.Encoding, source_code: str, max_subchunk_tokens: int = 8000
|
||||||
|
) -> list[tuple[str, int]]:
|
||||||
|
"""Splits source code into subchunks of up to max_subchunk_tokens and counts tokens."""
|
||||||
|
|
||||||
|
token_ids = tokenizer.encode(source_code)
|
||||||
|
subchunk_token_counts = []
|
||||||
|
|
||||||
|
for start_idx in range(0, len(token_ids), max_subchunk_tokens):
|
||||||
|
subchunk_token_ids = token_ids[start_idx: start_idx + max_subchunk_tokens]
|
||||||
|
token_count = len(subchunk_token_ids)
|
||||||
|
subchunk = ''.join(
|
||||||
|
tokenizer.decode_single_token_bytes(token_id).decode('utf-8', errors='replace')
|
||||||
|
for token_id in subchunk_token_ids
|
||||||
|
)
|
||||||
|
subchunk_token_counts.append((subchunk, token_count))
|
||||||
|
|
||||||
|
return subchunk_token_counts
|
||||||
|
|
||||||
|
|
||||||
|
def _get_subchunk_token_counts(
|
||||||
|
tokenizer: tiktoken.Encoding,
|
||||||
|
source_code: str,
|
||||||
|
max_subchunk_tokens: int = 8000,
|
||||||
|
depth: int = 0,
|
||||||
|
max_depth: int = 100
|
||||||
|
) -> list[tuple[str, int]]:
|
||||||
|
"""Splits source code into subchunk and counts tokens for each subchunk."""
|
||||||
|
if depth > max_depth:
|
||||||
|
return _get_naive_subchunk_token_counts(tokenizer, source_code, max_subchunk_tokens)
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = parso.parse(source_code)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing source code: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not module.children:
|
||||||
|
logger.warning("Parsed module has no children (empty or invalid source code).")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Handle cases with only one real child and an EndMarker to prevent infinite recursion.
|
||||||
|
if len(module.children) <= 2:
|
||||||
|
module = module.children[0]
|
||||||
|
|
||||||
|
subchunk_token_counts = []
|
||||||
|
for child in module.children:
|
||||||
|
subchunk = child.get_code()
|
||||||
|
token_count = _count_tokens(tokenizer, subchunk)
|
||||||
|
|
||||||
|
if token_count == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token_count <= max_subchunk_tokens:
|
||||||
|
subchunk_token_counts.append((subchunk, token_count))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if child.type == 'string':
|
||||||
|
subchunk_token_counts.extend(_get_naive_subchunk_token_counts(tokenizer, subchunk, max_subchunk_tokens))
|
||||||
|
continue
|
||||||
|
|
||||||
|
subchunk_token_counts.extend(
|
||||||
|
_get_subchunk_token_counts(tokenizer, subchunk, max_subchunk_tokens, depth=depth + 1, max_depth=max_depth)
|
||||||
|
)
|
||||||
|
|
||||||
|
return subchunk_token_counts
|
||||||
|
|
||||||
|
|
||||||
|
def _get_chunk_source_code(
|
||||||
|
code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int
|
||||||
|
) -> tuple[list[tuple[str, int]], str]:
|
||||||
|
"""Generates a chunk of source code from tokenized subchunks with overlap handling."""
|
||||||
|
current_count = 0
|
||||||
|
cumulative_counts = []
|
||||||
|
current_source_code = ''
|
||||||
|
|
||||||
|
for i, (child_code, token_count) in enumerate(code_token_counts):
|
||||||
|
current_count += token_count
|
||||||
|
cumulative_counts.append(current_count)
|
||||||
|
if current_count > max_tokens:
|
||||||
|
break
|
||||||
|
current_source_code += f"\n{child_code}"
|
||||||
|
|
||||||
|
if current_count <= max_tokens:
|
||||||
|
return [], current_source_code.strip()
|
||||||
|
|
||||||
|
cutoff = 1
|
||||||
|
for i, cum_count in enumerate(cumulative_counts):
|
||||||
|
if cum_count > (1 - overlap) * max_tokens:
|
||||||
|
break
|
||||||
|
cutoff = i
|
||||||
|
|
||||||
|
return code_token_counts[cutoff:], current_source_code.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def get_source_code_chunks_from_code_part(
|
||||||
|
code_file_part: CodePart,
|
||||||
|
max_tokens: int = 8192,
|
||||||
|
overlap: float = 0.25,
|
||||||
|
granularity: float = 0.1,
|
||||||
|
model_name: str = "text-embedding-3-large"
|
||||||
|
) -> Generator[SourceCodeChunk, None, None]:
|
||||||
|
"""Yields source code chunks from a CodePart object, with configurable token limits and overlap."""
|
||||||
|
if not code_file_part.source_code:
|
||||||
|
logger.error(f"No source code in CodeFile {code_file_part.id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
tokenizer = tiktoken.encoding_for_model(model_name)
|
||||||
|
max_subchunk_tokens = max(1, int(granularity * max_tokens))
|
||||||
|
subchunk_token_counts = _get_subchunk_token_counts(tokenizer, code_file_part.source_code, max_subchunk_tokens)
|
||||||
|
|
||||||
|
previous_chunk = None
|
||||||
|
while subchunk_token_counts:
|
||||||
|
subchunk_token_counts, chunk_source_code = _get_chunk_source_code(subchunk_token_counts, overlap, max_tokens)
|
||||||
|
if not chunk_source_code:
|
||||||
|
continue
|
||||||
|
current_chunk = SourceCodeChunk(
|
||||||
|
id=uuid5(NAMESPACE_OID, chunk_source_code),
|
||||||
|
code_chunk_of=code_file_part,
|
||||||
|
source_code=chunk_source_code,
|
||||||
|
previous_chunk=previous_chunk
|
||||||
|
)
|
||||||
|
yield current_chunk
|
||||||
|
previous_chunk = current_chunk
|
||||||
|
|
||||||
|
|
||||||
|
async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="text-embedding-3-large") -> \
|
||||||
|
AsyncGenerator[list[DataPoint], None]:
|
||||||
|
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
|
||||||
|
# TODO: Add support for other embedding models, with max_token mapping
|
||||||
|
for data_point in data_points:
|
||||||
|
try:
|
||||||
|
yield data_point
|
||||||
|
if not isinstance(data_point, CodeFile):
|
||||||
|
continue
|
||||||
|
if not data_point.contains:
|
||||||
|
logger.warning(f"CodeFile {data_point.id} contains no code parts")
|
||||||
|
continue
|
||||||
|
for code_part in data_point.contains:
|
||||||
|
try:
|
||||||
|
yield code_part
|
||||||
|
for source_code_chunk in get_source_code_chunks_from_code_part(code_part, model_name=embedding_model):
|
||||||
|
yield source_code_chunk
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing code part: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing data point: {e}")
|
||||||
|
|
@ -1,6 +1,10 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
logger = logging.getLogger("index_data_points")
|
||||||
|
|
||||||
async def index_data_points(data_points: list[DataPoint]):
|
async def index_data_points(data_points: list[DataPoint]):
|
||||||
created_indexes = {}
|
created_indexes = {}
|
||||||
|
|
@ -30,7 +34,10 @@ async def index_data_points(data_points: list[DataPoint]):
|
||||||
|
|
||||||
for index_name, indexable_points in index_points.items():
|
for index_name, indexable_points in index_points.items():
|
||||||
index_name, field_name = index_name.split(".")
|
index_name, field_name = index_name.split(".")
|
||||||
await vector_engine.index_data_points(index_name, field_name, indexable_points)
|
try:
|
||||||
|
await vector_engine.index_data_points(index_name, field_name, indexable_points)
|
||||||
|
except EmbeddingException as e:
|
||||||
|
logger.warning(f"Failed to index data points for {index_name}.{field_name}: {e}")
|
||||||
|
|
||||||
return data_points
|
return data_points
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.chunking.models import DocumentChunk
|
from cognee.modules.chunking.models import DocumentChunk
|
||||||
from cognee.shared.CodeGraphEntities import CodeFile
|
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
|
||||||
|
|
||||||
|
|
||||||
class TextSummary(DataPoint):
|
class TextSummary(DataPoint):
|
||||||
|
|
@ -17,7 +19,7 @@ class TextSummary(DataPoint):
|
||||||
class CodeSummary(DataPoint):
|
class CodeSummary(DataPoint):
|
||||||
__tablename__ = "code_summary"
|
__tablename__ = "code_summary"
|
||||||
text: str
|
text: str
|
||||||
made_from: CodeFile
|
summarizes: Union[CodeFile, CodePart, SourceCodeChunk]
|
||||||
|
|
||||||
_metadata: dict = {
|
_metadata: dict = {
|
||||||
"index_fields": ["text"],
|
"index_fields": ["text"],
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import AsyncGenerator, Union
|
from typing import AsyncGenerator, Union
|
||||||
from uuid import uuid5
|
from uuid import uuid5
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.data.extraction.extract_summary import extract_code_summary
|
from cognee.modules.data.extraction.extract_summary import extract_code_summary
|
||||||
|
|
||||||
from .models import CodeSummary
|
from .models import CodeSummary
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,7 +21,7 @@ async def summarize_code(
|
||||||
)
|
)
|
||||||
|
|
||||||
file_summaries_map = {
|
file_summaries_map = {
|
||||||
code_data_point.extracted_id: str(file_summary)
|
code_data_point.id: str(file_summary)
|
||||||
for code_data_point, file_summary in zip(code_data_points, file_summaries)
|
for code_data_point, file_summary in zip(code_data_points, file_summaries)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -35,6 +35,6 @@ async def summarize_code(
|
||||||
|
|
||||||
yield CodeSummary(
|
yield CodeSummary(
|
||||||
id=uuid5(node.id, "CodeSummary"),
|
id=uuid5(node.id, "CodeSummary"),
|
||||||
made_from=node,
|
summarizes=node,
|
||||||
text=file_summaries_map[node.extracted_id],
|
text=file_summaries_map[node.id],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,6 @@ async def main(repo_path, include_docs):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
|
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
|
||||||
parser.add_argument("--include_docs", type=bool, default=True, help="Whether or not to process non-code files")
|
parser.add_argument("--include_docs", type=lambda x: x.lower() in ("true", "1"), default=True, help="Whether or not to process non-code files")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
asyncio.run(main(args.repo_path, args.include_docs))
|
asyncio.run(main(args.repo_path, args.include_docs))
|
||||||
Loading…
Add table
Reference in a new issue