From d9e558e8857fcc260433fdeff5964a1ab9364462 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Mon, 16 Dec 2024 11:02:50 +0100 Subject: [PATCH 1/8] fix: Resolve reflection issue when running cognee a second time after pruning data When running cognee a second time after pruning data some metadata doesn't get pruned. This makes cognee believe some tables exist that have been deleted Fix --- .../vector/pgvector/PGVectorAdapter.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 8faf1cd6d..a6b458cbd 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -2,7 +2,7 @@ import asyncio from uuid import UUID from typing import List, Optional, get_type_hints from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy import JSON, Column, Table, select, delete +from sqlalchemy import JSON, Column, Table, select, delete, MetaData from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from cognee.exceptions import InvalidValueError @@ -48,10 +48,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def has_collection(self, collection_name: str) -> bool: async with self.engine.begin() as connection: - # Load the schema information into the MetaData object - await connection.run_sync(Base.metadata.reflect) + # Create a MetaData instance to load table information + metadata = MetaData() + # Load table information from schema into MetaData + await connection.run_sync(metadata.reflect) - if collection_name in Base.metadata.tables: + if collection_name in metadata.tables: return True else: return False @@ -145,10 +147,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): with an async engine. """ async with self.engine.begin() as connection: - # Load the schema information into the MetaData object - await connection.run_sync(Base.metadata.reflect) - if collection_name in Base.metadata.tables: - return Base.metadata.tables[collection_name] + # Create a MetaData instance to load table information + metadata = MetaData() + # Load table information from schema into MetaData + await connection.run_sync(metadata.reflect) + if collection_name in metadata.tables: + return metadata.tables[collection_name] else: raise EntityNotFoundError(message=f"Table '{collection_name}' not found.") From 394a0b2dfb9645e58ed31835e8eaec7c90970358 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Mon, 16 Dec 2024 11:26:33 +0100 Subject: [PATCH 2/8] fix: Add metadata reflection fix to sqlite as well Added fix when reflecting metadata to sqlite as well Fix --- .../relational/sqlalchemy/SqlAlchemyAdapter.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 8041aeaea..b1e608059 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -113,10 +113,12 @@ class SQLAlchemyAdapter(): """ async with self.engine.begin() as connection: if self.engine.dialect.name == "sqlite": - # Load the schema information into the MetaData object - await connection.run_sync(Base.metadata.reflect) - if table_name in Base.metadata.tables: - return Base.metadata.tables[table_name] + # Create a MetaData instance to load table information + metadata = MetaData() + # Load table information from schema into MetaData + await connection.run_sync(metadata.reflect) + if table_name in metadata.tables: + return metadata.tables[table_name] else: raise EntityNotFoundError(message=f"Table '{table_name}' not found.") else: @@ -138,8 +140,11 @@ class SQLAlchemyAdapter(): table_names = [] async with self.engine.begin() as connection: if self.engine.dialect.name == "sqlite": - await connection.run_sync(Base.metadata.reflect) - for table in Base.metadata.tables: + # Create a MetaData instance to load table information + metadata = MetaData() + # Load table information from schema into MetaData + await connection.run_sync(metadata.reflect) + for table in metadata.tables: table_names.append(str(table)) else: schema_list = await self.get_schema_list() From 34b139af2665ab5274de6484980e77cfda2985c5 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Mon, 16 Dec 2024 13:19:21 +0100 Subject: [PATCH 3/8] Revert "fix: Add metadata reflection fix to sqlite as well" This reverts commit 394a0b2dfb9645e58ed31835e8eaec7c90970358. --- .../relational/sqlalchemy/SqlAlchemyAdapter.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index b1e608059..8041aeaea 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -113,12 +113,10 @@ class SQLAlchemyAdapter(): """ async with self.engine.begin() as connection: if self.engine.dialect.name == "sqlite": - # Create a MetaData instance to load table information - metadata = MetaData() - # Load table information from schema into MetaData - await connection.run_sync(metadata.reflect) - if table_name in metadata.tables: - return metadata.tables[table_name] + # Load the schema information into the MetaData object + await connection.run_sync(Base.metadata.reflect) + if table_name in Base.metadata.tables: + return Base.metadata.tables[table_name] else: raise EntityNotFoundError(message=f"Table '{table_name}' not found.") else: @@ -140,11 +138,8 @@ class SQLAlchemyAdapter(): table_names = [] async with self.engine.begin() as connection: if self.engine.dialect.name == "sqlite": - # Create a MetaData instance to load table information - metadata = MetaData() - # Load table information from schema into MetaData - await connection.run_sync(metadata.reflect) - for table in metadata.tables: + await connection.run_sync(Base.metadata.reflect) + for table in Base.metadata.tables: table_names.append(str(table)) else: schema_list = await self.get_schema_list() From 5360093097ded87f324a8b8813802cd29ec1bc60 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:02:39 +0100 Subject: [PATCH 4/8] COG-810 Implement a top-down dependency graph builder tool (#268) * feat: parse repo to call graph * Update/repo_processor/top_down_repo_parse.py task * fix: minor improvements * feat: file parsing jedi script optimisation --------- --- .../repo_processor/top_down_repo_parse.py | 171 ++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 cognee/tasks/repo_processor/top_down_repo_parse.py diff --git a/cognee/tasks/repo_processor/top_down_repo_parse.py b/cognee/tasks/repo_processor/top_down_repo_parse.py new file mode 100644 index 000000000..52f58f811 --- /dev/null +++ b/cognee/tasks/repo_processor/top_down_repo_parse.py @@ -0,0 +1,171 @@ +import os + +import jedi +import parso +from tqdm import tqdm + +from . import logger + +_NODE_TYPE_MAP = { + "funcdef": "func_def", + "classdef": "class_def", + "async_funcdef": "async_func_def", + "async_stmt": "async_func_def", + "simple_stmt": "var_def", +} + +def _create_object_dict(name_node, type_name=None): + return { + "name": name_node.value, + "line": name_node.start_pos[0], + "column": name_node.start_pos[1], + "type": type_name, + } + + +def _parse_node(node): + """Parse a node to extract importable object details, including async functions and classes.""" + node_type = _NODE_TYPE_MAP.get(node.type) + + if node.type in {"funcdef", "classdef", "async_funcdef"}: + return [_create_object_dict(node.name, type_name=node_type)] + if node.type == "async_stmt" and len(node.children) > 1: + function_node = node.children[1] + if function_node.type == "funcdef": + return [_create_object_dict(function_node.name, type_name=_NODE_TYPE_MAP.get(function_node.type))] + if node.type == "simple_stmt": + # TODO: Handle multi-level/nested unpacking variable definitions in the future + expr_child = node.children[0] + if expr_child.type != "expr_stmt": + return [] + if expr_child.children[0].type == "testlist_star_expr": + name_targets = expr_child.children[0].children + else: + name_targets = expr_child.children + return [ + _create_object_dict(target, type_name=_NODE_TYPE_MAP.get(target.type)) + for target in name_targets + if target.type == "name" + ] + return [] + + + +def extract_importable_objects_with_positions_from_source_code(source_code): + """Extract top-level objects in a Python source code string with their positions (line/column).""" + try: + tree = parso.parse(source_code) + except Exception as e: + logger.error(f"Error parsing source code: {e}") + return [] + + importable_objects = [] + try: + for node in tree.children: + importable_objects.extend(_parse_node(node)) + except Exception as e: + logger.error(f"Error extracting nodes from parsed tree: {e}") + return [] + + return importable_objects + + +def extract_importable_objects_with_positions(file_path): + """Extract top-level objects in a Python file with their positions (line/column).""" + try: + with open(file_path, "r") as file: + source_code = file.read() + except Exception as e: + logger.error(f"Error reading file {file_path}: {e}") + return [] + + return extract_importable_objects_with_positions_from_source_code(source_code) + + + +def find_entity_usages(script, line, column): + """ + Return a list of files in the repo where the entity at module_path:line,column is used. + """ + usages = set() + + + try: + inferred = script.infer(line, column) + except Exception as e: + logger.error(f"Error inferring entity at {script.path}:{line},{column}: {e}") + return [] + + if not inferred or not inferred[0]: + logger.info(f"No entity inferred at {script.path}:{line},{column}") + return [] + + logger.debug(f"Inferred entity: {inferred[0].name}, type: {inferred[0].type}") + + try: + references = script.get_references(line=line, column=column, scope="project", include_builtins=False) + except Exception as e: + logger.error(f"Error retrieving references for entity at {script.path}:{line},{column}: {e}") + references = [] + + for ref in references: + if ref.module_path: # Collect unique module paths + usages.add(ref.module_path) + logger.info(f"Entity used in: {ref.module_path}") + + return list(usages) + +def parse_file_with_references(project, file_path): + """Parse a file to extract object names and their references within a project.""" + try: + importable_objects = extract_importable_objects_with_positions(file_path) + except Exception as e: + logger.error(f"Error extracting objects from {file_path}: {e}") + return [] + + if not os.path.isfile(file_path): + logger.warning(f"Module file does not exist: {file_path}") + return [] + + try: + script = jedi.Script(path=file_path, project=project) + except Exception as e: + logger.error(f"Error initializing Jedi Script: {e}") + return [] + + parsed_results = [ + { + "name": obj["name"], + "type": obj["type"], + "references": find_entity_usages(script, obj["line"], obj["column"]), + } + for obj in importable_objects + ] + return parsed_results + + +def parse_repo(repo_path): + """Parse a repository to extract object names, types, and references for all Python files.""" + try: + project = jedi.Project(path=repo_path) + except Exception as e: + logger.error(f"Error creating Jedi project for repository at {repo_path}: {e}") + return {} + + EXCLUDE_DIRS = {'venv', '.git', '__pycache__', 'build'} + + python_files = [ + os.path.join(directory, file) + for directory, _, filenames in os.walk(repo_path) + if not any(excluded in directory for excluded in EXCLUDE_DIRS) + for file in filenames + if file.endswith(".py") and os.path.getsize(os.path.join(directory, file)) > 0 + ] + + results = { + file_path: parse_file_with_references(project, file_path) + for file_path in tqdm(python_files) + } + + return results + From bfa0f06fb4421a5cbb09fd6e4a556db1905eea0d Mon Sep 17 00:00:00 2001 From: alekszievr <44192193+alekszievr@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:27:03 +0100 Subject: [PATCH 5/8] Add type to DataPoint metadata (#364) * Add type to DataPoint metadata * Add missing index_fields * Use DataPoint UUID type in pgvector create_data_points * Make _metadata mandatory everywhere --- .../hybrid/falkordb/FalkorDBAdapter.py | 15 ++++--- .../vector/lancedb/LanceDBAdapter.py | 16 ++++--- .../databases/vector/milvus/MilvusAdapter.py | 9 ++-- .../vector/pgvector/PGVectorAdapter.py | 22 ++++++---- .../databases/vector/qdrant/QDrantAdapter.py | 12 ++++-- .../vector/weaviate_db/WeaviateAdapter.py | 8 ++-- .../infrastructure/engine/models/DataPoint.py | 13 +++--- .../modules/chunking/models/DocumentChunk.py | 5 ++- .../processing/document_types/Document.py | 11 +++-- cognee/modules/engine/models/Entity.py | 1 + cognee/modules/engine/models/EntityType.py | 3 +- cognee/modules/graph/models/EdgeType.py | 5 ++- .../graph/utils/convert_node_to_data_point.py | 2 +- cognee/shared/CodeGraphEntities.py | 21 ++++++---- cognee/shared/SourceCodeGraph.py | 42 +++++++++++++------ cognee/tasks/storage/index_data_points.py | 12 +++++- cognee/tasks/summarization/models.py | 2 + .../graph/get_graph_from_huge_model_test.py | 14 ++++++- .../graph/get_graph_from_model_test.py | 19 ++++++++- 19 files changed, 167 insertions(+), 65 deletions(-) diff --git a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py index fdc7db069..324ee7bcd 100644 --- a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py +++ b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py @@ -1,21 +1,26 @@ import asyncio # from datetime import datetime import json -from uuid import UUID from textwrap import dedent +from uuid import UUID + from falkordb import FalkorDB from cognee.exceptions import InvalidValueError -from cognee.infrastructure.engine import DataPoint -from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface +from cognee.infrastructure.databases.graph.graph_db_interface import \ + GraphDBInterface from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine -from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface +from cognee.infrastructure.databases.vector.vector_db_interface import \ + VectorDBInterface +from cognee.infrastructure.engine import DataPoint + class IndexSchema(DataPoint): text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 37d340004..1b3fc55c3 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -1,25 +1,29 @@ -from typing import List, Optional, get_type_hints, Generic, TypeVar import asyncio +from typing import Generic, List, Optional, TypeVar, get_type_hints from uuid import UUID + import lancedb +from lancedb.pydantic import LanceModel, Vector from pydantic import BaseModel -from lancedb.pydantic import Vector, LanceModel from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.files.storage import LocalStorage from cognee.modules.storage.utils import copy_model, get_own_properties -from ..models.ScoredResult import ScoredResult -from ..vector_db_interface import VectorDBInterface -from ..utils import normalize_distances + from ..embeddings.EmbeddingEngine import EmbeddingEngine +from ..models.ScoredResult import ScoredResult +from ..utils import normalize_distances +from ..vector_db_interface import VectorDBInterface + class IndexSchema(DataPoint): id: str text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } class LanceDBAdapter(VectorDBInterface): diff --git a/cognee/infrastructure/databases/vector/milvus/MilvusAdapter.py b/cognee/infrastructure/databases/vector/milvus/MilvusAdapter.py index 4e5290dd1..0d4ea05d3 100644 --- a/cognee/infrastructure/databases/vector/milvus/MilvusAdapter.py +++ b/cognee/infrastructure/databases/vector/milvus/MilvusAdapter.py @@ -4,10 +4,12 @@ import asyncio import logging from typing import List, Optional from uuid import UUID + from cognee.infrastructure.engine import DataPoint -from ..vector_db_interface import VectorDBInterface -from ..models.ScoredResult import ScoredResult + from ..embeddings.EmbeddingEngine import EmbeddingEngine +from ..models.ScoredResult import ScoredResult +from ..vector_db_interface import VectorDBInterface logger = logging.getLogger("MilvusAdapter") @@ -16,7 +18,8 @@ class IndexSchema(DataPoint): text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index a6b458cbd..3f0565253 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -1,6 +1,7 @@ import asyncio -from uuid import UUID from typing import List, Optional, get_type_hints +from uuid import UUID + from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy import JSON, Column, Table, select, delete, MetaData from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker @@ -9,19 +10,21 @@ from cognee.exceptions import InvalidValueError from cognee.infrastructure.databases.exceptions import EntityNotFoundError from cognee.infrastructure.engine import DataPoint -from .serialize_data import serialize_data -from ..models.ScoredResult import ScoredResult -from ..vector_db_interface import VectorDBInterface -from ..utils import normalize_distances -from ..embeddings.EmbeddingEngine import EmbeddingEngine -from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from ...relational.ModelBase import Base +from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter +from ..embeddings.EmbeddingEngine import EmbeddingEngine +from ..models.ScoredResult import ScoredResult +from ..utils import normalize_distances +from ..vector_db_interface import VectorDBInterface +from .serialize_data import serialize_data + class IndexSchema(DataPoint): text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): @@ -89,6 +92,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def create_data_points( self, collection_name: str, data_points: List[DataPoint] ): + data_point_types = get_type_hints(DataPoint) if not await self.has_collection(collection_name): await self.create_collection( collection_name = collection_name, @@ -108,7 +112,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): primary_key: Mapped[int] = mapped_column( primary_key=True, autoincrement=True ) - id: Mapped[type(data_points[0].id)] + id: Mapped[data_point_types["id"]] payload = Column(JSON) vector = Column(self.Vector(vector_size)) diff --git a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py index d5d2a1a5c..b63139bc5 100644 --- a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +++ b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py @@ -1,13 +1,16 @@ import logging +from typing import Dict, List, Optional from uuid import UUID -from typing import List, Dict, Optional + from qdrant_client import AsyncQdrantClient, models from cognee.exceptions import InvalidValueError -from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult +from cognee.infrastructure.databases.vector.models.ScoredResult import \ + ScoredResult from cognee.infrastructure.engine import DataPoint -from ..vector_db_interface import VectorDBInterface + from ..embeddings.EmbeddingEngine import EmbeddingEngine +from ..vector_db_interface import VectorDBInterface logger = logging.getLogger("QDrantAdapter") @@ -15,7 +18,8 @@ class IndexSchema(DataPoint): text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } # class CollectionConfig(BaseModel, extra = "forbid"): diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index c16f765b0..31162b1b5 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -5,9 +5,10 @@ from uuid import UUID from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine import DataPoint -from ..vector_db_interface import VectorDBInterface -from ..models.ScoredResult import ScoredResult + from ..embeddings.EmbeddingEngine import EmbeddingEngine +from ..models.ScoredResult import ScoredResult +from ..vector_db_interface import VectorDBInterface logger = logging.getLogger("WeaviateAdapter") @@ -15,7 +16,8 @@ class IndexSchema(DataPoint): text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } class WeaviateAdapter(VectorDBInterface): diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index abb924f2f..e08041146 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -1,8 +1,10 @@ -from typing_extensions import TypedDict -from uuid import UUID, uuid4 -from typing import Optional from datetime import datetime, timezone +from typing import Optional +from uuid import UUID, uuid4 + from pydantic import BaseModel, Field +from typing_extensions import TypedDict + class MetaData(TypedDict): index_fields: list[str] @@ -13,7 +15,8 @@ class DataPoint(BaseModel): updated_at: Optional[datetime] = datetime.now(timezone.utc) topological_rank: Optional[int] = 0 _metadata: Optional[MetaData] = { - "index_fields": [] + "index_fields": [], + "type": "DataPoint" } # class Config: @@ -39,4 +42,4 @@ class DataPoint(BaseModel): @classmethod def get_embeddable_property_names(self, data_point): - return data_point._metadata["index_fields"] or [] + return data_point._metadata["index_fields"] or [] \ No newline at end of file diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py index b5faea560..8729596df 100644 --- a/cognee/modules/chunking/models/DocumentChunk.py +++ b/cognee/modules/chunking/models/DocumentChunk.py @@ -1,8 +1,10 @@ from typing import List, Optional + from cognee.infrastructure.engine import DataPoint from cognee.modules.data.processing.document_types import Document from cognee.modules.engine.models import Entity + class DocumentChunk(DataPoint): __tablename__ = "document_chunk" text: str @@ -12,6 +14,7 @@ class DocumentChunk(DataPoint): is_part_of: Document contains: List[Entity] = None - _metadata: Optional[dict] = { + _metadata: dict = { "index_fields": ["text"], + "type": "DocumentChunk" } diff --git a/cognee/modules/data/processing/document_types/Document.py b/cognee/modules/data/processing/document_types/Document.py index 45441dcce..924ffabac 100644 --- a/cognee/modules/data/processing/document_types/Document.py +++ b/cognee/modules/data/processing/document_types/Document.py @@ -1,12 +1,17 @@ -from cognee.infrastructure.engine import DataPoint from uuid import UUID +from cognee.infrastructure.engine import DataPoint + + class Document(DataPoint): - type: str name: str raw_data_location: str metadata_id: UUID mime_type: str + _metadata: dict = { + "index_fields": ["name"], + "type": "Document" + } def read(self, chunk_size: int) -> str: - pass + pass \ No newline at end of file diff --git a/cognee/modules/engine/models/Entity.py b/cognee/modules/engine/models/Entity.py index b805d3d11..16e0ca3d8 100644 --- a/cognee/modules/engine/models/Entity.py +++ b/cognee/modules/engine/models/Entity.py @@ -10,4 +10,5 @@ class Entity(DataPoint): _metadata: dict = { "index_fields": ["name"], + "type": "Entity" } diff --git a/cognee/modules/engine/models/EntityType.py b/cognee/modules/engine/models/EntityType.py index 1c7843cfd..d3cc54311 100644 --- a/cognee/modules/engine/models/EntityType.py +++ b/cognee/modules/engine/models/EntityType.py @@ -1,11 +1,12 @@ from cognee.infrastructure.engine import DataPoint + class EntityType(DataPoint): __tablename__ = "entity_type" name: str - type: str description: str _metadata: dict = { "index_fields": ["name"], + "type": "EntityType" } diff --git a/cognee/modules/graph/models/EdgeType.py b/cognee/modules/graph/models/EdgeType.py index f9554d25d..998f08d8d 100644 --- a/cognee/modules/graph/models/EdgeType.py +++ b/cognee/modules/graph/models/EdgeType.py @@ -1,11 +1,14 @@ from typing import Optional + from cognee.infrastructure.engine import DataPoint + class EdgeType(DataPoint): __tablename__ = "edge_type" relationship_name: str number_of_edges: int - _metadata: Optional[dict] = { + _metadata: dict = { "index_fields": ["relationship_name"], + "type": "EdgeType" } \ No newline at end of file diff --git a/cognee/modules/graph/utils/convert_node_to_data_point.py b/cognee/modules/graph/utils/convert_node_to_data_point.py index 292f53733..602a7ffa3 100644 --- a/cognee/modules/graph/utils/convert_node_to_data_point.py +++ b/cognee/modules/graph/utils/convert_node_to_data_point.py @@ -2,7 +2,7 @@ from cognee.infrastructure.engine import DataPoint def convert_node_to_data_point(node_data: dict) -> DataPoint: - subclass = find_subclass_by_name(DataPoint, node_data["type"]) + subclass = find_subclass_by_name(DataPoint, node_data._metadata["type"]) return subclass(**node_data) diff --git a/cognee/shared/CodeGraphEntities.py b/cognee/shared/CodeGraphEntities.py index 8859fd0d6..23b8879c2 100644 --- a/cognee/shared/CodeGraphEntities.py +++ b/cognee/shared/CodeGraphEntities.py @@ -1,15 +1,19 @@ from typing import List, Optional + from cognee.infrastructure.engine import DataPoint + class Repository(DataPoint): __tablename__ = "Repository" path: str - type: Optional[str] = "Repository" + _metadata: dict = { + "index_fields": ["source_code"], + "type": "Repository" + } class CodeFile(DataPoint): __tablename__ = "codefile" extracted_id: str # actually file path - type: Optional[str] = "CodeFile" source_code: Optional[str] = None part_of: Optional[Repository] = None depends_on: Optional[List["CodeFile"]] = None @@ -17,24 +21,27 @@ class CodeFile(DataPoint): contains: Optional[List["CodePart"]] = None _metadata: dict = { - "index_fields": ["source_code"] + "index_fields": ["source_code"], + "type": "CodeFile" } class CodePart(DataPoint): __tablename__ = "codepart" # part_of: Optional[CodeFile] source_code: str - type: Optional[str] = "CodePart" - + _metadata: dict = { - "index_fields": ["source_code"] + "index_fields": ["source_code"], + "type": "CodePart" } class CodeRelationship(DataPoint): source_id: str target_id: str - type: str # between files relation: str # depends on or depends directly + _metadata: dict = { + "type": "CodeRelationship" + } CodeFile.model_rebuild() CodePart.model_rebuild() diff --git a/cognee/shared/SourceCodeGraph.py b/cognee/shared/SourceCodeGraph.py index 0fc8f9487..3de72c5fd 100644 --- a/cognee/shared/SourceCodeGraph.py +++ b/cognee/shared/SourceCodeGraph.py @@ -1,79 +1,90 @@ -from typing import Any, List, Union, Literal, Optional +from typing import Any, List, Literal, Optional, Union + from cognee.infrastructure.engine import DataPoint + class Variable(DataPoint): id: str name: str - type: Literal["Variable"] = "Variable" description: str is_static: Optional[bool] = False default_value: Optional[str] = None data_type: str _metadata = { - "index_fields": ["name"] + "index_fields": ["name"], + "type": "Variable" } class Operator(DataPoint): id: str name: str - type: Literal["Operator"] = "Operator" description: str return_type: str + _metadata = { + "index_fields": ["name"], + "type": "Operator" + } class Class(DataPoint): id: str name: str - type: Literal["Class"] = "Class" description: str constructor_parameters: List[Variable] extended_from_class: Optional["Class"] = None has_methods: List["Function"] _metadata = { - "index_fields": ["name"] + "index_fields": ["name"], + "type": "Class" } class ClassInstance(DataPoint): id: str name: str - type: Literal["ClassInstance"] = "ClassInstance" description: str from_class: Class instantiated_by: Union["Function"] instantiation_arguments: List[Variable] _metadata = { - "index_fields": ["name"] + "index_fields": ["name"], + "type": "ClassInstance" } class Function(DataPoint): id: str name: str - type: Literal["Function"] = "Function" description: str parameters: List[Variable] return_type: str is_static: Optional[bool] = False _metadata = { - "index_fields": ["name"] + "index_fields": ["name"], + "type": "Function" } class FunctionCall(DataPoint): id: str - type: Literal["FunctionCall"] = "FunctionCall" called_by: Union[Function, Literal["main"]] function_called: Function function_arguments: List[Any] + _metadata = { + "index_fields": [], + "type": "FunctionCall" + } class Expression(DataPoint): id: str name: str - type: Literal["Expression"] = "Expression" description: str expression: str members: List[Union[Variable, Function, Operator, "Expression"]] + _metadata = { + "index_fields": ["name"], + "type": "Expression" + } class SourceCodeGraph(DataPoint): id: str @@ -89,8 +100,13 @@ class SourceCodeGraph(DataPoint): Operator, Expression, ]] + _metadata = { + "index_fields": ["name"], + "type": "SourceCodeGraph" + } + Class.model_rebuild() ClassInstance.model_rebuild() Expression.model_rebuild() FunctionCall.model_rebuild() -SourceCodeGraph.model_rebuild() +SourceCodeGraph.model_rebuild() \ No newline at end of file diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 786168b58..857e4d777 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -1,6 +1,7 @@ from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.engine import DataPoint + async def index_data_points(data_points: list[DataPoint]): created_indexes = {} index_points = {} @@ -80,11 +81,20 @@ if __name__ == "__main__": class Car(DataPoint): model: str color: str + _metadata = { + "index_fields": ["name"], + "type": "Car" + } + class Person(DataPoint): name: str age: int owns_car: list[Car] + _metadata = { + "index_fields": ["name"], + "type": "Person" + } car1 = Car(model = "Tesla Model S", color = "Blue") car2 = Car(model = "Toyota Camry", color = "Red") @@ -92,4 +102,4 @@ if __name__ == "__main__": data_points = get_data_points_from_model(person) - print(data_points) + print(data_points) \ No newline at end of file diff --git a/cognee/tasks/summarization/models.py b/cognee/tasks/summarization/models.py index 6fef4fb02..add448155 100644 --- a/cognee/tasks/summarization/models.py +++ b/cognee/tasks/summarization/models.py @@ -10,6 +10,7 @@ class TextSummary(DataPoint): _metadata: dict = { "index_fields": ["text"], + "type": "TextSummary" } @@ -20,4 +21,5 @@ class CodeSummary(DataPoint): _metadata: dict = { "index_fields": ["text"], + "type": "CodeSummary" } diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py index 016f2be33..06c74c854 100644 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py @@ -2,7 +2,7 @@ import asyncio import random import time from typing import List -from uuid import uuid5, NAMESPACE_OID +from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.engine import DataPoint from cognee.modules.graph.utils import get_graph_from_model @@ -11,16 +11,28 @@ random.seed(1500) class Repository(DataPoint): path: str + _metadata = { + "index_fields": [], + "type": "Repository" + } class CodeFile(DataPoint): part_of: Repository contains: List["CodePart"] = [] depends_on: List["CodeFile"] = [] source_code: str + _metadata = { + "index_fields": [], + "type": "CodeFile" + } class CodePart(DataPoint): part_of: CodeFile source_code: str + _metadata = { + "index_fields": [], + "type": "CodePart" + } CodeFile.model_rebuild() CodePart.model_rebuild() diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py index 000d45c15..499dc9f3f 100644 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py @@ -1,25 +1,42 @@ import asyncio import random from typing import List -from uuid import uuid5, NAMESPACE_OID +from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.engine import DataPoint from cognee.modules.graph.utils import get_graph_from_model + class Document(DataPoint): path: str + _metadata = { + "index_fields": [], + "type": "Document" + } class DocumentChunk(DataPoint): part_of: Document text: str contains: List["Entity"] = None + _metadata = { + "index_fields": ["text"], + "type": "DocumentChunk" + } class EntityType(DataPoint): name: str + _metadata = { + "index_fields": ["name"], + "type": "EntityType" + } class Entity(DataPoint): name: str is_type: EntityType + _metadata = { + "index_fields": ["name"], + "type": "Entity" + } DocumentChunk.model_rebuild() From 9e7ab6492a87f18126ccc9ac5a76219c78a19003 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 17 Dec 2024 11:31:31 +0100 Subject: [PATCH 6/8] =?UTF-8?q?feat:=20outsources=20chunking=20parameters?= =?UTF-8?q?=20to=20extract=20chunk=20from=20documents=20=E2=80=A6=20(#289)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: outsources chunking parameters to extract chunk from documents task --- .../processing/document_types/AudioDocument.py | 7 ++++--- .../processing/document_types/ChunkerMapping.py | 15 +++++++++++++++ .../data/processing/document_types/Document.py | 4 ++-- .../processing/document_types/ImageDocument.py | 7 ++++--- .../data/processing/document_types/PdfDocument.py | 7 ++++--- .../processing/document_types/TextDocument.py | 8 +++++--- .../documents/extract_chunks_from_documents.py | 4 ++-- .../integration/documents/AudioDocument_test.py | 2 +- .../integration/documents/ImageDocument_test.py | 2 +- .../integration/documents/PdfDocument_test.py | 2 +- .../integration/documents/TextDocument_test.py | 2 +- 11 files changed, 40 insertions(+), 20 deletions(-) create mode 100644 cognee/modules/data/processing/document_types/ChunkerMapping.py diff --git a/cognee/modules/data/processing/document_types/AudioDocument.py b/cognee/modules/data/processing/document_types/AudioDocument.py index 0d2cddd3d..268338703 100644 --- a/cognee/modules/data/processing/document_types/AudioDocument.py +++ b/cognee/modules/data/processing/document_types/AudioDocument.py @@ -1,6 +1,6 @@ from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class AudioDocument(Document): type: str = "audio" @@ -9,11 +9,12 @@ class AudioDocument(Document): result = get_llm_client().create_transcript(self.raw_data_location) return(result.text) - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): # Transcribe the audio file text = self.create_transcript() - chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text]) + chunker_func = ChunkerConfig.get_chunker(chunker) + chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text]) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/ChunkerMapping.py b/cognee/modules/data/processing/document_types/ChunkerMapping.py new file mode 100644 index 000000000..14dbb8bb7 --- /dev/null +++ b/cognee/modules/data/processing/document_types/ChunkerMapping.py @@ -0,0 +1,15 @@ +from cognee.modules.chunking.TextChunker import TextChunker + +class ChunkerConfig: + chunker_mapping = { + "text_chunker": TextChunker + } + + @classmethod + def get_chunker(cls, chunker_name: str): + chunker_class = cls.chunker_mapping.get(chunker_name) + if chunker_class is None: + raise NotImplementedError( + f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}" + ) + return chunker_class \ No newline at end of file diff --git a/cognee/modules/data/processing/document_types/Document.py b/cognee/modules/data/processing/document_types/Document.py index 924ffabac..8d6a3dafb 100644 --- a/cognee/modules/data/processing/document_types/Document.py +++ b/cognee/modules/data/processing/document_types/Document.py @@ -13,5 +13,5 @@ class Document(DataPoint): "type": "Document" } - def read(self, chunk_size: int) -> str: - pass \ No newline at end of file + def read(self, chunk_size: int, chunker = str) -> str: + pass diff --git a/cognee/modules/data/processing/document_types/ImageDocument.py b/cognee/modules/data/processing/document_types/ImageDocument.py index e8f0dd8ee..352486bd8 100644 --- a/cognee/modules/data/processing/document_types/ImageDocument.py +++ b/cognee/modules/data/processing/document_types/ImageDocument.py @@ -1,6 +1,6 @@ from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class ImageDocument(Document): type: str = "image" @@ -10,10 +10,11 @@ class ImageDocument(Document): result = get_llm_client().transcribe_image(self.raw_data_location) return(result.choices[0].message.content) - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): # Transcribe the image file text = self.transcribe_image() - chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text]) + chunker_func = ChunkerConfig.get_chunker(chunker) + chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text]) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/PdfDocument.py b/cognee/modules/data/processing/document_types/PdfDocument.py index 2d1941996..361214718 100644 --- a/cognee/modules/data/processing/document_types/PdfDocument.py +++ b/cognee/modules/data/processing/document_types/PdfDocument.py @@ -1,11 +1,11 @@ from pypdf import PdfReader -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class PdfDocument(Document): type: str = "pdf" - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): file = PdfReader(self.raw_data_location) def get_text(): @@ -13,7 +13,8 @@ class PdfDocument(Document): page_text = page.extract_text() yield page_text - chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text) + chunker_func = ChunkerConfig.get_chunker(chunker) + chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/TextDocument.py b/cognee/modules/data/processing/document_types/TextDocument.py index 32d3416b9..3952d9845 100644 --- a/cognee/modules/data/processing/document_types/TextDocument.py +++ b/cognee/modules/data/processing/document_types/TextDocument.py @@ -1,10 +1,10 @@ -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class TextDocument(Document): type: str = "text" - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): def get_text(): with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file: while True: @@ -15,6 +15,8 @@ class TextDocument(Document): yield text - chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text) + chunker_func = ChunkerConfig.get_chunker(chunker) + + chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text) yield from chunker.read() diff --git a/cognee/tasks/documents/extract_chunks_from_documents.py b/cognee/tasks/documents/extract_chunks_from_documents.py index ec19a786d..423b87b69 100644 --- a/cognee/tasks/documents/extract_chunks_from_documents.py +++ b/cognee/tasks/documents/extract_chunks_from_documents.py @@ -1,7 +1,7 @@ from cognee.modules.data.processing.document_types.Document import Document -async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024): +async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024, chunker = 'text_chunker'): for document in documents: - for document_chunk in document.read(chunk_size = chunk_size): + for document_chunk in document.read(chunk_size = chunk_size, chunker = chunker): yield document_chunk diff --git a/cognee/tests/integration/documents/AudioDocument_test.py b/cognee/tests/integration/documents/AudioDocument_test.py index da8b85d0b..151f4c0b2 100644 --- a/cognee/tests/integration/documents/AudioDocument_test.py +++ b/cognee/tests/integration/documents/AudioDocument_test.py @@ -31,7 +31,7 @@ def test_AudioDocument(): ) with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64) + GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count diff --git a/cognee/tests/integration/documents/ImageDocument_test.py b/cognee/tests/integration/documents/ImageDocument_test.py index 8a8ee8ef3..40e0155af 100644 --- a/cognee/tests/integration/documents/ImageDocument_test.py +++ b/cognee/tests/integration/documents/ImageDocument_test.py @@ -21,7 +21,7 @@ def test_ImageDocument(): with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64) + GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count diff --git a/cognee/tests/integration/documents/PdfDocument_test.py b/cognee/tests/integration/documents/PdfDocument_test.py index ac57eaf75..25d4cf6c6 100644 --- a/cognee/tests/integration/documents/PdfDocument_test.py +++ b/cognee/tests/integration/documents/PdfDocument_test.py @@ -22,7 +22,7 @@ def test_PdfDocument(): ) for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=1024) + GROUND_TRUTH, document.read(chunk_size=1024, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count diff --git a/cognee/tests/integration/documents/TextDocument_test.py b/cognee/tests/integration/documents/TextDocument_test.py index f663418f5..91f38968e 100644 --- a/cognee/tests/integration/documents/TextDocument_test.py +++ b/cognee/tests/integration/documents/TextDocument_test.py @@ -33,7 +33,7 @@ def test_TextDocument(input_file, chunk_size): ) for ground_truth, paragraph_data in zip( - GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size) + GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count From da5e3ab24de6c07f583dec15cd67f6dbcec25e8e Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Tue, 17 Dec 2024 12:02:25 +0100 Subject: [PATCH 7/8] COG 870 Remove duplicate edges from the code graph (#293) * feat: turn summarize_code into generator * feat: extract run_code_graph_pipeline, update the pipeline * feat: minimal code graph example * refactor: update argument * refactor: move run_code_graph_pipeline to cognify/code_graph_pipeline * refactor: indentation and whitespace nits * refactor: add deprecated use comments and warnings --------- Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com> Co-authored-by: Boris --- cognee/api/v1/cognify/code_graph_pipeline.py | 36 +++++++++++++++ cognee/tasks/summarization/summarize_code.py | 40 +++++++++-------- evals/eval_swe_bench.py | 47 +++----------------- examples/python/code_graph_example.py | 15 +++++++ 4 files changed, 80 insertions(+), 58 deletions(-) create mode 100644 examples/python/code_graph_example.py diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 59c658300..3c72e0793 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -1,8 +1,14 @@ +# NOTICE: This module contains deprecated functions. +# Use only the run_code_graph_pipeline function; all other functions are deprecated. +# Related issue: COG-906 + import asyncio import logging +from pathlib import Path from typing import Union from cognee.shared.SourceCodeGraph import SourceCodeGraph +from cognee.shared.data_models import SummarizedContent from cognee.shared.utils import send_telemetry from cognee.modules.data.models import Dataset, Data from cognee.modules.data.methods.get_dataset_data import get_dataset_data @@ -16,7 +22,9 @@ from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents from cognee.tasks.graph import extract_graph_from_code +from cognee.tasks.repo_processor import get_repo_file_dependencies, enrich_dependency_graph, expand_dependency_graph from cognee.tasks.storage import add_data_points +from cognee.tasks.summarization import summarize_code logger = logging.getLogger("code_graph_pipeline") @@ -51,6 +59,7 @@ async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User async def run_pipeline(dataset: Dataset, user: User): + '''DEPRECATED: Use `run_code_graph_pipeline` instead. This function will be removed.''' data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id) document_ids_str = [str(document.id) for document in data_documents] @@ -103,3 +112,30 @@ async def run_pipeline(dataset: Dataset, user: User): def generate_dataset_name(dataset_name: str) -> str: return dataset_name.replace(".", "_").replace(" ", "_") + + +async def run_code_graph_pipeline(repo_path): + import os + import pathlib + import cognee + from cognee.infrastructure.databases.relational import create_db_and_tables + + file_path = Path(__file__).parent + data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve()) + cognee.config.data_root_directory(data_directory_path) + cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve()) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await create_db_and_tables() + + tasks = [ + Task(get_repo_file_dependencies), + Task(enrich_dependency_graph, task_config={"batch_size": 50}), + Task(expand_dependency_graph, task_config={"batch_size": 50}), + Task(summarize_code, summarization_model=SummarizedContent, task_config={"batch_size": 50}), + Task(add_data_points, task_config={"batch_size": 50}), + ] + + return run_tasks(tasks, repo_path, "cognify_code_pipeline") diff --git a/cognee/tasks/summarization/summarize_code.py b/cognee/tasks/summarization/summarize_code.py index 277081f40..76435186c 100644 --- a/cognee/tasks/summarization/summarize_code.py +++ b/cognee/tasks/summarization/summarize_code.py @@ -1,39 +1,43 @@ import asyncio -from typing import Type from uuid import uuid5 +from typing import Type from pydantic import BaseModel from cognee.infrastructure.engine import DataPoint from cognee.modules.data.extraction.extract_summary import extract_summary from cognee.shared.CodeGraphEntities import CodeFile -from cognee.tasks.storage import add_data_points - from .models import CodeSummary async def summarize_code( - code_files: list[DataPoint], + code_graph_nodes: list[DataPoint], summarization_model: Type[BaseModel], ) -> list[DataPoint]: - if len(code_files) == 0: - return code_files + if len(code_graph_nodes) == 0: + return - code_files_data_points = [file for file in code_files if isinstance(file, CodeFile)] + code_files_data_points = [file for file in code_graph_nodes if isinstance(file, CodeFile)] file_summaries = await asyncio.gather( *[extract_summary(file.source_code, summarization_model) for file in code_files_data_points] ) - summaries = [ - CodeSummary( - id = uuid5(file.id, "CodeSummary"), - made_from = file, - text = file_summaries[file_index].summary, + file_summaries_map = { + code_file_data_point.extracted_id: file_summary.summary + for code_file_data_point, file_summary in zip(code_files_data_points, file_summaries) + } + + for node in code_graph_nodes: + if not isinstance(node, DataPoint): + continue + yield node + + if not isinstance(node, CodeFile): + continue + + yield CodeSummary( + id=uuid5(node.id, "CodeSummary"), + made_from=node, + text=file_summaries_map[node.extracted_id], ) - for (file_index, file) in enumerate(code_files_data_points) - ] - - await add_data_points(summaries) - - return code_files diff --git a/evals/eval_swe_bench.py b/evals/eval_swe_bench.py index 67826fc12..6c2280d80 100644 --- a/evals/eval_swe_bench.py +++ b/evals/eval_swe_bench.py @@ -7,19 +7,13 @@ from pathlib import Path from swebench.harness.utils import load_swebench_dataset from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE +from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline from cognee.api.v1.search import SearchType from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt -from cognee.modules.pipelines import Task, run_tasks from cognee.modules.retrieval.brute_force_triplet_search import \ brute_force_triplet_search -# from cognee.shared.data_models import SummarizedContent from cognee.shared.utils import render_graph -from cognee.tasks.repo_processor import (enrich_dependency_graph, - expand_dependency_graph, - get_repo_file_dependencies) -from cognee.tasks.storage import add_data_points -# from cognee.tasks.summarization import summarize_code from evals.eval_utils import download_github_repo, retrieved_edges_to_string @@ -42,48 +36,22 @@ def check_install_package(package_name): async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS): - import os - import pathlib - import cognee - from cognee.infrastructure.databases.relational import create_db_and_tables - - file_path = Path(__file__).parent - data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve()) - cognee.config.data_root_directory(data_directory_path) - cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve()) - cognee.config.system_root_directory(cognee_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata = True) - - await create_db_and_tables() - - # repo_path = download_github_repo(instance, '../RAW_GIT_REPOS') - - repo_path = '/Users/borisarzentar/Projects/graphrag' - - tasks = [ - Task(get_repo_file_dependencies), - Task(enrich_dependency_graph, task_config = { "batch_size": 50 }), - Task(expand_dependency_graph, task_config = { "batch_size": 50 }), - Task(add_data_points, task_config = { "batch_size": 50 }), - # Task(summarize_code, summarization_model = SummarizedContent), - ] - - pipeline = run_tasks(tasks, repo_path, "cognify_code_pipeline") + repo_path = download_github_repo(instance, '../RAW_GIT_REPOS') + pipeline = await run_code_graph_pipeline(repo_path) async for result in pipeline: print(result) print('Here we have the repo under the repo_path') - await render_graph(None, include_labels = True, include_nodes = True) + await render_graph(None, include_labels=True, include_nodes=True) problem_statement = instance['problem_statement'] instructions = read_query_prompt("patch_gen_kg_instructions.txt") - retrieved_edges = await brute_force_triplet_search(problem_statement, top_k = 3, collections = ["data_point_source_code", "data_point_text"]) - + retrieved_edges = await brute_force_triplet_search(problem_statement, top_k=3, + collections=["data_point_source_code", "data_point_text"]) + retrieved_edges_str = retrieved_edges_to_string(retrieved_edges) prompt = "\n".join([ @@ -171,7 +139,6 @@ async def main(): with open(predictions_path, "w") as file: json.dump(preds, file) - subprocess.run( [ "python", diff --git a/examples/python/code_graph_example.py b/examples/python/code_graph_example.py new file mode 100644 index 000000000..9189de46c --- /dev/null +++ b/examples/python/code_graph_example.py @@ -0,0 +1,15 @@ +import argparse +import asyncio +from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline + + +async def main(repo_path): + async for result in await run_code_graph_pipeline(repo_path): + print(result) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--repo-path", type=str, required=True, help="Path to the repository") + args = parser.parse_args() + asyncio.run(main(args.repo_path)) + From 9afd0ece63bdcda318afc498d0643fb649ac1302 Mon Sep 17 00:00:00 2001 From: alekszievr <44192193+alekszievr@users.noreply.github.com> Date: Tue, 17 Dec 2024 13:05:47 +0100 Subject: [PATCH 8/8] Structured code summarization (#375) * feat: turn summarize_code into generator * feat: extract run_code_graph_pipeline, update the pipeline * feat: minimal code graph example * refactor: update argument * refactor: move run_code_graph_pipeline to cognify/code_graph_pipeline * refactor: indentation and whitespace nits * refactor: add deprecated use comments and warnings * Structured code summarization * add missing prompt file * Remove summarization_model argument from summarize_code and fix typehinting * minor refactors --------- Co-authored-by: lxobr <122801072+lxobr@users.noreply.github.com> Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com> Co-authored-by: Boris --- cognee/api/v1/cognify/code_graph_pipeline.py | 31 +++++++++++-------- .../llm/prompts/summarize_code.txt | 10 ++++++ .../data/extraction/extract_summary.py | 10 +++++- cognee/shared/data_models.py | 27 +++++++++++++++- cognee/tasks/summarization/summarize_code.py | 19 +++++------- 5 files changed, 71 insertions(+), 26 deletions(-) create mode 100644 cognee/infrastructure/llm/prompts/summarize_code.txt diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 3c72e0793..eeb10d69e 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -7,22 +7,27 @@ import logging from pathlib import Path from typing import Union -from cognee.shared.SourceCodeGraph import SourceCodeGraph -from cognee.shared.data_models import SummarizedContent -from cognee.shared.utils import send_telemetry -from cognee.modules.data.models import Dataset, Data -from cognee.modules.data.methods.get_dataset_data import get_dataset_data from cognee.modules.data.methods import get_datasets, get_datasets_by_name -from cognee.modules.pipelines.tasks.Task import Task +from cognee.modules.data.methods.get_dataset_data import get_dataset_data +from cognee.modules.data.models import Data, Dataset from cognee.modules.pipelines import run_tasks -from cognee.modules.users.models import User -from cognee.modules.users.methods import get_default_user from cognee.modules.pipelines.models import PipelineRunStatus -from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status -from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status -from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents +from cognee.modules.pipelines.operations.get_pipeline_status import \ + get_pipeline_status +from cognee.modules.pipelines.operations.log_pipeline_status import \ + log_pipeline_status +from cognee.modules.pipelines.tasks.Task import Task +from cognee.modules.users.methods import get_default_user +from cognee.modules.users.models import User +from cognee.shared.SourceCodeGraph import SourceCodeGraph +from cognee.shared.utils import send_telemetry +from cognee.tasks.documents import (check_permissions_on_documents, + classify_documents, + extract_chunks_from_documents) from cognee.tasks.graph import extract_graph_from_code -from cognee.tasks.repo_processor import get_repo_file_dependencies, enrich_dependency_graph, expand_dependency_graph +from cognee.tasks.repo_processor import (enrich_dependency_graph, + expand_dependency_graph, + get_repo_file_dependencies) from cognee.tasks.storage import add_data_points from cognee.tasks.summarization import summarize_code @@ -134,7 +139,7 @@ async def run_code_graph_pipeline(repo_path): Task(get_repo_file_dependencies), Task(enrich_dependency_graph, task_config={"batch_size": 50}), Task(expand_dependency_graph, task_config={"batch_size": 50}), - Task(summarize_code, summarization_model=SummarizedContent, task_config={"batch_size": 50}), + Task(summarize_code, task_config={"batch_size": 50}), Task(add_data_points, task_config={"batch_size": 50}), ] diff --git a/cognee/infrastructure/llm/prompts/summarize_code.txt b/cognee/infrastructure/llm/prompts/summarize_code.txt new file mode 100644 index 000000000..405585617 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/summarize_code.txt @@ -0,0 +1,10 @@ +You are an expert Python programmer and technical writer. Your task is to summarize the given Python code snippet or file. +The code may contain multiple imports, classes, functions, constants and logic. Provide a clear, structured explanation of its components +and their relationships. + +Instructions: +Provide an overview: Start with a high-level summary of what the code does as a whole. +Break it down: Summarize each class and function individually, explaining their purpose and how they interact. +Describe the workflow: Outline how the classes and functions work together. Mention any control flow (e.g., main functions, entry points, loops). +Key features: Highlight important elements like arguments, return values, or unique logic. +Maintain clarity: Write in plain English for someone familiar with Python but unfamiliar with this code. \ No newline at end of file diff --git a/cognee/modules/data/extraction/extract_summary.py b/cognee/modules/data/extraction/extract_summary.py index a17bf3ae6..10d429da9 100644 --- a/cognee/modules/data/extraction/extract_summary.py +++ b/cognee/modules/data/extraction/extract_summary.py @@ -1,7 +1,11 @@ from typing import Type + from pydantic import BaseModel -from cognee.infrastructure.llm.prompts import read_query_prompt + from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm.prompts import read_query_prompt +from cognee.shared.data_models import SummarizedCode + async def extract_summary(content: str, response_model: Type[BaseModel]): llm_client = get_llm_client() @@ -11,3 +15,7 @@ async def extract_summary(content: str, response_model: Type[BaseModel]): llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model) return llm_output + +async def extract_code_summary(content: str): + + return await extract_summary(content, response_model=SummarizedCode) diff --git a/cognee/shared/data_models.py b/cognee/shared/data_models.py index 6cb4d436a..dec53cfcb 100644 --- a/cognee/shared/data_models.py +++ b/cognee/shared/data_models.py @@ -1,9 +1,11 @@ """Data models for the cognitive architecture.""" from enum import Enum, auto -from typing import Optional, List, Union, Dict, Any +from typing import Any, Dict, List, Optional, Union + from pydantic import BaseModel, Field + class Node(BaseModel): """Node in a knowledge graph.""" id: str @@ -194,6 +196,29 @@ class SummarizedContent(BaseModel): summary: str description: str +class SummarizedFunction(BaseModel): + name: str + description: str + inputs: Optional[List[str]] = None + outputs: Optional[List[str]] = None + decorators: Optional[List[str]] = None + +class SummarizedClass(BaseModel): + name: str + description: str + methods: Optional[List[SummarizedFunction]] = None + decorators: Optional[List[str]] = None + +class SummarizedCode(BaseModel): + file_name: str + high_level_summary: str + key_features: List[str] + imports: List[str] = [] + constants: List[str] = [] + classes: List[SummarizedClass] = [] + functions: List[SummarizedFunction] = [] + workflow_description: Optional[str] = None + class GraphDBType(Enum): NETWORKX = auto() diff --git a/cognee/tasks/summarization/summarize_code.py b/cognee/tasks/summarization/summarize_code.py index 76435186c..b116e57a9 100644 --- a/cognee/tasks/summarization/summarize_code.py +++ b/cognee/tasks/summarization/summarize_code.py @@ -1,31 +1,28 @@ import asyncio +from typing import AsyncGenerator, Union from uuid import uuid5 from typing import Type -from pydantic import BaseModel - from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.extraction.extract_summary import extract_summary -from cognee.shared.CodeGraphEntities import CodeFile +from cognee.modules.data.extraction.extract_summary import extract_code_summary from .models import CodeSummary async def summarize_code( code_graph_nodes: list[DataPoint], - summarization_model: Type[BaseModel], -) -> list[DataPoint]: +) -> AsyncGenerator[Union[DataPoint, CodeSummary], None]: if len(code_graph_nodes) == 0: return - code_files_data_points = [file for file in code_graph_nodes if isinstance(file, CodeFile)] + code_data_points = [file for file in code_graph_nodes if hasattr(file, "source_code")] file_summaries = await asyncio.gather( - *[extract_summary(file.source_code, summarization_model) for file in code_files_data_points] + *[extract_code_summary(file.source_code) for file in code_data_points] ) file_summaries_map = { - code_file_data_point.extracted_id: file_summary.summary - for code_file_data_point, file_summary in zip(code_files_data_points, file_summaries) + code_data_point.extracted_id: str(file_summary) + for code_data_point, file_summary in zip(code_data_points, file_summaries) } for node in code_graph_nodes: @@ -33,7 +30,7 @@ async def summarize_code( continue yield node - if not isinstance(node, CodeFile): + if not hasattr(node, "source_code"): continue yield CodeSummary(