diff --git a/.github/workflows/test_deduplication.yml b/.github/workflows/test_deduplication.yml index 77a7ddbb7..a1aab3252 100644 --- a/.github/workflows/test_deduplication.yml +++ b/.github/workflows/test_deduplication.yml @@ -17,6 +17,7 @@ jobs: run_deduplication_test: name: test runs-on: ubuntu-latest + if: ${{ github.event.label.name == 'run-checks' }} defaults: run: shell: bash diff --git a/.github/workflows/test_milvus.yml b/.github/workflows/test_milvus.yml index 5cad72378..c4214fddf 100644 --- a/.github/workflows/test_milvus.yml +++ b/.github/workflows/test_milvus.yml @@ -18,6 +18,7 @@ jobs: run_milvus: name: test runs-on: ubuntu-latest + if: ${{ github.event.label.name == 'run-checks' }} strategy: fail-fast: false defaults: diff --git a/.github/workflows/test_neo4j.yml b/.github/workflows/test_neo4j.yml index 3f3a35e4f..5f955eefc 100644 --- a/.github/workflows/test_neo4j.yml +++ b/.github/workflows/test_neo4j.yml @@ -15,6 +15,7 @@ env: jobs: run_neo4j_integration_test: name: test + if: ${{ github.event.label.name == 'run-checks' }} runs-on: ubuntu-latest defaults: diff --git a/.github/workflows/test_pgvector.yml b/.github/workflows/test_pgvector.yml index a162d2cb4..9f5ffd1da 100644 --- a/.github/workflows/test_pgvector.yml +++ b/.github/workflows/test_pgvector.yml @@ -18,6 +18,7 @@ jobs: run_pgvector_integration_test: name: test runs-on: ubuntu-latest + if: ${{ github.event.label.name == 'run-checks' }} defaults: run: shell: bash diff --git a/.github/workflows/test_python_3_10.yml b/.github/workflows/test_python_3_10.yml index 39eb4e57a..770d2fd63 100644 --- a/.github/workflows/test_python_3_10.yml +++ b/.github/workflows/test_python_3_10.yml @@ -18,6 +18,7 @@ jobs: run_common: name: test runs-on: ubuntu-latest + if: ${{ github.event.label.name == 'run-checks' }} strategy: fail-fast: false defaults: diff --git a/.github/workflows/test_python_3_11.yml b/.github/workflows/test_python_3_11.yml index 2dd704eb9..69eb875bd 100644 --- a/.github/workflows/test_python_3_11.yml +++ b/.github/workflows/test_python_3_11.yml @@ -18,6 +18,7 @@ jobs: run_common: name: test runs-on: ubuntu-latest + if: ${{ github.event.label.name == 'run-checks' }} strategy: fail-fast: false defaults: diff --git a/.github/workflows/test_python_3_9.yml b/.github/workflows/test_python_3_9.yml index 99c2b9a7a..380c894ca 100644 --- a/.github/workflows/test_python_3_9.yml +++ b/.github/workflows/test_python_3_9.yml @@ -18,6 +18,7 @@ jobs: run_common: name: test runs-on: ubuntu-latest + if: ${{ github.event.label.name == 'run-checks' }} strategy: fail-fast: false defaults: diff --git a/.github/workflows/test_qdrant.yml b/.github/workflows/test_qdrant.yml index f0a2e3d3f..17d9ac628 100644 --- a/.github/workflows/test_qdrant.yml +++ b/.github/workflows/test_qdrant.yml @@ -18,6 +18,7 @@ jobs: run_qdrant_integration_test: name: test runs-on: ubuntu-latest + if: ${{ github.event.label.name == 'run-checks' }} defaults: run: diff --git a/.github/workflows/test_weaviate.yml b/.github/workflows/test_weaviate.yml index b8eb72383..9a3651dda 100644 --- a/.github/workflows/test_weaviate.yml +++ b/.github/workflows/test_weaviate.yml @@ -18,6 +18,7 @@ jobs: run_weaviate_integration_test: name: test runs-on: ubuntu-latest + if: ${{ github.event.label.name == 'run-checks' }} defaults: run: diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 59c658300..eeb10d69e 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -1,22 +1,35 @@ +# NOTICE: This module contains deprecated functions. +# Use only the run_code_graph_pipeline function; all other functions are deprecated. +# Related issue: COG-906 + import asyncio import logging +from pathlib import Path from typing import Union +from cognee.modules.data.methods import get_datasets, get_datasets_by_name +from cognee.modules.data.methods.get_dataset_data import get_dataset_data +from cognee.modules.data.models import Data, Dataset +from cognee.modules.pipelines import run_tasks +from cognee.modules.pipelines.models import PipelineRunStatus +from cognee.modules.pipelines.operations.get_pipeline_status import \ + get_pipeline_status +from cognee.modules.pipelines.operations.log_pipeline_status import \ + log_pipeline_status +from cognee.modules.pipelines.tasks.Task import Task +from cognee.modules.users.methods import get_default_user +from cognee.modules.users.models import User from cognee.shared.SourceCodeGraph import SourceCodeGraph from cognee.shared.utils import send_telemetry -from cognee.modules.data.models import Dataset, Data -from cognee.modules.data.methods.get_dataset_data import get_dataset_data -from cognee.modules.data.methods import get_datasets, get_datasets_by_name -from cognee.modules.pipelines.tasks.Task import Task -from cognee.modules.pipelines import run_tasks -from cognee.modules.users.models import User -from cognee.modules.users.methods import get_default_user -from cognee.modules.pipelines.models import PipelineRunStatus -from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status -from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status -from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents +from cognee.tasks.documents import (check_permissions_on_documents, + classify_documents, + extract_chunks_from_documents) from cognee.tasks.graph import extract_graph_from_code +from cognee.tasks.repo_processor import (enrich_dependency_graph, + expand_dependency_graph, + get_repo_file_dependencies) from cognee.tasks.storage import add_data_points +from cognee.tasks.summarization import summarize_code logger = logging.getLogger("code_graph_pipeline") @@ -51,6 +64,7 @@ async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User async def run_pipeline(dataset: Dataset, user: User): + '''DEPRECATED: Use `run_code_graph_pipeline` instead. This function will be removed.''' data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id) document_ids_str = [str(document.id) for document in data_documents] @@ -103,3 +117,30 @@ async def run_pipeline(dataset: Dataset, user: User): def generate_dataset_name(dataset_name: str) -> str: return dataset_name.replace(".", "_").replace(" ", "_") + + +async def run_code_graph_pipeline(repo_path): + import os + import pathlib + import cognee + from cognee.infrastructure.databases.relational import create_db_and_tables + + file_path = Path(__file__).parent + data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve()) + cognee.config.data_root_directory(data_directory_path) + cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve()) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await create_db_and_tables() + + tasks = [ + Task(get_repo_file_dependencies), + Task(enrich_dependency_graph, task_config={"batch_size": 50}), + Task(expand_dependency_graph, task_config={"batch_size": 50}), + Task(summarize_code, task_config={"batch_size": 50}), + Task(add_data_points, task_config={"batch_size": 50}), + ] + + return run_tasks(tasks, repo_path, "cognify_code_pipeline") diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index 2c45774ee..c14f00978 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -69,17 +69,18 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, graph_model: BaseMo send_telemetry("cognee.cognify EXECUTION STARTED", user.id) - async with update_status_lock: - task_status = await get_pipeline_status([dataset_id]) + #async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests + task_status = await get_pipeline_status([dataset_id]) - if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED: - logger.info("Dataset %s is already being processed.", dataset_name) - return + if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED: + logger.info("Dataset %s is already being processed.", dataset_name) + return + + await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, { + "dataset_name": dataset_name, + "files": document_ids_str, + }) - await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, { - "dataset_name": dataset_name, - "files": document_ids_str, - }) try: cognee_config = get_cognify_config() diff --git a/cognee/api/v1/cognify/routers/get_cognify_router.py b/cognee/api/v1/cognify/routers/get_cognify_router.py index 9616fa71c..257ac994f 100644 --- a/cognee/api/v1/cognify/routers/get_cognify_router.py +++ b/cognee/api/v1/cognify/routers/get_cognify_router.py @@ -1,13 +1,15 @@ from fastapi import APIRouter -from typing import List +from typing import List, Optional from pydantic import BaseModel from cognee.modules.users.models import User from fastapi.responses import JSONResponse from cognee.modules.users.methods import get_authenticated_user from fastapi import Depends +from cognee.shared.data_models import KnowledgeGraph class CognifyPayloadDTO(BaseModel): datasets: List[str] + graph_model: Optional[BaseModel] = KnowledgeGraph def get_cognify_router() -> APIRouter: router = APIRouter() @@ -17,11 +19,11 @@ def get_cognify_router() -> APIRouter: """ This endpoint is responsible for the cognitive processing of the content.""" from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify try: - await cognee_cognify(payload.datasets, user) + await cognee_cognify(payload.datasets, user, payload.graph_model) except Exception as error: return JSONResponse( status_code=409, content={"error": str(error)} ) - return router \ No newline at end of file + return router diff --git a/cognee/api/v1/search/search_v2.py b/cognee/api/v1/search/search_v2.py index 6a5da4648..222ec6791 100644 --- a/cognee/api/v1/search/search_v2.py +++ b/cognee/api/v1/search/search_v2.py @@ -1,7 +1,7 @@ import json from uuid import UUID from enum import Enum -from typing import Callable, Dict +from typing import Callable, Dict, Union from cognee.exceptions import InvalidValueError from cognee.modules.search.operations import log_query, log_result @@ -22,7 +22,12 @@ class SearchType(Enum): CHUNKS = "CHUNKS" COMPLETION = "COMPLETION" -async def search(query_type: SearchType, query_text: str, user: User = None) -> list: +async def search(query_type: SearchType, query_text: str, user: User = None, + datasets: Union[list[str], str, None] = None) -> list: + # We use lists from now on for datasets + if isinstance(datasets, str): + datasets = [datasets] + if user is None: user = await get_default_user() @@ -31,7 +36,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) -> query = await log_query(query_text, str(query_type), user.id) - own_document_ids = await get_document_ids_for_user(user.id) + own_document_ids = await get_document_ids_for_user(user.id, datasets) search_results = await specific_search(query_type, query_text, user) filtered_search_results = [] diff --git a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py index fdc7db069..324ee7bcd 100644 --- a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py +++ b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py @@ -1,21 +1,26 @@ import asyncio # from datetime import datetime import json -from uuid import UUID from textwrap import dedent +from uuid import UUID + from falkordb import FalkorDB from cognee.exceptions import InvalidValueError -from cognee.infrastructure.engine import DataPoint -from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface +from cognee.infrastructure.databases.graph.graph_db_interface import \ + GraphDBInterface from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine -from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface +from cognee.infrastructure.databases.vector.vector_db_interface import \ + VectorDBInterface +from cognee.infrastructure.engine import DataPoint + class IndexSchema(DataPoint): text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 37d340004..1b3fc55c3 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -1,25 +1,29 @@ -from typing import List, Optional, get_type_hints, Generic, TypeVar import asyncio +from typing import Generic, List, Optional, TypeVar, get_type_hints from uuid import UUID + import lancedb +from lancedb.pydantic import LanceModel, Vector from pydantic import BaseModel -from lancedb.pydantic import Vector, LanceModel from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.files.storage import LocalStorage from cognee.modules.storage.utils import copy_model, get_own_properties -from ..models.ScoredResult import ScoredResult -from ..vector_db_interface import VectorDBInterface -from ..utils import normalize_distances + from ..embeddings.EmbeddingEngine import EmbeddingEngine +from ..models.ScoredResult import ScoredResult +from ..utils import normalize_distances +from ..vector_db_interface import VectorDBInterface + class IndexSchema(DataPoint): id: str text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } class LanceDBAdapter(VectorDBInterface): diff --git a/cognee/infrastructure/databases/vector/milvus/MilvusAdapter.py b/cognee/infrastructure/databases/vector/milvus/MilvusAdapter.py index 4e5290dd1..0d4ea05d3 100644 --- a/cognee/infrastructure/databases/vector/milvus/MilvusAdapter.py +++ b/cognee/infrastructure/databases/vector/milvus/MilvusAdapter.py @@ -4,10 +4,12 @@ import asyncio import logging from typing import List, Optional from uuid import UUID + from cognee.infrastructure.engine import DataPoint -from ..vector_db_interface import VectorDBInterface -from ..models.ScoredResult import ScoredResult + from ..embeddings.EmbeddingEngine import EmbeddingEngine +from ..models.ScoredResult import ScoredResult +from ..vector_db_interface import VectorDBInterface logger = logging.getLogger("MilvusAdapter") @@ -16,7 +18,8 @@ class IndexSchema(DataPoint): text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 8faf1cd6d..3f0565253 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -1,27 +1,30 @@ import asyncio -from uuid import UUID from typing import List, Optional, get_type_hints +from uuid import UUID + from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy import JSON, Column, Table, select, delete +from sqlalchemy import JSON, Column, Table, select, delete, MetaData from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from cognee.exceptions import InvalidValueError from cognee.infrastructure.databases.exceptions import EntityNotFoundError from cognee.infrastructure.engine import DataPoint -from .serialize_data import serialize_data -from ..models.ScoredResult import ScoredResult -from ..vector_db_interface import VectorDBInterface -from ..utils import normalize_distances -from ..embeddings.EmbeddingEngine import EmbeddingEngine -from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from ...relational.ModelBase import Base +from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter +from ..embeddings.EmbeddingEngine import EmbeddingEngine +from ..models.ScoredResult import ScoredResult +from ..utils import normalize_distances +from ..vector_db_interface import VectorDBInterface +from .serialize_data import serialize_data + class IndexSchema(DataPoint): text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): @@ -48,10 +51,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def has_collection(self, collection_name: str) -> bool: async with self.engine.begin() as connection: - # Load the schema information into the MetaData object - await connection.run_sync(Base.metadata.reflect) + # Create a MetaData instance to load table information + metadata = MetaData() + # Load table information from schema into MetaData + await connection.run_sync(metadata.reflect) - if collection_name in Base.metadata.tables: + if collection_name in metadata.tables: return True else: return False @@ -87,6 +92,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def create_data_points( self, collection_name: str, data_points: List[DataPoint] ): + data_point_types = get_type_hints(DataPoint) if not await self.has_collection(collection_name): await self.create_collection( collection_name = collection_name, @@ -106,7 +112,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): primary_key: Mapped[int] = mapped_column( primary_key=True, autoincrement=True ) - id: Mapped[type(data_points[0].id)] + id: Mapped[data_point_types["id"]] payload = Column(JSON) vector = Column(self.Vector(vector_size)) @@ -145,10 +151,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): with an async engine. """ async with self.engine.begin() as connection: - # Load the schema information into the MetaData object - await connection.run_sync(Base.metadata.reflect) - if collection_name in Base.metadata.tables: - return Base.metadata.tables[collection_name] + # Create a MetaData instance to load table information + metadata = MetaData() + # Load table information from schema into MetaData + await connection.run_sync(metadata.reflect) + if collection_name in metadata.tables: + return metadata.tables[collection_name] else: raise EntityNotFoundError(message=f"Table '{collection_name}' not found.") diff --git a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py index d5d2a1a5c..b63139bc5 100644 --- a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +++ b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py @@ -1,13 +1,16 @@ import logging +from typing import Dict, List, Optional from uuid import UUID -from typing import List, Dict, Optional + from qdrant_client import AsyncQdrantClient, models from cognee.exceptions import InvalidValueError -from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult +from cognee.infrastructure.databases.vector.models.ScoredResult import \ + ScoredResult from cognee.infrastructure.engine import DataPoint -from ..vector_db_interface import VectorDBInterface + from ..embeddings.EmbeddingEngine import EmbeddingEngine +from ..vector_db_interface import VectorDBInterface logger = logging.getLogger("QDrantAdapter") @@ -15,7 +18,8 @@ class IndexSchema(DataPoint): text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } # class CollectionConfig(BaseModel, extra = "forbid"): diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index c16f765b0..31162b1b5 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -5,9 +5,10 @@ from uuid import UUID from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine import DataPoint -from ..vector_db_interface import VectorDBInterface -from ..models.ScoredResult import ScoredResult + from ..embeddings.EmbeddingEngine import EmbeddingEngine +from ..models.ScoredResult import ScoredResult +from ..vector_db_interface import VectorDBInterface logger = logging.getLogger("WeaviateAdapter") @@ -15,7 +16,8 @@ class IndexSchema(DataPoint): text: str _metadata: dict = { - "index_fields": ["text"] + "index_fields": ["text"], + "type": "IndexSchema" } class WeaviateAdapter(VectorDBInterface): diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index abb924f2f..e08041146 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -1,8 +1,10 @@ -from typing_extensions import TypedDict -from uuid import UUID, uuid4 -from typing import Optional from datetime import datetime, timezone +from typing import Optional +from uuid import UUID, uuid4 + from pydantic import BaseModel, Field +from typing_extensions import TypedDict + class MetaData(TypedDict): index_fields: list[str] @@ -13,7 +15,8 @@ class DataPoint(BaseModel): updated_at: Optional[datetime] = datetime.now(timezone.utc) topological_rank: Optional[int] = 0 _metadata: Optional[MetaData] = { - "index_fields": [] + "index_fields": [], + "type": "DataPoint" } # class Config: @@ -39,4 +42,4 @@ class DataPoint(BaseModel): @classmethod def get_embeddable_property_names(self, data_point): - return data_point._metadata["index_fields"] or [] + return data_point._metadata["index_fields"] or [] \ No newline at end of file diff --git a/cognee/infrastructure/llm/prompts/summarize_code.txt b/cognee/infrastructure/llm/prompts/summarize_code.txt new file mode 100644 index 000000000..405585617 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/summarize_code.txt @@ -0,0 +1,10 @@ +You are an expert Python programmer and technical writer. Your task is to summarize the given Python code snippet or file. +The code may contain multiple imports, classes, functions, constants and logic. Provide a clear, structured explanation of its components +and their relationships. + +Instructions: +Provide an overview: Start with a high-level summary of what the code does as a whole. +Break it down: Summarize each class and function individually, explaining their purpose and how they interact. +Describe the workflow: Outline how the classes and functions work together. Mention any control flow (e.g., main functions, entry points, loops). +Key features: Highlight important elements like arguments, return values, or unique logic. +Maintain clarity: Write in plain English for someone familiar with Python but unfamiliar with this code. \ No newline at end of file diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py index b5faea560..8729596df 100644 --- a/cognee/modules/chunking/models/DocumentChunk.py +++ b/cognee/modules/chunking/models/DocumentChunk.py @@ -1,8 +1,10 @@ from typing import List, Optional + from cognee.infrastructure.engine import DataPoint from cognee.modules.data.processing.document_types import Document from cognee.modules.engine.models import Entity + class DocumentChunk(DataPoint): __tablename__ = "document_chunk" text: str @@ -12,6 +14,7 @@ class DocumentChunk(DataPoint): is_part_of: Document contains: List[Entity] = None - _metadata: Optional[dict] = { + _metadata: dict = { "index_fields": ["text"], + "type": "DocumentChunk" } diff --git a/cognee/modules/data/extraction/extract_summary.py b/cognee/modules/data/extraction/extract_summary.py index a17bf3ae6..10d429da9 100644 --- a/cognee/modules/data/extraction/extract_summary.py +++ b/cognee/modules/data/extraction/extract_summary.py @@ -1,7 +1,11 @@ from typing import Type + from pydantic import BaseModel -from cognee.infrastructure.llm.prompts import read_query_prompt + from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm.prompts import read_query_prompt +from cognee.shared.data_models import SummarizedCode + async def extract_summary(content: str, response_model: Type[BaseModel]): llm_client = get_llm_client() @@ -11,3 +15,7 @@ async def extract_summary(content: str, response_model: Type[BaseModel]): llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model) return llm_output + +async def extract_code_summary(content: str): + + return await extract_summary(content, response_model=SummarizedCode) diff --git a/cognee/modules/data/models/__init__.py b/cognee/modules/data/models/__init__.py index 5d79dbd40..bd5774f88 100644 --- a/cognee/modules/data/models/__init__.py +++ b/cognee/modules/data/models/__init__.py @@ -1,2 +1,3 @@ from .Data import Data from .Dataset import Dataset +from .DatasetData import DatasetData diff --git a/cognee/modules/data/processing/document_types/AudioDocument.py b/cognee/modules/data/processing/document_types/AudioDocument.py index 0d2cddd3d..268338703 100644 --- a/cognee/modules/data/processing/document_types/AudioDocument.py +++ b/cognee/modules/data/processing/document_types/AudioDocument.py @@ -1,6 +1,6 @@ from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class AudioDocument(Document): type: str = "audio" @@ -9,11 +9,12 @@ class AudioDocument(Document): result = get_llm_client().create_transcript(self.raw_data_location) return(result.text) - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): # Transcribe the audio file text = self.create_transcript() - chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text]) + chunker_func = ChunkerConfig.get_chunker(chunker) + chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text]) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/ChunkerMapping.py b/cognee/modules/data/processing/document_types/ChunkerMapping.py new file mode 100644 index 000000000..14dbb8bb7 --- /dev/null +++ b/cognee/modules/data/processing/document_types/ChunkerMapping.py @@ -0,0 +1,15 @@ +from cognee.modules.chunking.TextChunker import TextChunker + +class ChunkerConfig: + chunker_mapping = { + "text_chunker": TextChunker + } + + @classmethod + def get_chunker(cls, chunker_name: str): + chunker_class = cls.chunker_mapping.get(chunker_name) + if chunker_class is None: + raise NotImplementedError( + f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}" + ) + return chunker_class \ No newline at end of file diff --git a/cognee/modules/data/processing/document_types/Document.py b/cognee/modules/data/processing/document_types/Document.py index 45441dcce..8d6a3dafb 100644 --- a/cognee/modules/data/processing/document_types/Document.py +++ b/cognee/modules/data/processing/document_types/Document.py @@ -1,12 +1,17 @@ -from cognee.infrastructure.engine import DataPoint from uuid import UUID +from cognee.infrastructure.engine import DataPoint + + class Document(DataPoint): - type: str name: str raw_data_location: str metadata_id: UUID mime_type: str + _metadata: dict = { + "index_fields": ["name"], + "type": "Document" + } - def read(self, chunk_size: int) -> str: + def read(self, chunk_size: int, chunker = str) -> str: pass diff --git a/cognee/modules/data/processing/document_types/ImageDocument.py b/cognee/modules/data/processing/document_types/ImageDocument.py index e8f0dd8ee..352486bd8 100644 --- a/cognee/modules/data/processing/document_types/ImageDocument.py +++ b/cognee/modules/data/processing/document_types/ImageDocument.py @@ -1,6 +1,6 @@ from cognee.infrastructure.llm.get_llm_client import get_llm_client -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class ImageDocument(Document): type: str = "image" @@ -10,10 +10,11 @@ class ImageDocument(Document): result = get_llm_client().transcribe_image(self.raw_data_location) return(result.choices[0].message.content) - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): # Transcribe the image file text = self.transcribe_image() - chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text]) + chunker_func = ChunkerConfig.get_chunker(chunker) + chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text]) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/PdfDocument.py b/cognee/modules/data/processing/document_types/PdfDocument.py index 2d1941996..361214718 100644 --- a/cognee/modules/data/processing/document_types/PdfDocument.py +++ b/cognee/modules/data/processing/document_types/PdfDocument.py @@ -1,11 +1,11 @@ from pypdf import PdfReader -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class PdfDocument(Document): type: str = "pdf" - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): file = PdfReader(self.raw_data_location) def get_text(): @@ -13,7 +13,8 @@ class PdfDocument(Document): page_text = page.extract_text() yield page_text - chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text) + chunker_func = ChunkerConfig.get_chunker(chunker) + chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/TextDocument.py b/cognee/modules/data/processing/document_types/TextDocument.py index 32d3416b9..3952d9845 100644 --- a/cognee/modules/data/processing/document_types/TextDocument.py +++ b/cognee/modules/data/processing/document_types/TextDocument.py @@ -1,10 +1,10 @@ -from cognee.modules.chunking.TextChunker import TextChunker from .Document import Document +from .ChunkerMapping import ChunkerConfig class TextDocument(Document): type: str = "text" - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str): def get_text(): with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file: while True: @@ -15,6 +15,8 @@ class TextDocument(Document): yield text - chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text) + chunker_func = ChunkerConfig.get_chunker(chunker) + + chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text) yield from chunker.read() diff --git a/cognee/modules/engine/models/Entity.py b/cognee/modules/engine/models/Entity.py index b805d3d11..16e0ca3d8 100644 --- a/cognee/modules/engine/models/Entity.py +++ b/cognee/modules/engine/models/Entity.py @@ -10,4 +10,5 @@ class Entity(DataPoint): _metadata: dict = { "index_fields": ["name"], + "type": "Entity" } diff --git a/cognee/modules/engine/models/EntityType.py b/cognee/modules/engine/models/EntityType.py index 1c7843cfd..d3cc54311 100644 --- a/cognee/modules/engine/models/EntityType.py +++ b/cognee/modules/engine/models/EntityType.py @@ -1,11 +1,12 @@ from cognee.infrastructure.engine import DataPoint + class EntityType(DataPoint): __tablename__ = "entity_type" name: str - type: str description: str _metadata: dict = { "index_fields": ["name"], + "type": "EntityType" } diff --git a/cognee/modules/graph/models/EdgeType.py b/cognee/modules/graph/models/EdgeType.py index f9554d25d..998f08d8d 100644 --- a/cognee/modules/graph/models/EdgeType.py +++ b/cognee/modules/graph/models/EdgeType.py @@ -1,11 +1,14 @@ from typing import Optional + from cognee.infrastructure.engine import DataPoint + class EdgeType(DataPoint): __tablename__ = "edge_type" relationship_name: str number_of_edges: int - _metadata: Optional[dict] = { + _metadata: dict = { "index_fields": ["relationship_name"], + "type": "EdgeType" } \ No newline at end of file diff --git a/cognee/modules/graph/utils/convert_node_to_data_point.py b/cognee/modules/graph/utils/convert_node_to_data_point.py index 292f53733..602a7ffa3 100644 --- a/cognee/modules/graph/utils/convert_node_to_data_point.py +++ b/cognee/modules/graph/utils/convert_node_to_data_point.py @@ -2,7 +2,7 @@ from cognee.infrastructure.engine import DataPoint def convert_node_to_data_point(node_data: dict) -> DataPoint: - subclass = find_subclass_by_name(DataPoint, node_data["type"]) + subclass = find_subclass_by_name(DataPoint, node_data._metadata["type"]) return subclass(**node_data) diff --git a/cognee/modules/users/permissions/methods/get_document_ids_for_user.py b/cognee/modules/users/permissions/methods/get_document_ids_for_user.py index 79736db0f..d439fb4f5 100644 --- a/cognee/modules/users/permissions/methods/get_document_ids_for_user.py +++ b/cognee/modules/users/permissions/methods/get_document_ids_for_user.py @@ -1,9 +1,11 @@ from uuid import UUID from sqlalchemy import select from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.data.models import Dataset, DatasetData from ...models import ACL, Resource, Permission -async def get_document_ids_for_user(user_id: UUID) -> list[str]: + +async def get_document_ids_for_user(user_id: UUID, datasets: list[str] = None) -> list[str]: db_engine = get_relational_engine() async with db_engine.get_async_session() as session: @@ -18,4 +20,31 @@ async def get_document_ids_for_user(user_id: UUID) -> list[str]: ) )).all() + if datasets: + documents_ids_in_dataset = set() + # If datasets are specified filter out documents that aren't part of the specified datasets + for dataset in datasets: + # Find dataset id for dataset element + dataset_id = (await session.scalars( + select(Dataset.id) + .where( + Dataset.name == dataset, + Dataset.owner_id == user_id, + ) + )).one_or_none() + + # Check which documents are connected to this dataset + for document_id in document_ids: + data_id = (await session.scalars( + select(DatasetData.data_id) + .where( + DatasetData.dataset_id == dataset_id, + DatasetData.data_id == document_id, + ) + )).one_or_none() + + # If document is related to dataset added it to return value + if data_id: + documents_ids_in_dataset.add(document_id) + return list(documents_ids_in_dataset) return document_ids diff --git a/cognee/shared/CodeGraphEntities.py b/cognee/shared/CodeGraphEntities.py index 8859fd0d6..23b8879c2 100644 --- a/cognee/shared/CodeGraphEntities.py +++ b/cognee/shared/CodeGraphEntities.py @@ -1,15 +1,19 @@ from typing import List, Optional + from cognee.infrastructure.engine import DataPoint + class Repository(DataPoint): __tablename__ = "Repository" path: str - type: Optional[str] = "Repository" + _metadata: dict = { + "index_fields": ["source_code"], + "type": "Repository" + } class CodeFile(DataPoint): __tablename__ = "codefile" extracted_id: str # actually file path - type: Optional[str] = "CodeFile" source_code: Optional[str] = None part_of: Optional[Repository] = None depends_on: Optional[List["CodeFile"]] = None @@ -17,24 +21,27 @@ class CodeFile(DataPoint): contains: Optional[List["CodePart"]] = None _metadata: dict = { - "index_fields": ["source_code"] + "index_fields": ["source_code"], + "type": "CodeFile" } class CodePart(DataPoint): __tablename__ = "codepart" # part_of: Optional[CodeFile] source_code: str - type: Optional[str] = "CodePart" - + _metadata: dict = { - "index_fields": ["source_code"] + "index_fields": ["source_code"], + "type": "CodePart" } class CodeRelationship(DataPoint): source_id: str target_id: str - type: str # between files relation: str # depends on or depends directly + _metadata: dict = { + "type": "CodeRelationship" + } CodeFile.model_rebuild() CodePart.model_rebuild() diff --git a/cognee/shared/SourceCodeGraph.py b/cognee/shared/SourceCodeGraph.py index 0fc8f9487..3de72c5fd 100644 --- a/cognee/shared/SourceCodeGraph.py +++ b/cognee/shared/SourceCodeGraph.py @@ -1,79 +1,90 @@ -from typing import Any, List, Union, Literal, Optional +from typing import Any, List, Literal, Optional, Union + from cognee.infrastructure.engine import DataPoint + class Variable(DataPoint): id: str name: str - type: Literal["Variable"] = "Variable" description: str is_static: Optional[bool] = False default_value: Optional[str] = None data_type: str _metadata = { - "index_fields": ["name"] + "index_fields": ["name"], + "type": "Variable" } class Operator(DataPoint): id: str name: str - type: Literal["Operator"] = "Operator" description: str return_type: str + _metadata = { + "index_fields": ["name"], + "type": "Operator" + } class Class(DataPoint): id: str name: str - type: Literal["Class"] = "Class" description: str constructor_parameters: List[Variable] extended_from_class: Optional["Class"] = None has_methods: List["Function"] _metadata = { - "index_fields": ["name"] + "index_fields": ["name"], + "type": "Class" } class ClassInstance(DataPoint): id: str name: str - type: Literal["ClassInstance"] = "ClassInstance" description: str from_class: Class instantiated_by: Union["Function"] instantiation_arguments: List[Variable] _metadata = { - "index_fields": ["name"] + "index_fields": ["name"], + "type": "ClassInstance" } class Function(DataPoint): id: str name: str - type: Literal["Function"] = "Function" description: str parameters: List[Variable] return_type: str is_static: Optional[bool] = False _metadata = { - "index_fields": ["name"] + "index_fields": ["name"], + "type": "Function" } class FunctionCall(DataPoint): id: str - type: Literal["FunctionCall"] = "FunctionCall" called_by: Union[Function, Literal["main"]] function_called: Function function_arguments: List[Any] + _metadata = { + "index_fields": [], + "type": "FunctionCall" + } class Expression(DataPoint): id: str name: str - type: Literal["Expression"] = "Expression" description: str expression: str members: List[Union[Variable, Function, Operator, "Expression"]] + _metadata = { + "index_fields": ["name"], + "type": "Expression" + } class SourceCodeGraph(DataPoint): id: str @@ -89,8 +100,13 @@ class SourceCodeGraph(DataPoint): Operator, Expression, ]] + _metadata = { + "index_fields": ["name"], + "type": "SourceCodeGraph" + } + Class.model_rebuild() ClassInstance.model_rebuild() Expression.model_rebuild() FunctionCall.model_rebuild() -SourceCodeGraph.model_rebuild() +SourceCodeGraph.model_rebuild() \ No newline at end of file diff --git a/cognee/shared/data_models.py b/cognee/shared/data_models.py index 6cb4d436a..dec53cfcb 100644 --- a/cognee/shared/data_models.py +++ b/cognee/shared/data_models.py @@ -1,9 +1,11 @@ """Data models for the cognitive architecture.""" from enum import Enum, auto -from typing import Optional, List, Union, Dict, Any +from typing import Any, Dict, List, Optional, Union + from pydantic import BaseModel, Field + class Node(BaseModel): """Node in a knowledge graph.""" id: str @@ -194,6 +196,29 @@ class SummarizedContent(BaseModel): summary: str description: str +class SummarizedFunction(BaseModel): + name: str + description: str + inputs: Optional[List[str]] = None + outputs: Optional[List[str]] = None + decorators: Optional[List[str]] = None + +class SummarizedClass(BaseModel): + name: str + description: str + methods: Optional[List[SummarizedFunction]] = None + decorators: Optional[List[str]] = None + +class SummarizedCode(BaseModel): + file_name: str + high_level_summary: str + key_features: List[str] + imports: List[str] = [] + constants: List[str] = [] + classes: List[SummarizedClass] = [] + functions: List[SummarizedFunction] = [] + workflow_description: Optional[str] = None + class GraphDBType(Enum): NETWORKX = auto() diff --git a/cognee/tasks/documents/extract_chunks_from_documents.py b/cognee/tasks/documents/extract_chunks_from_documents.py index ec19a786d..423b87b69 100644 --- a/cognee/tasks/documents/extract_chunks_from_documents.py +++ b/cognee/tasks/documents/extract_chunks_from_documents.py @@ -1,7 +1,7 @@ from cognee.modules.data.processing.document_types.Document import Document -async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024): +async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024, chunker = 'text_chunker'): for document in documents: - for document_chunk in document.read(chunk_size = chunk_size): + for document_chunk in document.read(chunk_size = chunk_size, chunker = chunker): yield document_chunk diff --git a/cognee/tasks/repo_processor/top_down_repo_parse.py b/cognee/tasks/repo_processor/top_down_repo_parse.py new file mode 100644 index 000000000..52f58f811 --- /dev/null +++ b/cognee/tasks/repo_processor/top_down_repo_parse.py @@ -0,0 +1,171 @@ +import os + +import jedi +import parso +from tqdm import tqdm + +from . import logger + +_NODE_TYPE_MAP = { + "funcdef": "func_def", + "classdef": "class_def", + "async_funcdef": "async_func_def", + "async_stmt": "async_func_def", + "simple_stmt": "var_def", +} + +def _create_object_dict(name_node, type_name=None): + return { + "name": name_node.value, + "line": name_node.start_pos[0], + "column": name_node.start_pos[1], + "type": type_name, + } + + +def _parse_node(node): + """Parse a node to extract importable object details, including async functions and classes.""" + node_type = _NODE_TYPE_MAP.get(node.type) + + if node.type in {"funcdef", "classdef", "async_funcdef"}: + return [_create_object_dict(node.name, type_name=node_type)] + if node.type == "async_stmt" and len(node.children) > 1: + function_node = node.children[1] + if function_node.type == "funcdef": + return [_create_object_dict(function_node.name, type_name=_NODE_TYPE_MAP.get(function_node.type))] + if node.type == "simple_stmt": + # TODO: Handle multi-level/nested unpacking variable definitions in the future + expr_child = node.children[0] + if expr_child.type != "expr_stmt": + return [] + if expr_child.children[0].type == "testlist_star_expr": + name_targets = expr_child.children[0].children + else: + name_targets = expr_child.children + return [ + _create_object_dict(target, type_name=_NODE_TYPE_MAP.get(target.type)) + for target in name_targets + if target.type == "name" + ] + return [] + + + +def extract_importable_objects_with_positions_from_source_code(source_code): + """Extract top-level objects in a Python source code string with their positions (line/column).""" + try: + tree = parso.parse(source_code) + except Exception as e: + logger.error(f"Error parsing source code: {e}") + return [] + + importable_objects = [] + try: + for node in tree.children: + importable_objects.extend(_parse_node(node)) + except Exception as e: + logger.error(f"Error extracting nodes from parsed tree: {e}") + return [] + + return importable_objects + + +def extract_importable_objects_with_positions(file_path): + """Extract top-level objects in a Python file with their positions (line/column).""" + try: + with open(file_path, "r") as file: + source_code = file.read() + except Exception as e: + logger.error(f"Error reading file {file_path}: {e}") + return [] + + return extract_importable_objects_with_positions_from_source_code(source_code) + + + +def find_entity_usages(script, line, column): + """ + Return a list of files in the repo where the entity at module_path:line,column is used. + """ + usages = set() + + + try: + inferred = script.infer(line, column) + except Exception as e: + logger.error(f"Error inferring entity at {script.path}:{line},{column}: {e}") + return [] + + if not inferred or not inferred[0]: + logger.info(f"No entity inferred at {script.path}:{line},{column}") + return [] + + logger.debug(f"Inferred entity: {inferred[0].name}, type: {inferred[0].type}") + + try: + references = script.get_references(line=line, column=column, scope="project", include_builtins=False) + except Exception as e: + logger.error(f"Error retrieving references for entity at {script.path}:{line},{column}: {e}") + references = [] + + for ref in references: + if ref.module_path: # Collect unique module paths + usages.add(ref.module_path) + logger.info(f"Entity used in: {ref.module_path}") + + return list(usages) + +def parse_file_with_references(project, file_path): + """Parse a file to extract object names and their references within a project.""" + try: + importable_objects = extract_importable_objects_with_positions(file_path) + except Exception as e: + logger.error(f"Error extracting objects from {file_path}: {e}") + return [] + + if not os.path.isfile(file_path): + logger.warning(f"Module file does not exist: {file_path}") + return [] + + try: + script = jedi.Script(path=file_path, project=project) + except Exception as e: + logger.error(f"Error initializing Jedi Script: {e}") + return [] + + parsed_results = [ + { + "name": obj["name"], + "type": obj["type"], + "references": find_entity_usages(script, obj["line"], obj["column"]), + } + for obj in importable_objects + ] + return parsed_results + + +def parse_repo(repo_path): + """Parse a repository to extract object names, types, and references for all Python files.""" + try: + project = jedi.Project(path=repo_path) + except Exception as e: + logger.error(f"Error creating Jedi project for repository at {repo_path}: {e}") + return {} + + EXCLUDE_DIRS = {'venv', '.git', '__pycache__', 'build'} + + python_files = [ + os.path.join(directory, file) + for directory, _, filenames in os.walk(repo_path) + if not any(excluded in directory for excluded in EXCLUDE_DIRS) + for file in filenames + if file.endswith(".py") and os.path.getsize(os.path.join(directory, file)) > 0 + ] + + results = { + file_path: parse_file_with_references(project, file_path) + for file_path in tqdm(python_files) + } + + return results + diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 786168b58..857e4d777 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -1,6 +1,7 @@ from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.engine import DataPoint + async def index_data_points(data_points: list[DataPoint]): created_indexes = {} index_points = {} @@ -80,11 +81,20 @@ if __name__ == "__main__": class Car(DataPoint): model: str color: str + _metadata = { + "index_fields": ["name"], + "type": "Car" + } + class Person(DataPoint): name: str age: int owns_car: list[Car] + _metadata = { + "index_fields": ["name"], + "type": "Person" + } car1 = Car(model = "Tesla Model S", color = "Blue") car2 = Car(model = "Toyota Camry", color = "Red") @@ -92,4 +102,4 @@ if __name__ == "__main__": data_points = get_data_points_from_model(person) - print(data_points) + print(data_points) \ No newline at end of file diff --git a/cognee/tasks/summarization/models.py b/cognee/tasks/summarization/models.py index 6fef4fb02..add448155 100644 --- a/cognee/tasks/summarization/models.py +++ b/cognee/tasks/summarization/models.py @@ -10,6 +10,7 @@ class TextSummary(DataPoint): _metadata: dict = { "index_fields": ["text"], + "type": "TextSummary" } @@ -20,4 +21,5 @@ class CodeSummary(DataPoint): _metadata: dict = { "index_fields": ["text"], + "type": "CodeSummary" } diff --git a/cognee/tasks/summarization/summarize_code.py b/cognee/tasks/summarization/summarize_code.py index 277081f40..b116e57a9 100644 --- a/cognee/tasks/summarization/summarize_code.py +++ b/cognee/tasks/summarization/summarize_code.py @@ -1,39 +1,40 @@ import asyncio -from typing import Type +from typing import AsyncGenerator, Union from uuid import uuid5 - -from pydantic import BaseModel +from typing import Type from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.extraction.extract_summary import extract_summary -from cognee.shared.CodeGraphEntities import CodeFile -from cognee.tasks.storage import add_data_points - +from cognee.modules.data.extraction.extract_summary import extract_code_summary from .models import CodeSummary async def summarize_code( - code_files: list[DataPoint], - summarization_model: Type[BaseModel], -) -> list[DataPoint]: - if len(code_files) == 0: - return code_files + code_graph_nodes: list[DataPoint], +) -> AsyncGenerator[Union[DataPoint, CodeSummary], None]: + if len(code_graph_nodes) == 0: + return - code_files_data_points = [file for file in code_files if isinstance(file, CodeFile)] + code_data_points = [file for file in code_graph_nodes if hasattr(file, "source_code")] file_summaries = await asyncio.gather( - *[extract_summary(file.source_code, summarization_model) for file in code_files_data_points] + *[extract_code_summary(file.source_code) for file in code_data_points] ) - summaries = [ - CodeSummary( - id = uuid5(file.id, "CodeSummary"), - made_from = file, - text = file_summaries[file_index].summary, + file_summaries_map = { + code_data_point.extracted_id: str(file_summary) + for code_data_point, file_summary in zip(code_data_points, file_summaries) + } + + for node in code_graph_nodes: + if not isinstance(node, DataPoint): + continue + yield node + + if not hasattr(node, "source_code"): + continue + + yield CodeSummary( + id=uuid5(node.id, "CodeSummary"), + made_from=node, + text=file_summaries_map[node.extracted_id], ) - for (file_index, file) in enumerate(code_files_data_points) - ] - - await add_data_points(summaries) - - return code_files diff --git a/cognee/tests/integration/documents/AudioDocument_test.py b/cognee/tests/integration/documents/AudioDocument_test.py index da8b85d0b..151f4c0b2 100644 --- a/cognee/tests/integration/documents/AudioDocument_test.py +++ b/cognee/tests/integration/documents/AudioDocument_test.py @@ -31,7 +31,7 @@ def test_AudioDocument(): ) with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64) + GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count diff --git a/cognee/tests/integration/documents/ImageDocument_test.py b/cognee/tests/integration/documents/ImageDocument_test.py index 8a8ee8ef3..40e0155af 100644 --- a/cognee/tests/integration/documents/ImageDocument_test.py +++ b/cognee/tests/integration/documents/ImageDocument_test.py @@ -21,7 +21,7 @@ def test_ImageDocument(): with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT): for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=64) + GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count diff --git a/cognee/tests/integration/documents/PdfDocument_test.py b/cognee/tests/integration/documents/PdfDocument_test.py index ac57eaf75..25d4cf6c6 100644 --- a/cognee/tests/integration/documents/PdfDocument_test.py +++ b/cognee/tests/integration/documents/PdfDocument_test.py @@ -22,7 +22,7 @@ def test_PdfDocument(): ) for ground_truth, paragraph_data in zip( - GROUND_TRUTH, document.read(chunk_size=1024) + GROUND_TRUTH, document.read(chunk_size=1024, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count diff --git a/cognee/tests/integration/documents/TextDocument_test.py b/cognee/tests/integration/documents/TextDocument_test.py index f663418f5..91f38968e 100644 --- a/cognee/tests/integration/documents/TextDocument_test.py +++ b/cognee/tests/integration/documents/TextDocument_test.py @@ -33,7 +33,7 @@ def test_TextDocument(input_file, chunk_size): ) for ground_truth, paragraph_data in zip( - GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size) + GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker='text_chunker') ): assert ( ground_truth["word_count"] == paragraph_data.word_count diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index 3b4fa19c5..9554a3f9d 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -4,6 +4,7 @@ import pathlib import cognee from cognee.api.v1.search import SearchType from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search +from cognee.modules.users.methods import get_default_user logging.basicConfig(level=logging.DEBUG) @@ -44,12 +45,13 @@ async def main(): await cognee.prune.prune_data() await cognee.prune.prune_system(metadata = True) - dataset_name = "cs_explanations" + dataset_name_1 = "natural_language" + dataset_name_2 = "quantum" explanation_file_path = os.path.join( pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt" ) - await cognee.add([explanation_file_path], dataset_name) + await cognee.add([explanation_file_path], dataset_name_1) text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena. At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states. @@ -59,12 +61,23 @@ async def main(): In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited. """ - await cognee.add([text], dataset_name) + await cognee.add([text], dataset_name_2) - await cognee.cognify([dataset_name]) + await cognee.cognify([dataset_name_2, dataset_name_1]) from cognee.infrastructure.databases.vector import get_vector_engine + # Test getting of documents for search per dataset + from cognee.modules.users.permissions.methods import get_document_ids_for_user + user = await get_default_user() + document_ids = await get_document_ids_for_user(user.id, [dataset_name_1]) + assert len(document_ids) == 1, f"Number of expected documents doesn't match {len(document_ids)} != 1" + + # Test getting of documents for search when no dataset is provided + user = await get_default_user() + document_ids = await get_document_ids_for_user(user.id) + assert len(document_ids) == 2, f"Number of expected documents doesn't match {len(document_ids)} != 2" + vector_engine = get_vector_engine() random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0] random_node_name = random_node.payload["text"] @@ -75,7 +88,7 @@ async def main(): for result in search_results: print(f"{result}\n") - search_results = await cognee.search(SearchType.CHUNKS, query_text = random_node_name) + search_results = await cognee.search(SearchType.CHUNKS, query_text = random_node_name, datasets=[dataset_name_2]) assert len(search_results) != 0, "The search results list is empty." print("\n\nExtracted chunks are:\n") for result in search_results: diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py index 016f2be33..06c74c854 100644 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_huge_model_test.py @@ -2,7 +2,7 @@ import asyncio import random import time from typing import List -from uuid import uuid5, NAMESPACE_OID +from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.engine import DataPoint from cognee.modules.graph.utils import get_graph_from_model @@ -11,16 +11,28 @@ random.seed(1500) class Repository(DataPoint): path: str + _metadata = { + "index_fields": [], + "type": "Repository" + } class CodeFile(DataPoint): part_of: Repository contains: List["CodePart"] = [] depends_on: List["CodeFile"] = [] source_code: str + _metadata = { + "index_fields": [], + "type": "CodeFile" + } class CodePart(DataPoint): part_of: CodeFile source_code: str + _metadata = { + "index_fields": [], + "type": "CodePart" + } CodeFile.model_rebuild() CodePart.model_rebuild() diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py index 000d45c15..499dc9f3f 100644 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_test.py @@ -1,25 +1,42 @@ import asyncio import random from typing import List -from uuid import uuid5, NAMESPACE_OID +from uuid import NAMESPACE_OID, uuid5 from cognee.infrastructure.engine import DataPoint from cognee.modules.graph.utils import get_graph_from_model + class Document(DataPoint): path: str + _metadata = { + "index_fields": [], + "type": "Document" + } class DocumentChunk(DataPoint): part_of: Document text: str contains: List["Entity"] = None + _metadata = { + "index_fields": ["text"], + "type": "DocumentChunk" + } class EntityType(DataPoint): name: str + _metadata = { + "index_fields": ["name"], + "type": "EntityType" + } class Entity(DataPoint): name: str is_type: EntityType + _metadata = { + "index_fields": ["name"], + "type": "Entity" + } DocumentChunk.model_rebuild() diff --git a/evals/eval_swe_bench.py b/evals/eval_swe_bench.py index 67826fc12..6c2280d80 100644 --- a/evals/eval_swe_bench.py +++ b/evals/eval_swe_bench.py @@ -7,19 +7,13 @@ from pathlib import Path from swebench.harness.utils import load_swebench_dataset from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE +from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline from cognee.api.v1.search import SearchType from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.prompts import read_query_prompt -from cognee.modules.pipelines import Task, run_tasks from cognee.modules.retrieval.brute_force_triplet_search import \ brute_force_triplet_search -# from cognee.shared.data_models import SummarizedContent from cognee.shared.utils import render_graph -from cognee.tasks.repo_processor import (enrich_dependency_graph, - expand_dependency_graph, - get_repo_file_dependencies) -from cognee.tasks.storage import add_data_points -# from cognee.tasks.summarization import summarize_code from evals.eval_utils import download_github_repo, retrieved_edges_to_string @@ -42,48 +36,22 @@ def check_install_package(package_name): async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS): - import os - import pathlib - import cognee - from cognee.infrastructure.databases.relational import create_db_and_tables - - file_path = Path(__file__).parent - data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve()) - cognee.config.data_root_directory(data_directory_path) - cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve()) - cognee.config.system_root_directory(cognee_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata = True) - - await create_db_and_tables() - - # repo_path = download_github_repo(instance, '../RAW_GIT_REPOS') - - repo_path = '/Users/borisarzentar/Projects/graphrag' - - tasks = [ - Task(get_repo_file_dependencies), - Task(enrich_dependency_graph, task_config = { "batch_size": 50 }), - Task(expand_dependency_graph, task_config = { "batch_size": 50 }), - Task(add_data_points, task_config = { "batch_size": 50 }), - # Task(summarize_code, summarization_model = SummarizedContent), - ] - - pipeline = run_tasks(tasks, repo_path, "cognify_code_pipeline") + repo_path = download_github_repo(instance, '../RAW_GIT_REPOS') + pipeline = await run_code_graph_pipeline(repo_path) async for result in pipeline: print(result) print('Here we have the repo under the repo_path') - await render_graph(None, include_labels = True, include_nodes = True) + await render_graph(None, include_labels=True, include_nodes=True) problem_statement = instance['problem_statement'] instructions = read_query_prompt("patch_gen_kg_instructions.txt") - retrieved_edges = await brute_force_triplet_search(problem_statement, top_k = 3, collections = ["data_point_source_code", "data_point_text"]) - + retrieved_edges = await brute_force_triplet_search(problem_statement, top_k=3, + collections=["data_point_source_code", "data_point_text"]) + retrieved_edges_str = retrieved_edges_to_string(retrieved_edges) prompt = "\n".join([ @@ -171,7 +139,6 @@ async def main(): with open(predictions_path, "w") as file: json.dump(preds, file) - subprocess.run( [ "python", diff --git a/examples/python/code_graph_example.py b/examples/python/code_graph_example.py new file mode 100644 index 000000000..9189de46c --- /dev/null +++ b/examples/python/code_graph_example.py @@ -0,0 +1,15 @@ +import argparse +import asyncio +from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline + + +async def main(repo_path): + async for result in await run_code_graph_pipeline(repo_path): + print(result) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--repo-path", type=str, required=True, help="Path to the repository") + args = parser.parse_args() + asyncio.run(main(args.repo_path)) +