diff --git a/README.md b/README.md
index 40005fc09..727678b9f 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,9 @@ Try it in a Google Colab Discord community
-
+
+

+
## 📦 Installation
You can install Cognee using either **pip** or **poetry**.
diff --git a/assets/cognee_benefits.png b/assets/cognee_benefits.png
new file mode 100644
index 000000000..d435bed05
Binary files /dev/null and b/assets/cognee_benefits.png differ
diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py
index 8e92d08e0..3d31b4000 100644
--- a/cognee/api/v1/cognify/code_graph_pipeline.py
+++ b/cognee/api/v1/cognify/code_graph_pipeline.py
@@ -3,6 +3,8 @@ import logging
from pathlib import Path
from cognee.base_config import get_base_config
+from cognee.infrastructure.databases.vector.embeddings import \
+ get_embedding_engine
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
@@ -15,8 +17,10 @@ from cognee.tasks.ingestion import ingest_data_with_metadata
from cognee.tasks.repo_processor import (enrich_dependency_graph,
expand_dependency_graph,
get_data_list_for_user,
- get_non_code_files,
+ get_non_py_files,
get_repo_file_dependencies)
+from cognee.tasks.repo_processor.get_source_code_chunks import \
+ get_source_code_chunks
from cognee.tasks.storage import add_data_points
monitoring = get_base_config().monitoring_tool
@@ -28,6 +32,7 @@ from cognee.tasks.summarization import summarize_code, summarize_text
logger = logging.getLogger("code_graph_pipeline")
update_status_lock = asyncio.Lock()
+
@observe
async def run_code_graph_pipeline(repo_path, include_docs=True):
import os
@@ -46,20 +51,23 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
await cognee.prune.prune_system(metadata=True)
await create_db_and_tables()
+ embedding_engine = get_embedding_engine()
+
cognee_config = get_cognify_config()
user = await get_default_user()
tasks = [
Task(get_repo_file_dependencies),
- Task(enrich_dependency_graph, task_config={"batch_size": 50}),
+ Task(enrich_dependency_graph),
Task(expand_dependency_graph, task_config={"batch_size": 50}),
+ Task(get_source_code_chunks, embedding_model=embedding_engine.model, task_config={"batch_size": 50}),
Task(summarize_code, task_config={"batch_size": 50}),
Task(add_data_points, task_config={"batch_size": 50}),
]
if include_docs:
non_code_tasks = [
- Task(get_non_code_files, task_config={"batch_size": 50}),
+ Task(get_non_py_files, task_config={"batch_size": 50}),
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents),
@@ -71,7 +79,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
task_config={"batch_size": 50}
),
]
-
+
if include_docs:
async for result in run_tasks(non_code_tasks, repo_path):
yield result
diff --git a/cognee/infrastructure/databases/exceptions/EmbeddingException.py b/cognee/infrastructure/databases/exceptions/EmbeddingException.py
new file mode 100644
index 000000000..ba7c70d80
--- /dev/null
+++ b/cognee/infrastructure/databases/exceptions/EmbeddingException.py
@@ -0,0 +1,3 @@
+class EmbeddingException(Exception):
+ """Custom exception for handling embedding-related errors."""
+ pass
\ No newline at end of file
diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py
index dce12b318..93f59cc77 100644
--- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py
+++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py
@@ -5,17 +5,19 @@ from typing import List, Optional
import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
+from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
litellm.set_verbose = False
logger = logging.getLogger("LiteLLMEmbeddingEngine")
+
class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str
endpoint: str
api_version: str
model: str
dimensions: int
- mock:bool
+ mock: bool
def __init__(
self,
@@ -33,7 +35,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
- enable_mocking= str(enable_mocking).lower()
+ enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
MAX_RETRIES = 5
@@ -43,7 +45,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
async def exponential_backoff(attempt):
wait_time = min(10 * (2 ** attempt), 60) # Max 60 seconds
await asyncio.sleep(wait_time)
-
+
try:
if self.mock:
response = {
@@ -56,10 +58,10 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
else:
response = await litellm.aembedding(
self.model,
- input = text,
- api_key = self.api_key,
- api_base = self.endpoint,
- api_version = self.api_version
+ input=text,
+ api_key=self.api_key,
+ api_base=self.endpoint,
+ api_version=self.api_version
)
self.retry_count = 0
@@ -71,7 +73,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
if len(text) == 1:
parts = [text]
else:
- parts = [text[0:math.ceil(len(text)/2)], text[math.ceil(len(text)/2):]]
+ parts = [text[0:math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2):]]
parts_futures = [self.embed_text(part) for part in parts]
embeddings = await asyncio.gather(*parts_futures)
@@ -95,6 +97,9 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
return await self.embed_text(text)
+ except (litellm.exceptions.BadRequestError, litellm.llms.OpenAI.openai.OpenAIError):
+ raise EmbeddingException("Failed to index data points.")
+
except Exception as error:
logger.error("Error embedding text: %s", str(error))
raise error
diff --git a/cognee/shared/CodeGraphEntities.py b/cognee/shared/CodeGraphEntities.py
index 23b8879c2..27289493d 100644
--- a/cognee/shared/CodeGraphEntities.py
+++ b/cognee/shared/CodeGraphEntities.py
@@ -1,5 +1,4 @@
from typing import List, Optional
-
from cognee.infrastructure.engine import DataPoint
@@ -7,7 +6,7 @@ class Repository(DataPoint):
__tablename__ = "Repository"
path: str
_metadata: dict = {
- "index_fields": ["source_code"],
+ "index_fields": [],
"type": "Repository"
}
@@ -19,29 +18,31 @@ class CodeFile(DataPoint):
depends_on: Optional[List["CodeFile"]] = None
depends_directly_on: Optional[List["CodeFile"]] = None
contains: Optional[List["CodePart"]] = None
-
_metadata: dict = {
- "index_fields": ["source_code"],
+ "index_fields": [],
"type": "CodeFile"
}
class CodePart(DataPoint):
__tablename__ = "codepart"
- # part_of: Optional[CodeFile]
- source_code: str
-
+ # part_of: Optional[CodeFile] = None
+ source_code: Optional[str] = None
_metadata: dict = {
- "index_fields": ["source_code"],
+ "index_fields": [],
"type": "CodePart"
}
-class CodeRelationship(DataPoint):
- source_id: str
- target_id: str
- relation: str # depends on or depends directly
+class SourceCodeChunk(DataPoint):
+ __tablename__ = "sourcecodechunk"
+ code_chunk_of: Optional[CodePart] = None
+ source_code: Optional[str] = None
+ previous_chunk: Optional["SourceCodeChunk"] = None
+
_metadata: dict = {
- "type": "CodeRelationship"
+ "index_fields": ["source_code"],
+ "type": "SourceCodeChunk"
}
CodeFile.model_rebuild()
CodePart.model_rebuild()
+SourceCodeChunk.model_rebuild()
\ No newline at end of file
diff --git a/cognee/shared/data_models.py b/cognee/shared/data_models.py
index dec53cfcb..2a8bc8c91 100644
--- a/cognee/shared/data_models.py
+++ b/cognee/shared/data_models.py
@@ -210,7 +210,6 @@ class SummarizedClass(BaseModel):
decorators: Optional[List[str]] = None
class SummarizedCode(BaseModel):
- file_name: str
high_level_summary: str
key_features: List[str]
imports: List[str] = []
diff --git a/cognee/tasks/repo_processor/get_repo_file_dependencies.py b/cognee/tasks/repo_processor/get_repo_file_dependencies.py
index 221af6cf6..b54c1f152 100644
--- a/cognee/tasks/repo_processor/get_repo_file_dependencies.py
+++ b/cognee/tasks/repo_processor/get_repo_file_dependencies.py
@@ -71,7 +71,7 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
path = repo_path,
)
- yield repo
+ yield [repo]
with ProcessPoolExecutor(max_workers = 12) as executor:
loop = asyncio.get_event_loop()
@@ -90,10 +90,11 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
results = await asyncio.gather(*tasks)
+ code_files = []
for (file_path, metadata), dependencies in zip(py_files_dict.items(), results):
source_code = metadata.get("source_code")
- yield CodeFile(
+ code_files.append(CodeFile(
id = uuid5(NAMESPACE_OID, file_path),
source_code = source_code,
extracted_id = file_path,
@@ -106,4 +107,6 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
source_code = py_files_dict.get(dependency, {}).get("source_code"),
) for dependency in dependencies
] if dependencies else None,
- )
+ ))
+
+ yield code_files
diff --git a/cognee/tasks/repo_processor/get_source_code_chunks.py b/cognee/tasks/repo_processor/get_source_code_chunks.py
new file mode 100644
index 000000000..4d0ce3200
--- /dev/null
+++ b/cognee/tasks/repo_processor/get_source_code_chunks.py
@@ -0,0 +1,164 @@
+import logging
+from typing import AsyncGenerator, Generator
+from uuid import NAMESPACE_OID, uuid5
+
+import parso
+import tiktoken
+
+from cognee.infrastructure.engine import DataPoint
+from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
+
+logger = logging.getLogger("task:get_source_code_chunks")
+
+
+def _count_tokens(tokenizer: tiktoken.Encoding, source_code: str) -> int:
+ return len(tokenizer.encode(source_code))
+
+
+def _get_naive_subchunk_token_counts(
+ tokenizer: tiktoken.Encoding, source_code: str, max_subchunk_tokens: int = 8000
+) -> list[tuple[str, int]]:
+ """Splits source code into subchunks of up to max_subchunk_tokens and counts tokens."""
+
+ token_ids = tokenizer.encode(source_code)
+ subchunk_token_counts = []
+
+ for start_idx in range(0, len(token_ids), max_subchunk_tokens):
+ subchunk_token_ids = token_ids[start_idx: start_idx + max_subchunk_tokens]
+ token_count = len(subchunk_token_ids)
+ subchunk = ''.join(
+ tokenizer.decode_single_token_bytes(token_id).decode('utf-8', errors='replace')
+ for token_id in subchunk_token_ids
+ )
+ subchunk_token_counts.append((subchunk, token_count))
+
+ return subchunk_token_counts
+
+
+def _get_subchunk_token_counts(
+ tokenizer: tiktoken.Encoding,
+ source_code: str,
+ max_subchunk_tokens: int = 8000,
+ depth: int = 0,
+ max_depth: int = 100
+) -> list[tuple[str, int]]:
+ """Splits source code into subchunk and counts tokens for each subchunk."""
+ if depth > max_depth:
+ return _get_naive_subchunk_token_counts(tokenizer, source_code, max_subchunk_tokens)
+
+ try:
+ module = parso.parse(source_code)
+ except Exception as e:
+ logger.error(f"Error parsing source code: {e}")
+ return []
+
+ if not module.children:
+ logger.warning("Parsed module has no children (empty or invalid source code).")
+ return []
+
+ # Handle cases with only one real child and an EndMarker to prevent infinite recursion.
+ if len(module.children) <= 2:
+ module = module.children[0]
+
+ subchunk_token_counts = []
+ for child in module.children:
+ subchunk = child.get_code()
+ token_count = _count_tokens(tokenizer, subchunk)
+
+ if token_count == 0:
+ continue
+
+ if token_count <= max_subchunk_tokens:
+ subchunk_token_counts.append((subchunk, token_count))
+ continue
+
+ if child.type == 'string':
+ subchunk_token_counts.extend(_get_naive_subchunk_token_counts(tokenizer, subchunk, max_subchunk_tokens))
+ continue
+
+ subchunk_token_counts.extend(
+ _get_subchunk_token_counts(tokenizer, subchunk, max_subchunk_tokens, depth=depth + 1, max_depth=max_depth)
+ )
+
+ return subchunk_token_counts
+
+
+def _get_chunk_source_code(
+ code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int
+) -> tuple[list[tuple[str, int]], str]:
+ """Generates a chunk of source code from tokenized subchunks with overlap handling."""
+ current_count = 0
+ cumulative_counts = []
+ current_source_code = ''
+
+ for i, (child_code, token_count) in enumerate(code_token_counts):
+ current_count += token_count
+ cumulative_counts.append(current_count)
+ if current_count > max_tokens:
+ break
+ current_source_code += f"\n{child_code}"
+
+ if current_count <= max_tokens:
+ return [], current_source_code.strip()
+
+ cutoff = 1
+ for i, cum_count in enumerate(cumulative_counts):
+ if cum_count > (1 - overlap) * max_tokens:
+ break
+ cutoff = i
+
+ return code_token_counts[cutoff:], current_source_code.strip()
+
+
+def get_source_code_chunks_from_code_part(
+ code_file_part: CodePart,
+ max_tokens: int = 8192,
+ overlap: float = 0.25,
+ granularity: float = 0.1,
+ model_name: str = "text-embedding-3-large"
+) -> Generator[SourceCodeChunk, None, None]:
+ """Yields source code chunks from a CodePart object, with configurable token limits and overlap."""
+ if not code_file_part.source_code:
+ logger.error(f"No source code in CodeFile {code_file_part.id}")
+ return
+
+ tokenizer = tiktoken.encoding_for_model(model_name)
+ max_subchunk_tokens = max(1, int(granularity * max_tokens))
+ subchunk_token_counts = _get_subchunk_token_counts(tokenizer, code_file_part.source_code, max_subchunk_tokens)
+
+ previous_chunk = None
+ while subchunk_token_counts:
+ subchunk_token_counts, chunk_source_code = _get_chunk_source_code(subchunk_token_counts, overlap, max_tokens)
+ if not chunk_source_code:
+ continue
+ current_chunk = SourceCodeChunk(
+ id=uuid5(NAMESPACE_OID, chunk_source_code),
+ code_chunk_of=code_file_part,
+ source_code=chunk_source_code,
+ previous_chunk=previous_chunk
+ )
+ yield current_chunk
+ previous_chunk = current_chunk
+
+
+async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="text-embedding-3-large") -> \
+ AsyncGenerator[list[DataPoint], None]:
+ """Processes code graph datapoints, create SourceCodeChink datapoints."""
+ # TODO: Add support for other embedding models, with max_token mapping
+ for data_point in data_points:
+ try:
+ yield data_point
+ if not isinstance(data_point, CodeFile):
+ continue
+ if not data_point.contains:
+ logger.warning(f"CodeFile {data_point.id} contains no code parts")
+ continue
+ for code_part in data_point.contains:
+ try:
+ yield code_part
+ for source_code_chunk in get_source_code_chunks_from_code_part(code_part, model_name=embedding_model):
+ yield source_code_chunk
+ except Exception as e:
+ logger.error(f"Error processing code part: {e}")
+ except Exception as e:
+ logger.error(f"Error processing data point: {e}")
diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py
index 857e4d777..12af2d2ef 100644
--- a/cognee/tasks/storage/index_data_points.py
+++ b/cognee/tasks/storage/index_data_points.py
@@ -1,6 +1,10 @@
+import logging
+
+from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint
+logger = logging.getLogger("index_data_points")
async def index_data_points(data_points: list[DataPoint]):
created_indexes = {}
@@ -30,7 +34,10 @@ async def index_data_points(data_points: list[DataPoint]):
for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
- await vector_engine.index_data_points(index_name, field_name, indexable_points)
+ try:
+ await vector_engine.index_data_points(index_name, field_name, indexable_points)
+ except EmbeddingException as e:
+ logger.warning(f"Failed to index data points for {index_name}.{field_name}: {e}")
return data_points
diff --git a/cognee/tasks/summarization/models.py b/cognee/tasks/summarization/models.py
index add448155..5b0345015 100644
--- a/cognee/tasks/summarization/models.py
+++ b/cognee/tasks/summarization/models.py
@@ -1,6 +1,8 @@
+from typing import Union
+
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models import DocumentChunk
-from cognee.shared.CodeGraphEntities import CodeFile
+from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk
class TextSummary(DataPoint):
@@ -17,7 +19,7 @@ class TextSummary(DataPoint):
class CodeSummary(DataPoint):
__tablename__ = "code_summary"
text: str
- made_from: CodeFile
+ summarizes: Union[CodeFile, CodePart, SourceCodeChunk]
_metadata: dict = {
"index_fields": ["text"],
diff --git a/cognee/tasks/summarization/summarize_code.py b/cognee/tasks/summarization/summarize_code.py
index b116e57a9..9efc5b6ca 100644
--- a/cognee/tasks/summarization/summarize_code.py
+++ b/cognee/tasks/summarization/summarize_code.py
@@ -1,10 +1,10 @@
import asyncio
from typing import AsyncGenerator, Union
from uuid import uuid5
-from typing import Type
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.extraction.extract_summary import extract_code_summary
+
from .models import CodeSummary
@@ -21,7 +21,7 @@ async def summarize_code(
)
file_summaries_map = {
- code_data_point.extracted_id: str(file_summary)
+ code_data_point.id: str(file_summary)
for code_data_point, file_summary in zip(code_data_points, file_summaries)
}
@@ -35,6 +35,6 @@ async def summarize_code(
yield CodeSummary(
id=uuid5(node.id, "CodeSummary"),
- made_from=node,
- text=file_summaries_map[node.extracted_id],
+ summarizes=node,
+ text=file_summaries_map[node.id],
)
diff --git a/examples/python/code_graph_example.py b/examples/python/code_graph_example.py
index c0b91972b..44ab33aad 100644
--- a/examples/python/code_graph_example.py
+++ b/examples/python/code_graph_example.py
@@ -11,6 +11,6 @@ async def main(repo_path, include_docs):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
- parser.add_argument("--include_docs", type=bool, default=True, help="Whether or not to process non-code files")
+ parser.add_argument("--include_docs", type=lambda x: x.lower() in ("true", "1"), default=True, help="Whether or not to process non-code files")
args = parser.parse_args()
asyncio.run(main(args.repo_path, args.include_docs))
\ No newline at end of file