From 9e7ab6492a87f18126ccc9ac5a76219c78a19003 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 17 Dec 2024 11:31:31 +0100 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20outsources=20chunking=20parameters?= =?UTF-8?q?=20to=20extract=20chunk=20from=20documents=20=E2=80=A6=20(#289)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: outsources chunking parameters to extract chunk from documents task --- .../processing/document_types/AudioDocument.py | 7 ++++--- .../processing/document_types/ChunkerMapping.py | 15 +++++++++++++++ .../data/processing/document_types/Document.py | 4 ++-- .../processing/document_types/ImageDocument.py | 7 ++++--- .../data/processing/document_types/PdfDocument.py | 7 ++++--- .../processing/document_types/TextDocument.py | 8 +++++--- .../documents/extract_chunks_from_documents.py | 4 ++-- .../integration/documents/AudioDocument_test.py | 2 +- .../integration/documents/ImageDocument_test.py | 2 +- .../integration/documents/PdfDocument_test.py | 2 +- .../integration/documents/TextDocument_test.py | 2 +- 11 files changed, 40 insertions(+), 20 deletions(-) create mode 100644 cognee/modules/data/processing/document_types/ChunkerMapping.py diff --git a/cognee/modules/data/processing/document_types/AudioDocument.py b/cognee/modules/data/processing/document_types/AudioDocument.py index 0d2cddd3d..268338703 100644 --- a/cognee/modules/data/processing/document_types/AudioDocument.py +++ b/cognee/modules/data/processing/document_types/AudioDocument.py @@ -1,6 +1,6 @@ from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class AudioDocument(Document): type: str = "audio" @@ -9,11 +9,12 @@ class AudioDocument(Document): result = get_llm_client().create_transcript(self.raw_data_location) return(result.text) - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): # Transcribe the audio file text = self.create_transcript() - chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text]) + chunker_func = ChunkerConfig.get_chunker(chunker) + chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text]) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/ChunkerMapping.py b/cognee/modules/data/processing/document_types/ChunkerMapping.py new file mode 100644 index 000000000..14dbb8bb7 --- /dev/null +++ b/cognee/modules/data/processing/document_types/ChunkerMapping.py @@ -0,0 +1,15 @@ +from cognee.modules.chunking.TextChunker import TextChunker + +class ChunkerConfig: + chunker_mapping = { + "text_chunker": TextChunker + } + + @classmethod + def get_chunker(cls, chunker_name: str): + chunker_class = cls.chunker_mapping.get(chunker_name) + if chunker_class is None: + raise NotImplementedError( + f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}" + ) + return chunker_class \ No newline at end of file diff --git a/cognee/modules/data/processing/document_types/Document.py b/cognee/modules/data/processing/document_types/Document.py index 924ffabac..8d6a3dafb 100644 --- a/cognee/modules/data/processing/document_types/Document.py +++ b/cognee/modules/data/processing/document_types/Document.py @@ -13,5 +13,5 @@ class Document(DataPoint): "type": "Document" } - def read(self, chunk_size: int) -> str: - pass \ No newline at end of file + def read(self, chunk_size: int, chunker = str) -> str: + pass diff --git a/cognee/modules/data/processing/document_types/ImageDocument.py b/cognee/modules/data/processing/document_types/ImageDocument.py index e8f0dd8ee..352486bd8 100644 --- a/cognee/modules/data/processing/document_types/ImageDocument.py +++ b/cognee/modules/data/processing/document_types/ImageDocument.py @@ -1,6 +1,6 @@ from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class ImageDocument(Document): type: str = "image" @@ -10,10 +10,11 @@ class ImageDocument(Document): result = get_llm_client().transcribe_image(self.raw_data_location) return(result.choices[0].message.content) - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): # Transcribe the image file text = self.transcribe_image() - chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text]) + chunker_func = ChunkerConfig.get_chunker(chunker) + chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text]) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/PdfDocument.py b/cognee/modules/data/processing/document_types/PdfDocument.py index 2d1941996..361214718 100644 --- a/cognee/modules/data/processing/document_types/PdfDocument.py +++ b/cognee/modules/data/processing/document_types/PdfDocument.py @@ -1,11 +1,11 @@ from pypdf import PdfReader -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class PdfDocument(Document): type: str = "pdf" - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): file = PdfReader(self.raw_data_location) def get_text(): @@ -13,7 +13,8 @@ class PdfDocument(Document): page_text = page.extract_text() yield page_text - chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text) + chunker_func = ChunkerConfig.get_chunker(chunker) + chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/TextDocument.py b/cognee/modules/data/processing/document_types/TextDocument.py index 32d3416b9..3952d9845 100644 --- a/cognee/modules/data/processing/document_types/TextDocument.py +++ b/cognee/modules/data/processing/document_types/TextDocument.py @@ -1,10 +1,10 @@ -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class TextDocument(Document): type: str = "text" - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): def get_text(): with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file: while True: @@ -15,6 +15,8 @@ class TextDocument(Document): yield text - chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text) + chunker_func = ChunkerConfig.get_chunker(chunker) + + chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text) yield from chunker.read() diff --git a/cognee/tasks/documents/extract_chunks_from_documents.py b/cognee/tasks/documents/extract_chunks_from_documents.py index ec19a786d..423b87b69 100644 --- a/cognee/tasks/documents/extract_chunks_from_documents.py +++ b/cognee/tasks/documents/extract_chunks_from_documents.py @@ -1,7 +1,7 @@ from cognee.modules.data.processing.document_types.Document import Document -async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024): +async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024, chunker = 'text_chunker'): for document in documents: - for document_chunk in document.read(chunk_size = chunk_size): + for document_chunk in document.read(chunk_size = chunk_size, chunker = chunker): yield document_chunk diff --git a/cognee/tests/integration/documents/AudioDocument_test.py b/cognee/tests/integration/documents/AudioDocument_test.py index da8b85d0b..151f4c0b2 100644 --- a/cognee/tests/integration/documents/AudioDocument_test.py +++ b/cognee/tests/integration/documents/AudioDocument_test.py @@ -31,7 +31,7 @@ def test_AudioDocument(): ) with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64) + GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count diff --git a/cognee/tests/integration/documents/ImageDocument_test.py b/cognee/tests/integration/documents/ImageDocument_test.py index 8a8ee8ef3..40e0155af 100644 --- a/cognee/tests/integration/documents/ImageDocument_test.py +++ b/cognee/tests/integration/documents/ImageDocument_test.py @@ -21,7 +21,7 @@ def test_ImageDocument(): with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64) + GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count diff --git a/cognee/tests/integration/documents/PdfDocument_test.py b/cognee/tests/integration/documents/PdfDocument_test.py index ac57eaf75..25d4cf6c6 100644 --- a/cognee/tests/integration/documents/PdfDocument_test.py +++ b/cognee/tests/integration/documents/PdfDocument_test.py @@ -22,7 +22,7 @@ def test_PdfDocument(): ) for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=1024) + GROUND_TRUTH, document.read(chunk_size=1024, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count diff --git a/cognee/tests/integration/documents/TextDocument_test.py b/cognee/tests/integration/documents/TextDocument_test.py index f663418f5..91f38968e 100644 --- a/cognee/tests/integration/documents/TextDocument_test.py +++ b/cognee/tests/integration/documents/TextDocument_test.py @@ -33,7 +33,7 @@ def test_TextDocument(input_file, chunk_size): ) for ground_truth, paragraph_data in zip( - GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size) + GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count From da5e3ab24de6c07f583dec15cd67f6dbcec25e8e Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Tue, 17 Dec 2024 12:02:25 +0100 Subject: [PATCH 2/3] COG 870 Remove duplicate edges from the code graph (#293) * feat: turn summarize_code into generator * feat: extract run_code_graph_pipeline, update the pipeline * feat: minimal code graph example * refactor: update argument * refactor: move run_code_graph_pipeline to cognify/code_graph_pipeline * refactor: indentation and whitespace nits * refactor: add deprecated use comments and warnings --------- Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com> Co-authored-by: Boris --- cognee/api/v1/cognify/code_graph_pipeline.py | 36 +++++++++++++++ cognee/tasks/summarization/summarize_code.py | 40 +++++++++-------- evals/eval_swe_bench.py | 47 +++----------------- examples/python/code_graph_example.py | 15 +++++++ 4 files changed, 80 insertions(+), 58 deletions(-) create mode 100644 examples/python/code_graph_example.py diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 59c658300..3c72e0793 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -1,8 +1,14 @@ +# NOTICE: This module contains deprecated functions. +# Use only the run_code_graph_pipeline function; all other functions are deprecated. +# Related issue: COG-906 + import asyncio import logging +from pathlib import Path from typing import Union from cognee.shared.SourceCodeGraph import SourceCodeGraph +from cognee.shared.data_models import SummarizedContent from cognee.shared.utils import send_telemetry from cognee.modules.data.models import Dataset, Data from cognee.modules.data.methods.get_dataset_data import get_dataset_data @@ -16,7 +22,9 @@ from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents from cognee.tasks.graph import extract_graph_from_code +from cognee.tasks.repo_processor import get_repo_file_dependencies, enrich_dependency_graph, expand_dependency_graph from cognee.tasks.storage import add_data_points +from cognee.tasks.summarization import summarize_code logger = logging.getLogger("code_graph_pipeline") @@ -51,6 +59,7 @@ async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User async def run_pipeline(dataset: Dataset, user: User): + '''DEPRECATED: Use `run_code_graph_pipeline` instead. This function will be removed.''' data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id) document_ids_str = [str(document.id) for document in data_documents] @@ -103,3 +112,30 @@ async def run_pipeline(dataset: Dataset, user: User): def generate_dataset_name(dataset_name: str) -> str: return dataset_name.replace(".", "_").replace(" ", "_") + + +async def run_code_graph_pipeline(repo_path): + import os + import pathlib + import cognee + from cognee.infrastructure.databases.relational import create_db_and_tables + + file_path = Path(__file__).parent + data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve()) + cognee.config.data_root_directory(data_directory_path) + cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve()) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await create_db_and_tables() + + tasks = [ + Task(get_repo_file_dependencies), + Task(enrich_dependency_graph, task_config={"batch_size": 50}), + Task(expand_dependency_graph, task_config={"batch_size": 50}), + Task(summarize_code, summarization_model=SummarizedContent, task_config={"batch_size": 50}), + Task(add_data_points, task_config={"batch_size": 50}), + ] + + return run_tasks(tasks, repo_path, "cognify_code_pipeline") diff --git a/cognee/tasks/summarization/summarize_code.py b/cognee/tasks/summarization/summarize_code.py index 277081f40..76435186c 100644 --- a/cognee/tasks/summarization/summarize_code.py +++ b/cognee/tasks/summarization/summarize_code.py @@ -1,39 +1,43 @@ import asyncio -from typing import Type from uuid import uuid5 +from typing import Type from pydantic import BaseModel from cognee.infrastructure.engine import DataPoint from cognee.modules.data.extraction.extract_summary import extract_summary from cognee.shared.CodeGraphEntities import CodeFile -from cognee.tasks.storage import add_data_points - from .models import CodeSummary async def summarize_code( - code_files: list[DataPoint], + code_graph_nodes: list[DataPoint], summarization_model: Type[BaseModel], ) -> list[DataPoint]: - if len(code_files) == 0: - return code_files + if len(code_graph_nodes) == 0: + return - code_files_data_points = [file for file in code_files if isinstance(file, CodeFile)] + code_files_data_points = [file for file in code_graph_nodes if isinstance(file, CodeFile)] file_summaries = await asyncio.gather( *[extract_summary(file.source_code, summarization_model) for file in code_files_data_points] ) - summaries = [ - CodeSummary( - id = uuid5(file.id, "CodeSummary"), - made_from = file, - text = file_summaries[file_index].summary, + file_summaries_map = { + code_file_data_point.extracted_id: file_summary.summary + for code_file_data_point, file_summary in zip(code_files_data_points, file_summaries) + } + + for node in code_graph_nodes: + if not isinstance(node, DataPoint): + continue + yield node + + if not isinstance(node, CodeFile): + continue + + yield CodeSummary( + id=uuid5(node.id, "CodeSummary"), + made_from=node, + text=file_summaries_map[node.extracted_id], ) - for (file_index, file) in enumerate(code_files_data_points) - ] - - await add_data_points(summaries) - - return code_files diff --git a/evals/eval_swe_bench.py b/evals/eval_swe_bench.py index 67826fc12..6c2280d80 100644 --- a/evals/eval_swe_bench.py +++ b/evals/eval_swe_bench.py @@ -7,19 +7,13 @@ from pathlib import Path from swebench.harness.utils import load_swebench_dataset from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE +from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline from cognee.api.v1.search import SearchType from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt -from cognee.modules.pipelines import Task, run_tasks from cognee.modules.retrieval.brute_force_triplet_search import \ brute_force_triplet_search -# from cognee.shared.data_models import SummarizedContent from cognee.shared.utils import render_graph -from cognee.tasks.repo_processor import (enrich_dependency_graph, - expand_dependency_graph, - get_repo_file_dependencies) -from cognee.tasks.storage import add_data_points -# from cognee.tasks.summarization import summarize_code from evals.eval_utils import download_github_repo, retrieved_edges_to_string @@ -42,48 +36,22 @@ def check_install_package(package_name): async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS): - import os - import pathlib - import cognee - from cognee.infrastructure.databases.relational import create_db_and_tables - - file_path = Path(__file__).parent - data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve()) - cognee.config.data_root_directory(data_directory_path) - cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve()) - cognee.config.system_root_directory(cognee_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata = True) - - await create_db_and_tables() - - # repo_path = download_github_repo(instance, '../RAW_GIT_REPOS') - - repo_path = '/Users/borisarzentar/Projects/graphrag' - - tasks = [ - Task(get_repo_file_dependencies), - Task(enrich_dependency_graph, task_config = { "batch_size": 50 }), - Task(expand_dependency_graph, task_config = { "batch_size": 50 }), - Task(add_data_points, task_config = { "batch_size": 50 }), - # Task(summarize_code, summarization_model = SummarizedContent), - ] - - pipeline = run_tasks(tasks, repo_path, "cognify_code_pipeline") + repo_path = download_github_repo(instance, '../RAW_GIT_REPOS') + pipeline = await run_code_graph_pipeline(repo_path) async for result in pipeline: print(result) print('Here we have the repo under the repo_path') - await render_graph(None, include_labels = True, include_nodes = True) + await render_graph(None, include_labels=True, include_nodes=True) problem_statement = instance['problem_statement'] instructions = read_query_prompt("patch_gen_kg_instructions.txt") - retrieved_edges = await brute_force_triplet_search(problem_statement, top_k = 3, collections = ["data_point_source_code", "data_point_text"]) - + retrieved_edges = await brute_force_triplet_search(problem_statement, top_k=3, + collections=["data_point_source_code", "data_point_text"]) + retrieved_edges_str = retrieved_edges_to_string(retrieved_edges) prompt = "\n".join([ @@ -171,7 +139,6 @@ async def main(): with open(predictions_path, "w") as file: json.dump(preds, file) - subprocess.run( [ "python", diff --git a/examples/python/code_graph_example.py b/examples/python/code_graph_example.py new file mode 100644 index 000000000..9189de46c --- /dev/null +++ b/examples/python/code_graph_example.py @@ -0,0 +1,15 @@ +import argparse +import asyncio +from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline + + +async def main(repo_path): + async for result in await run_code_graph_pipeline(repo_path): + print(result) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--repo-path", type=str, required=True, help="Path to the repository") + args = parser.parse_args() + asyncio.run(main(args.repo_path)) + From 9afd0ece63bdcda318afc498d0643fb649ac1302 Mon Sep 17 00:00:00 2001 From: alekszievr <44192193+alekszievr@users.noreply.github.com> Date: Tue, 17 Dec 2024 13:05:47 +0100 Subject: [PATCH 3/3] Structured code summarization (#375) * feat: turn summarize_code into generator * feat: extract run_code_graph_pipeline, update the pipeline * feat: minimal code graph example * refactor: update argument * refactor: move run_code_graph_pipeline to cognify/code_graph_pipeline * refactor: indentation and whitespace nits * refactor: add deprecated use comments and warnings * Structured code summarization * add missing prompt file * Remove summarization_model argument from summarize_code and fix typehinting * minor refactors --------- Co-authored-by: lxobr <122801072+lxobr@users.noreply.github.com> Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com> Co-authored-by: Boris --- cognee/api/v1/cognify/code_graph_pipeline.py | 31 +++++++++++-------- .../llm/prompts/summarize_code.txt | 10 ++++++ .../data/extraction/extract_summary.py | 10 +++++- cognee/shared/data_models.py | 27 +++++++++++++++- cognee/tasks/summarization/summarize_code.py | 19 +++++------- 5 files changed, 71 insertions(+), 26 deletions(-) create mode 100644 cognee/infrastructure/llm/prompts/summarize_code.txt diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 3c72e0793..eeb10d69e 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -7,22 +7,27 @@ import logging from pathlib import Path from typing import Union -from cognee.shared.SourceCodeGraph import SourceCodeGraph -from cognee.shared.data_models import SummarizedContent -from cognee.shared.utils import send_telemetry -from cognee.modules.data.models import Dataset, Data -from cognee.modules.data.methods.get_dataset_data import get_dataset_data from cognee.modules.data.methods import get_datasets, get_datasets_by_name -from cognee.modules.pipelines.tasks.Task import Task +from cognee.modules.data.methods.get_dataset_data import get_dataset_data +from cognee.modules.data.models import Data, Dataset from cognee.modules.pipelines import run_tasks -from cognee.modules.users.models import User -from cognee.modules.users.methods import get_default_user from cognee.modules.pipelines.models import PipelineRunStatus -from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status -from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status -from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents +from cognee.modules.pipelines.operations.get_pipeline_status import \ + get_pipeline_status +from cognee.modules.pipelines.operations.log_pipeline_status import \ + log_pipeline_status +from cognee.modules.pipelines.tasks.Task import Task +from cognee.modules.users.methods import get_default_user +from cognee.modules.users.models import User +from cognee.shared.SourceCodeGraph import SourceCodeGraph +from cognee.shared.utils import send_telemetry +from cognee.tasks.documents import (check_permissions_on_documents, + classify_documents, + extract_chunks_from_documents) from cognee.tasks.graph import extract_graph_from_code -from cognee.tasks.repo_processor import get_repo_file_dependencies, enrich_dependency_graph, expand_dependency_graph +from cognee.tasks.repo_processor import (enrich_dependency_graph, + expand_dependency_graph, + get_repo_file_dependencies) from cognee.tasks.storage import add_data_points from cognee.tasks.summarization import summarize_code @@ -134,7 +139,7 @@ async def run_code_graph_pipeline(repo_path): Task(get_repo_file_dependencies), Task(enrich_dependency_graph, task_config={"batch_size": 50}), Task(expand_dependency_graph, task_config={"batch_size": 50}), - Task(summarize_code, summarization_model=SummarizedContent, task_config={"batch_size": 50}), + Task(summarize_code, task_config={"batch_size": 50}), Task(add_data_points, task_config={"batch_size": 50}), ] diff --git a/cognee/infrastructure/llm/prompts/summarize_code.txt b/cognee/infrastructure/llm/prompts/summarize_code.txt new file mode 100644 index 000000000..405585617 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/summarize_code.txt @@ -0,0 +1,10 @@ +You are an expert Python programmer and technical writer. Your task is to summarize the given Python code snippet or file. +The code may contain multiple imports, classes, functions, constants and logic. Provide a clear, structured explanation of its components +and their relationships. + +Instructions: +Provide an overview: Start with a high-level summary of what the code does as a whole. +Break it down: Summarize each class and function individually, explaining their purpose and how they interact. +Describe the workflow: Outline how the classes and functions work together. Mention any control flow (e.g., main functions, entry points, loops). +Key features: Highlight important elements like arguments, return values, or unique logic. +Maintain clarity: Write in plain English for someone familiar with Python but unfamiliar with this code. \ No newline at end of file diff --git a/cognee/modules/data/extraction/extract_summary.py b/cognee/modules/data/extraction/extract_summary.py index a17bf3ae6..10d429da9 100644 --- a/cognee/modules/data/extraction/extract_summary.py +++ b/cognee/modules/data/extraction/extract_summary.py @@ -1,7 +1,11 @@ from typing import Type + from pydantic import BaseModel -from cognee.infrastructure.llm.prompts import read_query_prompt + from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm.prompts import read_query_prompt +from cognee.shared.data_models import SummarizedCode + async def extract_summary(content: str, response_model: Type[BaseModel]): llm_client = get_llm_client() @@ -11,3 +15,7 @@ async def extract_summary(content: str, response_model: Type[BaseModel]): llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model) return llm_output + +async def extract_code_summary(content: str): + + return await extract_summary(content, response_model=SummarizedCode) diff --git a/cognee/shared/data_models.py b/cognee/shared/data_models.py index 6cb4d436a..dec53cfcb 100644 --- a/cognee/shared/data_models.py +++ b/cognee/shared/data_models.py @@ -1,9 +1,11 @@ """Data models for the cognitive architecture.""" from enum import Enum, auto -from typing import Optional, List, Union, Dict, Any +from typing import Any, Dict, List, Optional, Union + from pydantic import BaseModel, Field + class Node(BaseModel): """Node in a knowledge graph.""" id: str @@ -194,6 +196,29 @@ class SummarizedContent(BaseModel): summary: str description: str +class SummarizedFunction(BaseModel): + name: str + description: str + inputs: Optional[List[str]] = None + outputs: Optional[List[str]] = None + decorators: Optional[List[str]] = None + +class SummarizedClass(BaseModel): + name: str + description: str + methods: Optional[List[SummarizedFunction]] = None + decorators: Optional[List[str]] = None + +class SummarizedCode(BaseModel): + file_name: str + high_level_summary: str + key_features: List[str] + imports: List[str] = [] + constants: List[str] = [] + classes: List[SummarizedClass] = [] + functions: List[SummarizedFunction] = [] + workflow_description: Optional[str] = None + class GraphDBType(Enum): NETWORKX = auto() diff --git a/cognee/tasks/summarization/summarize_code.py b/cognee/tasks/summarization/summarize_code.py index 76435186c..b116e57a9 100644 --- a/cognee/tasks/summarization/summarize_code.py +++ b/cognee/tasks/summarization/summarize_code.py @@ -1,31 +1,28 @@ import asyncio +from typing import AsyncGenerator, Union from uuid import uuid5 from typing import Type -from pydantic import BaseModel - from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.extraction.extract_summary import extract_summary -from cognee.shared.CodeGraphEntities import CodeFile +from cognee.modules.data.extraction.extract_summary import extract_code_summary from .models import CodeSummary async def summarize_code( code_graph_nodes: list[DataPoint], - summarization_model: Type[BaseModel], -) -> list[DataPoint]: +) -> AsyncGenerator[Union[DataPoint, CodeSummary], None]: if len(code_graph_nodes) == 0: return - code_files_data_points = [file for file in code_graph_nodes if isinstance(file, CodeFile)] + code_data_points = [file for file in code_graph_nodes if hasattr(file, "source_code")] file_summaries = await asyncio.gather( - *[extract_summary(file.source_code, summarization_model) for file in code_files_data_points] + *[extract_code_summary(file.source_code) for file in code_data_points] ) file_summaries_map = { - code_file_data_point.extracted_id: file_summary.summary - for code_file_data_point, file_summary in zip(code_files_data_points, file_summaries) + code_data_point.extracted_id: str(file_summary) + for code_data_point, file_summary in zip(code_data_points, file_summaries) } for node in code_graph_nodes: @@ -33,7 +30,7 @@ async def summarize_code( continue yield node - if not isinstance(node, CodeFile): + if not hasattr(node, "source_code"): continue yield CodeSummary(