Merge branch 'dev' into COG-650-replace-pylint

This commit is contained in:
Vasilije 2024-12-26 21:02:46 +01:00 committed by GitHub
commit c8fdbb45c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 232 additions and 38 deletions

View file

@ -17,7 +17,9 @@ Try it in a Google Colab <a href="https://colab.research.google.com/drive/1g-Qn
If you have questions, join our <a href="https://discord.gg/NQPKmU5CCg">Discord</a> community If you have questions, join our <a href="https://discord.gg/NQPKmU5CCg">Discord</a> community
<div align="center">
<img src="assets/cognee_benefits.png" alt="why cognee" width="80%" />
</div>
## 📦 Installation ## 📦 Installation
You can install Cognee using either **pip** or **poetry**. You can install Cognee using either **pip** or **poetry**.

BIN
assets/cognee_benefits.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 353 KiB

View file

@ -3,6 +3,8 @@ import logging
from pathlib import Path from pathlib import Path
from cognee.base_config import get_base_config from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.embeddings import \
get_embedding_engine
from cognee.modules.cognify.config import get_cognify_config from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task from cognee.modules.pipelines.tasks.Task import Task
@ -15,8 +17,10 @@ from cognee.tasks.ingestion import ingest_data_with_metadata
from cognee.tasks.repo_processor import (enrich_dependency_graph, from cognee.tasks.repo_processor import (enrich_dependency_graph,
expand_dependency_graph, expand_dependency_graph,
get_data_list_for_user, get_data_list_for_user,
get_non_code_files, get_non_py_files,
get_repo_file_dependencies) get_repo_file_dependencies)
from cognee.tasks.repo_processor.get_source_code_chunks import \
get_source_code_chunks
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
monitoring = get_base_config().monitoring_tool monitoring = get_base_config().monitoring_tool
@ -28,6 +32,7 @@ from cognee.tasks.summarization import summarize_code, summarize_text
logger = logging.getLogger("code_graph_pipeline") logger = logging.getLogger("code_graph_pipeline")
update_status_lock = asyncio.Lock() update_status_lock = asyncio.Lock()
@observe @observe
async def run_code_graph_pipeline(repo_path, include_docs=True): async def run_code_graph_pipeline(repo_path, include_docs=True):
import os import os
@ -46,20 +51,23 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
await create_db_and_tables() await create_db_and_tables()
embedding_engine = get_embedding_engine()
cognee_config = get_cognify_config() cognee_config = get_cognify_config()
user = await get_default_user() user = await get_default_user()
tasks = [ tasks = [
Task(get_repo_file_dependencies), Task(get_repo_file_dependencies),
Task(enrich_dependency_graph, task_config={"batch_size": 50}), Task(enrich_dependency_graph),
Task(expand_dependency_graph, task_config={"batch_size": 50}), Task(expand_dependency_graph, task_config={"batch_size": 50}),
Task(get_source_code_chunks, embedding_model=embedding_engine.model, task_config={"batch_size": 50}),
Task(summarize_code, task_config={"batch_size": 50}), Task(summarize_code, task_config={"batch_size": 50}),
Task(add_data_points, task_config={"batch_size": 50}), Task(add_data_points, task_config={"batch_size": 50}),
] ]
if include_docs: if include_docs:
non_code_tasks = [ non_code_tasks = [
Task(get_non_code_files, task_config={"batch_size": 50}), Task(get_non_py_files, task_config={"batch_size": 50}),
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user), Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents), Task(classify_documents),
@ -71,7 +79,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
task_config={"batch_size": 50} task_config={"batch_size": 50}
), ),
] ]
if include_docs: if include_docs:
async for result in run_tasks(non_code_tasks, repo_path): async for result in run_tasks(non_code_tasks, repo_path):
yield result yield result

View file

@ -0,0 +1,3 @@
class EmbeddingException(Exception):
"""Custom exception for handling embedding-related errors."""
pass

View file

@ -5,17 +5,19 @@ from typing import List, Optional
import litellm import litellm
import os import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
litellm.set_verbose = False litellm.set_verbose = False
logger = logging.getLogger("LiteLLMEmbeddingEngine") logger = logging.getLogger("LiteLLMEmbeddingEngine")
class LiteLLMEmbeddingEngine(EmbeddingEngine): class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str api_key: str
endpoint: str endpoint: str
api_version: str api_version: str
model: str model: str
dimensions: int dimensions: int
mock:bool mock: bool
def __init__( def __init__(
self, self,
@ -33,7 +35,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
enable_mocking = os.getenv("MOCK_EMBEDDING", "false") enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool): if isinstance(enable_mocking, bool):
enable_mocking= str(enable_mocking).lower() enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes") self.mock = enable_mocking in ("true", "1", "yes")
MAX_RETRIES = 5 MAX_RETRIES = 5
@ -43,7 +45,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
async def exponential_backoff(attempt): async def exponential_backoff(attempt):
wait_time = min(10 * (2 ** attempt), 60) # Max 60 seconds wait_time = min(10 * (2 ** attempt), 60) # Max 60 seconds
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
try: try:
if self.mock: if self.mock:
response = { response = {
@ -56,10 +58,10 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
else: else:
response = await litellm.aembedding( response = await litellm.aembedding(
self.model, self.model,
input = text, input=text,
api_key = self.api_key, api_key=self.api_key,
api_base = self.endpoint, api_base=self.endpoint,
api_version = self.api_version api_version=self.api_version
) )
self.retry_count = 0 self.retry_count = 0
@ -71,7 +73,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
if len(text) == 1: if len(text) == 1:
parts = [text] parts = [text]
else: else:
parts = [text[0:math.ceil(len(text)/2)], text[math.ceil(len(text)/2):]] parts = [text[0:math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2):]]
parts_futures = [self.embed_text(part) for part in parts] parts_futures = [self.embed_text(part) for part in parts]
embeddings = await asyncio.gather(*parts_futures) embeddings = await asyncio.gather(*parts_futures)
@ -95,6 +97,9 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
return await self.embed_text(text) return await self.embed_text(text)
except (litellm.exceptions.BadRequestError, litellm.llms.OpenAI.openai.OpenAIError):
raise EmbeddingException("Failed to index data points.")
except Exception as error: except Exception as error:
logger.error("Error embedding text: %s", str(error)) logger.error("Error embedding text: %s", str(error))
raise error raise error

View file

@ -1,5 +1,4 @@
from typing import List, Optional from typing import List, Optional
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -7,7 +6,7 @@ class Repository(DataPoint):
__tablename__ = "Repository" __tablename__ = "Repository"
path: str path: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["source_code"], "index_fields": [],
"type": "Repository" "type": "Repository"
} }
@ -19,29 +18,31 @@ class CodeFile(DataPoint):
depends_on: Optional[List["CodeFile"]] = None depends_on: Optional[List["CodeFile"]] = None
depends_directly_on: Optional[List["CodeFile"]] = None depends_directly_on: Optional[List["CodeFile"]] = None
contains: Optional[List["CodePart"]] = None contains: Optional[List["CodePart"]] = None
_metadata: dict = { _metadata: dict = {
"index_fields": ["source_code"], "index_fields": [],
"type": "CodeFile" "type": "CodeFile"
} }
class CodePart(DataPoint): class CodePart(DataPoint):
__tablename__ = "codepart" __tablename__ = "codepart"
# part_of: Optional[CodeFile] # part_of: Optional[CodeFile] = None
source_code: str source_code: Optional[str] = None
_metadata: dict = { _metadata: dict = {
"index_fields": ["source_code"], "index_fields": [],
"type": "CodePart" "type": "CodePart"
} }
class CodeRelationship(DataPoint): class SourceCodeChunk(DataPoint):
source_id: str __tablename__ = "sourcecodechunk"
target_id: str code_chunk_of: Optional[CodePart] = None
relation: str # depends on or depends directly source_code: Optional[str] = None
previous_chunk: Optional["SourceCodeChunk"] = None
_metadata: dict = { _metadata: dict = {
"type": "CodeRelationship" "index_fields": ["source_code"],
"type": "SourceCodeChunk"
} }
CodeFile.model_rebuild() CodeFile.model_rebuild()
CodePart.model_rebuild() CodePart.model_rebuild()
SourceCodeChunk.model_rebuild()

View file

@ -210,7 +210,6 @@ class SummarizedClass(BaseModel):
decorators: Optional[List[str]] = None decorators: Optional[List[str]] = None
class SummarizedCode(BaseModel): class SummarizedCode(BaseModel):
file_name: str
high_level_summary: str high_level_summary: str
key_features: List[str] key_features: List[str]
imports: List[str] = [] imports: List[str] = []

View file

@ -71,7 +71,7 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
path = repo_path, path = repo_path,
) )
yield repo yield [repo]
with ProcessPoolExecutor(max_workers = 12) as executor: with ProcessPoolExecutor(max_workers = 12) as executor:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -90,10 +90,11 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
code_files = []
for (file_path, metadata), dependencies in zip(py_files_dict.items(), results): for (file_path, metadata), dependencies in zip(py_files_dict.items(), results):
source_code = metadata.get("source_code") source_code = metadata.get("source_code")
yield CodeFile( code_files.append(CodeFile(
id = uuid5(NAMESPACE_OID, file_path), id = uuid5(NAMESPACE_OID, file_path),
source_code = source_code, source_code = source_code,
extracted_id = file_path, extracted_id = file_path,
@ -106,4 +107,6 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
source_code = py_files_dict.get(dependency, {}).get("source_code"), source_code = py_files_dict.get(dependency, {}).get("source_code"),
) for dependency in dependencies ) for dependency in dependencies
] if dependencies else None, ] if dependencies else None,
) ))
yield code_files

View file

@ -0,0 +1,164 @@
import logging
from typing import AsyncGenerator, Generator
from uuid import NAMESPACE_OID, uuid5
import parso
import tiktoken
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
logger = logging.getLogger("task:get_source_code_chunks")
def _count_tokens(tokenizer: tiktoken.Encoding, source_code: str) -> int:
return len(tokenizer.encode(source_code))
def _get_naive_subchunk_token_counts(
tokenizer: tiktoken.Encoding, source_code: str, max_subchunk_tokens: int = 8000
) -> list[tuple[str, int]]:
"""Splits source code into subchunks of up to max_subchunk_tokens and counts tokens."""
token_ids = tokenizer.encode(source_code)
subchunk_token_counts = []
for start_idx in range(0, len(token_ids), max_subchunk_tokens):
subchunk_token_ids = token_ids[start_idx: start_idx + max_subchunk_tokens]
token_count = len(subchunk_token_ids)
subchunk = ''.join(
tokenizer.decode_single_token_bytes(token_id).decode('utf-8', errors='replace')
for token_id in subchunk_token_ids
)
subchunk_token_counts.append((subchunk, token_count))
return subchunk_token_counts
def _get_subchunk_token_counts(
tokenizer: tiktoken.Encoding,
source_code: str,
max_subchunk_tokens: int = 8000,
depth: int = 0,
max_depth: int = 100
) -> list[tuple[str, int]]:
"""Splits source code into subchunk and counts tokens for each subchunk."""
if depth > max_depth:
return _get_naive_subchunk_token_counts(tokenizer, source_code, max_subchunk_tokens)
try:
module = parso.parse(source_code)
except Exception as e:
logger.error(f"Error parsing source code: {e}")
return []
if not module.children:
logger.warning("Parsed module has no children (empty or invalid source code).")
return []
# Handle cases with only one real child and an EndMarker to prevent infinite recursion.
if len(module.children) <= 2:
module = module.children[0]
subchunk_token_counts = []
for child in module.children:
subchunk = child.get_code()
token_count = _count_tokens(tokenizer, subchunk)
if token_count == 0:
continue
if token_count <= max_subchunk_tokens:
subchunk_token_counts.append((subchunk, token_count))
continue
if child.type == 'string':
subchunk_token_counts.extend(_get_naive_subchunk_token_counts(tokenizer, subchunk, max_subchunk_tokens))
continue
subchunk_token_counts.extend(
_get_subchunk_token_counts(tokenizer, subchunk, max_subchunk_tokens, depth=depth + 1, max_depth=max_depth)
)
return subchunk_token_counts
def _get_chunk_source_code(
code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int
) -> tuple[list[tuple[str, int]], str]:
"""Generates a chunk of source code from tokenized subchunks with overlap handling."""
current_count = 0
cumulative_counts = []
current_source_code = ''
for i, (child_code, token_count) in enumerate(code_token_counts):
current_count += token_count
cumulative_counts.append(current_count)
if current_count > max_tokens:
break
current_source_code += f"\n{child_code}"
if current_count <= max_tokens:
return [], current_source_code.strip()
cutoff = 1
for i, cum_count in enumerate(cumulative_counts):
if cum_count > (1 - overlap) * max_tokens:
break
cutoff = i
return code_token_counts[cutoff:], current_source_code.strip()
def get_source_code_chunks_from_code_part(
code_file_part: CodePart,
max_tokens: int = 8192,
overlap: float = 0.25,
granularity: float = 0.1,
model_name: str = "text-embedding-3-large"
) -> Generator[SourceCodeChunk, None, None]:
"""Yields source code chunks from a CodePart object, with configurable token limits and overlap."""
if not code_file_part.source_code:
logger.error(f"No source code in CodeFile {code_file_part.id}")
return
tokenizer = tiktoken.encoding_for_model(model_name)
max_subchunk_tokens = max(1, int(granularity * max_tokens))
subchunk_token_counts = _get_subchunk_token_counts(tokenizer, code_file_part.source_code, max_subchunk_tokens)
previous_chunk = None
while subchunk_token_counts:
subchunk_token_counts, chunk_source_code = _get_chunk_source_code(subchunk_token_counts, overlap, max_tokens)
if not chunk_source_code:
continue
current_chunk = SourceCodeChunk(
id=uuid5(NAMESPACE_OID, chunk_source_code),
code_chunk_of=code_file_part,
source_code=chunk_source_code,
previous_chunk=previous_chunk
)
yield current_chunk
previous_chunk = current_chunk
async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="text-embedding-3-large") -> \
AsyncGenerator[list[DataPoint], None]:
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
# TODO: Add support for other embedding models, with max_token mapping
for data_point in data_points:
try:
yield data_point
if not isinstance(data_point, CodeFile):
continue
if not data_point.contains:
logger.warning(f"CodeFile {data_point.id} contains no code parts")
continue
for code_part in data_point.contains:
try:
yield code_part
for source_code_chunk in get_source_code_chunks_from_code_part(code_part, model_name=embedding_model):
yield source_code_chunk
except Exception as e:
logger.error(f"Error processing code part: {e}")
except Exception as e:
logger.error(f"Error processing data point: {e}")

View file

@ -1,6 +1,10 @@
import logging
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
logger = logging.getLogger("index_data_points")
async def index_data_points(data_points: list[DataPoint]): async def index_data_points(data_points: list[DataPoint]):
created_indexes = {} created_indexes = {}
@ -30,7 +34,10 @@ async def index_data_points(data_points: list[DataPoint]):
for index_name, indexable_points in index_points.items(): for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".") index_name, field_name = index_name.split(".")
await vector_engine.index_data_points(index_name, field_name, indexable_points) try:
await vector_engine.index_data_points(index_name, field_name, indexable_points)
except EmbeddingException as e:
logger.warning(f"Failed to index data points for {index_name}.{field_name}: {e}")
return data_points return data_points

View file

@ -1,6 +1,8 @@
from typing import Union
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models import DocumentChunk from cognee.modules.chunking.models import DocumentChunk
from cognee.shared.CodeGraphEntities import CodeFile from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
class TextSummary(DataPoint): class TextSummary(DataPoint):
@ -17,7 +19,7 @@ class TextSummary(DataPoint):
class CodeSummary(DataPoint): class CodeSummary(DataPoint):
__tablename__ = "code_summary" __tablename__ = "code_summary"
text: str text: str
made_from: CodeFile summarizes: Union[CodeFile, CodePart, SourceCodeChunk]
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"], "index_fields": ["text"],

View file

@ -1,10 +1,10 @@
import asyncio import asyncio
from typing import AsyncGenerator, Union from typing import AsyncGenerator, Union
from uuid import uuid5 from uuid import uuid5
from typing import Type
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.extraction.extract_summary import extract_code_summary from cognee.modules.data.extraction.extract_summary import extract_code_summary
from .models import CodeSummary from .models import CodeSummary
@ -21,7 +21,7 @@ async def summarize_code(
) )
file_summaries_map = { file_summaries_map = {
code_data_point.extracted_id: str(file_summary) code_data_point.id: str(file_summary)
for code_data_point, file_summary in zip(code_data_points, file_summaries) for code_data_point, file_summary in zip(code_data_points, file_summaries)
} }
@ -35,6 +35,6 @@ async def summarize_code(
yield CodeSummary( yield CodeSummary(
id=uuid5(node.id, "CodeSummary"), id=uuid5(node.id, "CodeSummary"),
made_from=node, summarizes=node,
text=file_summaries_map[node.extracted_id], text=file_summaries_map[node.id],
) )

View file

@ -11,6 +11,6 @@ async def main(repo_path, include_docs):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository") parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
parser.add_argument("--include_docs", type=bool, default=True, help="Whether or not to process non-code files") parser.add_argument("--include_docs", type=lambda x: x.lower() in ("true", "1"), default=True, help="Whether or not to process non-code files")
args = parser.parse_args() args = parser.parse_args()
asyncio.run(main(args.repo_path, args.include_docs)) asyncio.run(main(args.repo_path, args.include_docs))