Merge branch 'dev' into COG-578
This commit is contained in:
commit
be424249d7
52 changed files with 614 additions and 190 deletions
1
.github/workflows/test_deduplication.yml
vendored
1
.github/workflows/test_deduplication.yml
vendored
|
|
@ -17,6 +17,7 @@ jobs:
|
|||
run_deduplication_test:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
|
|||
1
.github/workflows/test_milvus.yml
vendored
1
.github/workflows/test_milvus.yml
vendored
|
|
@ -18,6 +18,7 @@ jobs:
|
|||
run_milvus:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
defaults:
|
||||
|
|
|
|||
1
.github/workflows/test_neo4j.yml
vendored
1
.github/workflows/test_neo4j.yml
vendored
|
|
@ -15,6 +15,7 @@ env:
|
|||
jobs:
|
||||
run_neo4j_integration_test:
|
||||
name: test
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
defaults:
|
||||
|
|
|
|||
1
.github/workflows/test_pgvector.yml
vendored
1
.github/workflows/test_pgvector.yml
vendored
|
|
@ -18,6 +18,7 @@ jobs:
|
|||
run_pgvector_integration_test:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
|
|||
1
.github/workflows/test_python_3_10.yml
vendored
1
.github/workflows/test_python_3_10.yml
vendored
|
|
@ -18,6 +18,7 @@ jobs:
|
|||
run_common:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
defaults:
|
||||
|
|
|
|||
1
.github/workflows/test_python_3_11.yml
vendored
1
.github/workflows/test_python_3_11.yml
vendored
|
|
@ -18,6 +18,7 @@ jobs:
|
|||
run_common:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
defaults:
|
||||
|
|
|
|||
1
.github/workflows/test_python_3_9.yml
vendored
1
.github/workflows/test_python_3_9.yml
vendored
|
|
@ -18,6 +18,7 @@ jobs:
|
|||
run_common:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
defaults:
|
||||
|
|
|
|||
1
.github/workflows/test_qdrant.yml
vendored
1
.github/workflows/test_qdrant.yml
vendored
|
|
@ -18,6 +18,7 @@ jobs:
|
|||
run_qdrant_integration_test:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
|
|
|
|||
1
.github/workflows/test_weaviate.yml
vendored
1
.github/workflows/test_weaviate.yml
vendored
|
|
@ -18,6 +18,7 @@ jobs:
|
|||
run_weaviate_integration_test:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event.label.name == 'run-checks' }}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
|
|
|
|||
|
|
@ -1,22 +1,35 @@
|
|||
# NOTICE: This module contains deprecated functions.
|
||||
# Use only the run_code_graph_pipeline function; all other functions are deprecated.
|
||||
# Related issue: COG-906
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import \
|
||||
get_pipeline_status
|
||||
from cognee.modules.pipelines.operations.log_pipeline_status import \
|
||||
log_pipeline_status
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.SourceCodeGraph import SourceCodeGraph
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.data.models import Dataset, Data
|
||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
|
||||
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.documents import (check_permissions_on_documents,
|
||||
classify_documents,
|
||||
extract_chunks_from_documents)
|
||||
from cognee.tasks.graph import extract_graph_from_code
|
||||
from cognee.tasks.repo_processor import (enrich_dependency_graph,
|
||||
expand_dependency_graph,
|
||||
get_repo_file_dependencies)
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_code
|
||||
|
||||
logger = logging.getLogger("code_graph_pipeline")
|
||||
|
||||
|
|
@ -51,6 +64,7 @@ async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User
|
|||
|
||||
|
||||
async def run_pipeline(dataset: Dataset, user: User):
|
||||
'''DEPRECATED: Use `run_code_graph_pipeline` instead. This function will be removed.'''
|
||||
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
|
||||
|
||||
document_ids_str = [str(document.id) for document in data_documents]
|
||||
|
|
@ -103,3 +117,30 @@ async def run_pipeline(dataset: Dataset, user: User):
|
|||
|
||||
def generate_dataset_name(dataset_name: str) -> str:
|
||||
return dataset_name.replace(".", "_").replace(" ", "_")
|
||||
|
||||
|
||||
async def run_code_graph_pipeline(repo_path):
|
||||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
file_path = Path(__file__).parent
|
||||
data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve())
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve())
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await create_db_and_tables()
|
||||
|
||||
tasks = [
|
||||
Task(get_repo_file_dependencies),
|
||||
Task(enrich_dependency_graph, task_config={"batch_size": 50}),
|
||||
Task(expand_dependency_graph, task_config={"batch_size": 50}),
|
||||
Task(summarize_code, task_config={"batch_size": 50}),
|
||||
Task(add_data_points, task_config={"batch_size": 50}),
|
||||
]
|
||||
|
||||
return run_tasks(tasks, repo_path, "cognify_code_pipeline")
|
||||
|
|
|
|||
|
|
@ -69,17 +69,18 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, graph_model: BaseMo
|
|||
|
||||
send_telemetry("cognee.cognify EXECUTION STARTED", user.id)
|
||||
|
||||
async with update_status_lock:
|
||||
task_status = await get_pipeline_status([dataset_id])
|
||||
#async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests
|
||||
task_status = await get_pipeline_status([dataset_id])
|
||||
|
||||
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
|
||||
logger.info("Dataset %s is already being processed.", dataset_name)
|
||||
return
|
||||
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
|
||||
logger.info("Dataset %s is already being processed.", dataset_name)
|
||||
return
|
||||
|
||||
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
|
||||
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
from fastapi import APIRouter
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from cognee.modules.users.models import User
|
||||
from fastapi.responses import JSONResponse
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from fastapi import Depends
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
class CognifyPayloadDTO(BaseModel):
|
||||
datasets: List[str]
|
||||
graph_model: Optional[BaseModel] = KnowledgeGraph
|
||||
|
||||
def get_cognify_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
|
@ -17,11 +19,11 @@ def get_cognify_router() -> APIRouter:
|
|||
""" This endpoint is responsible for the cognitive processing of the content."""
|
||||
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
|
||||
try:
|
||||
await cognee_cognify(payload.datasets, user)
|
||||
await cognee_cognify(payload.datasets, user, payload.graph_model)
|
||||
except Exception as error:
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={"error": str(error)}
|
||||
)
|
||||
|
||||
return router
|
||||
return router
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from uuid import UUID
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.search.operations import log_query, log_result
|
||||
|
|
@ -22,7 +22,12 @@ class SearchType(Enum):
|
|||
CHUNKS = "CHUNKS"
|
||||
COMPLETION = "COMPLETION"
|
||||
|
||||
async def search(query_type: SearchType, query_text: str, user: User = None) -> list:
|
||||
async def search(query_type: SearchType, query_text: str, user: User = None,
|
||||
datasets: Union[list[str], str, None] = None) -> list:
|
||||
# We use lists from now on for datasets
|
||||
if isinstance(datasets, str):
|
||||
datasets = [datasets]
|
||||
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
|
|
@ -31,7 +36,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) ->
|
|||
|
||||
query = await log_query(query_text, str(query_type), user.id)
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id)
|
||||
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
||||
search_results = await specific_search(query_type, query_text, user)
|
||||
|
||||
filtered_search_results = []
|
||||
|
|
|
|||
|
|
@ -1,21 +1,26 @@
|
|||
import asyncio
|
||||
# from datetime import datetime
|
||||
import json
|
||||
from uuid import UUID
|
||||
from textwrap import dedent
|
||||
from uuid import UUID
|
||||
|
||||
from falkordb import FalkorDB
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import \
|
||||
GraphDBInterface
|
||||
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
|
||||
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface
|
||||
from cognee.infrastructure.databases.vector.vector_db_interface import \
|
||||
VectorDBInterface
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"]
|
||||
"index_fields": ["text"],
|
||||
"type": "IndexSchema"
|
||||
}
|
||||
|
||||
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
||||
|
|
|
|||
|
|
@ -1,25 +1,29 @@
|
|||
from typing import List, Optional, get_type_hints, Generic, TypeVar
|
||||
import asyncio
|
||||
from typing import Generic, List, Optional, TypeVar, get_type_hints
|
||||
from uuid import UUID
|
||||
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from pydantic import BaseModel
|
||||
from lancedb.pydantic import Vector, LanceModel
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..utils import normalize_distances
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..utils import normalize_distances
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
id: str
|
||||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"]
|
||||
"index_fields": ["text"],
|
||||
"type": "IndexSchema"
|
||||
}
|
||||
|
||||
class LanceDBAdapter(VectorDBInterface):
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@ import asyncio
|
|||
import logging
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
|
||||
logger = logging.getLogger("MilvusAdapter")
|
||||
|
||||
|
|
@ -16,7 +18,8 @@ class IndexSchema(DataPoint):
|
|||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"]
|
||||
"index_fields": ["text"],
|
||||
"type": "IndexSchema"
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,27 +1,30 @@
|
|||
import asyncio
|
||||
from uuid import UUID
|
||||
from typing import List, Optional, get_type_hints
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import JSON, Column, Table, select, delete
|
||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
from .serialize_data import serialize_data
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..utils import normalize_distances
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||
from ...relational.ModelBase import Base
|
||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..utils import normalize_distances
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from .serialize_data import serialize_data
|
||||
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"]
|
||||
"index_fields": ["text"],
|
||||
"type": "IndexSchema"
|
||||
}
|
||||
|
||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||
|
|
@ -48,10 +51,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
async with self.engine.begin() as connection:
|
||||
# Load the schema information into the MetaData object
|
||||
await connection.run_sync(Base.metadata.reflect)
|
||||
# Create a MetaData instance to load table information
|
||||
metadata = MetaData()
|
||||
# Load table information from schema into MetaData
|
||||
await connection.run_sync(metadata.reflect)
|
||||
|
||||
if collection_name in Base.metadata.tables:
|
||||
if collection_name in metadata.tables:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
|
@ -87,6 +92,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
async def create_data_points(
|
||||
self, collection_name: str, data_points: List[DataPoint]
|
||||
):
|
||||
data_point_types = get_type_hints(DataPoint)
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name = collection_name,
|
||||
|
|
@ -106,7 +112,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
primary_key: Mapped[int] = mapped_column(
|
||||
primary_key=True, autoincrement=True
|
||||
)
|
||||
id: Mapped[type(data_points[0].id)]
|
||||
id: Mapped[data_point_types["id"]]
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
|
||||
|
|
@ -145,10 +151,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
with an async engine.
|
||||
"""
|
||||
async with self.engine.begin() as connection:
|
||||
# Load the schema information into the MetaData object
|
||||
await connection.run_sync(Base.metadata.reflect)
|
||||
if collection_name in Base.metadata.tables:
|
||||
return Base.metadata.tables[collection_name]
|
||||
# Create a MetaData instance to load table information
|
||||
metadata = MetaData()
|
||||
# Load table information from schema into MetaData
|
||||
await connection.run_sync(metadata.reflect)
|
||||
if collection_name in metadata.tables:
|
||||
return metadata.tables[collection_name]
|
||||
else:
|
||||
raise EntityNotFoundError(message=f"Table '{collection_name}' not found.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import UUID
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import \
|
||||
ScoredResult
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
|
||||
logger = logging.getLogger("QDrantAdapter")
|
||||
|
||||
|
|
@ -15,7 +18,8 @@ class IndexSchema(DataPoint):
|
|||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"]
|
||||
"index_fields": ["text"],
|
||||
"type": "IndexSchema"
|
||||
}
|
||||
|
||||
# class CollectionConfig(BaseModel, extra = "forbid"):
|
||||
|
|
|
|||
|
|
@ -5,9 +5,10 @@ from uuid import UUID
|
|||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
|
||||
logger = logging.getLogger("WeaviateAdapter")
|
||||
|
||||
|
|
@ -15,7 +16,8 @@ class IndexSchema(DataPoint):
|
|||
text: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"]
|
||||
"index_fields": ["text"],
|
||||
"type": "IndexSchema"
|
||||
}
|
||||
|
||||
class WeaviateAdapter(VectorDBInterface):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from typing_extensions import TypedDict
|
||||
from uuid import UUID, uuid4
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class MetaData(TypedDict):
|
||||
index_fields: list[str]
|
||||
|
|
@ -13,7 +15,8 @@ class DataPoint(BaseModel):
|
|||
updated_at: Optional[datetime] = datetime.now(timezone.utc)
|
||||
topological_rank: Optional[int] = 0
|
||||
_metadata: Optional[MetaData] = {
|
||||
"index_fields": []
|
||||
"index_fields": [],
|
||||
"type": "DataPoint"
|
||||
}
|
||||
|
||||
# class Config:
|
||||
|
|
@ -39,4 +42,4 @@ class DataPoint(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def get_embeddable_property_names(self, data_point):
|
||||
return data_point._metadata["index_fields"] or []
|
||||
return data_point._metadata["index_fields"] or []
|
||||
10
cognee/infrastructure/llm/prompts/summarize_code.txt
Normal file
10
cognee/infrastructure/llm/prompts/summarize_code.txt
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
You are an expert Python programmer and technical writer. Your task is to summarize the given Python code snippet or file.
|
||||
The code may contain multiple imports, classes, functions, constants and logic. Provide a clear, structured explanation of its components
|
||||
and their relationships.
|
||||
|
||||
Instructions:
|
||||
Provide an overview: Start with a high-level summary of what the code does as a whole.
|
||||
Break it down: Summarize each class and function individually, explaining their purpose and how they interact.
|
||||
Describe the workflow: Outline how the classes and functions work together. Mention any control flow (e.g., main functions, entry points, loops).
|
||||
Key features: Highlight important elements like arguments, return values, or unique logic.
|
||||
Maintain clarity: Write in plain English for someone familiar with Python but unfamiliar with this code.
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
__tablename__ = "document_chunk"
|
||||
text: str
|
||||
|
|
@ -12,6 +14,7 @@ class DocumentChunk(DataPoint):
|
|||
is_part_of: Document
|
||||
contains: List[Entity] = None
|
||||
|
||||
_metadata: Optional[dict] = {
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"],
|
||||
"type": "DocumentChunk"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.shared.data_models import SummarizedCode
|
||||
|
||||
|
||||
async def extract_summary(content: str, response_model: Type[BaseModel]):
|
||||
llm_client = get_llm_client()
|
||||
|
|
@ -11,3 +15,7 @@ async def extract_summary(content: str, response_model: Type[BaseModel]):
|
|||
llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
|
||||
|
||||
return llm_output
|
||||
|
||||
async def extract_code_summary(content: str):
|
||||
|
||||
return await extract_summary(content, response_model=SummarizedCode)
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
from .Data import Data
|
||||
from .Dataset import Dataset
|
||||
from .DatasetData import DatasetData
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
|
||||
class AudioDocument(Document):
|
||||
type: str = "audio"
|
||||
|
|
@ -9,11 +9,12 @@ class AudioDocument(Document):
|
|||
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||
return(result.text)
|
||||
|
||||
def read(self, chunk_size: int):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
# Transcribe the audio file
|
||||
|
||||
text = self.create_transcript()
|
||||
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text])
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
|
||||
class ChunkerConfig:
|
||||
chunker_mapping = {
|
||||
"text_chunker": TextChunker
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_chunker(cls, chunker_name: str):
|
||||
chunker_class = cls.chunker_mapping.get(chunker_name)
|
||||
if chunker_class is None:
|
||||
raise NotImplementedError(
|
||||
f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}"
|
||||
)
|
||||
return chunker_class
|
||||
|
|
@ -1,12 +1,17 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class Document(DataPoint):
|
||||
type: str
|
||||
name: str
|
||||
raw_data_location: str
|
||||
metadata_id: UUID
|
||||
mime_type: str
|
||||
_metadata: dict = {
|
||||
"index_fields": ["name"],
|
||||
"type": "Document"
|
||||
}
|
||||
|
||||
def read(self, chunk_size: int) -> str:
|
||||
def read(self, chunk_size: int, chunker = str) -> str:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
|
||||
class ImageDocument(Document):
|
||||
type: str = "image"
|
||||
|
|
@ -10,10 +10,11 @@ class ImageDocument(Document):
|
|||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||
return(result.choices[0].message.content)
|
||||
|
||||
def read(self, chunk_size: int):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
# Transcribe the image file
|
||||
text = self.transcribe_image()
|
||||
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text])
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from pypdf import PdfReader
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
|
||||
class PdfDocument(Document):
|
||||
type: str = "pdf"
|
||||
|
||||
def read(self, chunk_size: int):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
file = PdfReader(self.raw_data_location)
|
||||
|
||||
def get_text():
|
||||
|
|
@ -13,7 +13,8 @@ class PdfDocument(Document):
|
|||
page_text = page.extract_text()
|
||||
yield page_text
|
||||
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
|
||||
|
||||
yield from chunker.read()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from .Document import Document
|
||||
from .ChunkerMapping import ChunkerConfig
|
||||
|
||||
class TextDocument(Document):
|
||||
type: str = "text"
|
||||
|
||||
def read(self, chunk_size: int):
|
||||
def read(self, chunk_size: int, chunker: str):
|
||||
def get_text():
|
||||
with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file:
|
||||
while True:
|
||||
|
|
@ -15,6 +15,8 @@ class TextDocument(Document):
|
|||
|
||||
yield text
|
||||
|
||||
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text)
|
||||
chunker_func = ChunkerConfig.get_chunker(chunker)
|
||||
|
||||
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
|
||||
|
||||
yield from chunker.read()
|
||||
|
|
|
|||
|
|
@ -10,4 +10,5 @@ class Entity(DataPoint):
|
|||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["name"],
|
||||
"type": "Entity"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class EntityType(DataPoint):
|
||||
__tablename__ = "entity_type"
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["name"],
|
||||
"type": "EntityType"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class EdgeType(DataPoint):
|
||||
__tablename__ = "edge_type"
|
||||
relationship_name: str
|
||||
number_of_edges: int
|
||||
|
||||
_metadata: Optional[dict] = {
|
||||
_metadata: dict = {
|
||||
"index_fields": ["relationship_name"],
|
||||
"type": "EdgeType"
|
||||
}
|
||||
|
|
@ -2,7 +2,7 @@ from cognee.infrastructure.engine import DataPoint
|
|||
|
||||
|
||||
def convert_node_to_data_point(node_data: dict) -> DataPoint:
|
||||
subclass = find_subclass_by_name(DataPoint, node_data["type"])
|
||||
subclass = find_subclass_by_name(DataPoint, node_data._metadata["type"])
|
||||
|
||||
return subclass(**node_data)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Dataset, DatasetData
|
||||
from ...models import ACL, Resource, Permission
|
||||
|
||||
async def get_document_ids_for_user(user_id: UUID) -> list[str]:
|
||||
|
||||
async def get_document_ids_for_user(user_id: UUID, datasets: list[str] = None) -> list[str]:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
|
|
@ -18,4 +20,31 @@ async def get_document_ids_for_user(user_id: UUID) -> list[str]:
|
|||
)
|
||||
)).all()
|
||||
|
||||
if datasets:
|
||||
documents_ids_in_dataset = set()
|
||||
# If datasets are specified filter out documents that aren't part of the specified datasets
|
||||
for dataset in datasets:
|
||||
# Find dataset id for dataset element
|
||||
dataset_id = (await session.scalars(
|
||||
select(Dataset.id)
|
||||
.where(
|
||||
Dataset.name == dataset,
|
||||
Dataset.owner_id == user_id,
|
||||
)
|
||||
)).one_or_none()
|
||||
|
||||
# Check which documents are connected to this dataset
|
||||
for document_id in document_ids:
|
||||
data_id = (await session.scalars(
|
||||
select(DatasetData.data_id)
|
||||
.where(
|
||||
DatasetData.dataset_id == dataset_id,
|
||||
DatasetData.data_id == document_id,
|
||||
)
|
||||
)).one_or_none()
|
||||
|
||||
# If document is related to dataset added it to return value
|
||||
if data_id:
|
||||
documents_ids_in_dataset.add(document_id)
|
||||
return list(documents_ids_in_dataset)
|
||||
return document_ids
|
||||
|
|
|
|||
|
|
@ -1,15 +1,19 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class Repository(DataPoint):
|
||||
__tablename__ = "Repository"
|
||||
path: str
|
||||
type: Optional[str] = "Repository"
|
||||
_metadata: dict = {
|
||||
"index_fields": ["source_code"],
|
||||
"type": "Repository"
|
||||
}
|
||||
|
||||
class CodeFile(DataPoint):
|
||||
__tablename__ = "codefile"
|
||||
extracted_id: str # actually file path
|
||||
type: Optional[str] = "CodeFile"
|
||||
source_code: Optional[str] = None
|
||||
part_of: Optional[Repository] = None
|
||||
depends_on: Optional[List["CodeFile"]] = None
|
||||
|
|
@ -17,24 +21,27 @@ class CodeFile(DataPoint):
|
|||
contains: Optional[List["CodePart"]] = None
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["source_code"]
|
||||
"index_fields": ["source_code"],
|
||||
"type": "CodeFile"
|
||||
}
|
||||
|
||||
class CodePart(DataPoint):
|
||||
__tablename__ = "codepart"
|
||||
# part_of: Optional[CodeFile]
|
||||
source_code: str
|
||||
type: Optional[str] = "CodePart"
|
||||
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["source_code"]
|
||||
"index_fields": ["source_code"],
|
||||
"type": "CodePart"
|
||||
}
|
||||
|
||||
class CodeRelationship(DataPoint):
|
||||
source_id: str
|
||||
target_id: str
|
||||
type: str # between files
|
||||
relation: str # depends on or depends directly
|
||||
_metadata: dict = {
|
||||
"type": "CodeRelationship"
|
||||
}
|
||||
|
||||
CodeFile.model_rebuild()
|
||||
CodePart.model_rebuild()
|
||||
|
|
|
|||
|
|
@ -1,79 +1,90 @@
|
|||
from typing import Any, List, Union, Literal, Optional
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class Variable(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Variable"] = "Variable"
|
||||
description: str
|
||||
is_static: Optional[bool] = False
|
||||
default_value: Optional[str] = None
|
||||
data_type: str
|
||||
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
"index_fields": ["name"],
|
||||
"type": "Variable"
|
||||
}
|
||||
|
||||
class Operator(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Operator"] = "Operator"
|
||||
description: str
|
||||
return_type: str
|
||||
_metadata = {
|
||||
"index_fields": ["name"],
|
||||
"type": "Operator"
|
||||
}
|
||||
|
||||
class Class(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Class"] = "Class"
|
||||
description: str
|
||||
constructor_parameters: List[Variable]
|
||||
extended_from_class: Optional["Class"] = None
|
||||
has_methods: List["Function"]
|
||||
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
"index_fields": ["name"],
|
||||
"type": "Class"
|
||||
}
|
||||
|
||||
class ClassInstance(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["ClassInstance"] = "ClassInstance"
|
||||
description: str
|
||||
from_class: Class
|
||||
instantiated_by: Union["Function"]
|
||||
instantiation_arguments: List[Variable]
|
||||
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
"index_fields": ["name"],
|
||||
"type": "ClassInstance"
|
||||
}
|
||||
|
||||
class Function(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Function"] = "Function"
|
||||
description: str
|
||||
parameters: List[Variable]
|
||||
return_type: str
|
||||
is_static: Optional[bool] = False
|
||||
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
"index_fields": ["name"],
|
||||
"type": "Function"
|
||||
}
|
||||
|
||||
class FunctionCall(DataPoint):
|
||||
id: str
|
||||
type: Literal["FunctionCall"] = "FunctionCall"
|
||||
called_by: Union[Function, Literal["main"]]
|
||||
function_called: Function
|
||||
function_arguments: List[Any]
|
||||
_metadata = {
|
||||
"index_fields": [],
|
||||
"type": "FunctionCall"
|
||||
}
|
||||
|
||||
class Expression(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Expression"] = "Expression"
|
||||
description: str
|
||||
expression: str
|
||||
members: List[Union[Variable, Function, Operator, "Expression"]]
|
||||
_metadata = {
|
||||
"index_fields": ["name"],
|
||||
"type": "Expression"
|
||||
}
|
||||
|
||||
class SourceCodeGraph(DataPoint):
|
||||
id: str
|
||||
|
|
@ -89,8 +100,13 @@ class SourceCodeGraph(DataPoint):
|
|||
Operator,
|
||||
Expression,
|
||||
]]
|
||||
_metadata = {
|
||||
"index_fields": ["name"],
|
||||
"type": "SourceCodeGraph"
|
||||
}
|
||||
|
||||
Class.model_rebuild()
|
||||
ClassInstance.model_rebuild()
|
||||
Expression.model_rebuild()
|
||||
FunctionCall.model_rebuild()
|
||||
SourceCodeGraph.model_rebuild()
|
||||
SourceCodeGraph.model_rebuild()
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
"""Data models for the cognitive architecture."""
|
||||
|
||||
from enum import Enum, auto
|
||||
from typing import Optional, List, Union, Dict, Any
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
"""Node in a knowledge graph."""
|
||||
id: str
|
||||
|
|
@ -194,6 +196,29 @@ class SummarizedContent(BaseModel):
|
|||
summary: str
|
||||
description: str
|
||||
|
||||
class SummarizedFunction(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
inputs: Optional[List[str]] = None
|
||||
outputs: Optional[List[str]] = None
|
||||
decorators: Optional[List[str]] = None
|
||||
|
||||
class SummarizedClass(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
methods: Optional[List[SummarizedFunction]] = None
|
||||
decorators: Optional[List[str]] = None
|
||||
|
||||
class SummarizedCode(BaseModel):
|
||||
file_name: str
|
||||
high_level_summary: str
|
||||
key_features: List[str]
|
||||
imports: List[str] = []
|
||||
constants: List[str] = []
|
||||
classes: List[SummarizedClass] = []
|
||||
functions: List[SummarizedFunction] = []
|
||||
workflow_description: Optional[str] = None
|
||||
|
||||
|
||||
class GraphDBType(Enum):
|
||||
NETWORKX = auto()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from cognee.modules.data.processing.document_types.Document import Document
|
||||
|
||||
|
||||
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024):
|
||||
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024, chunker = 'text_chunker'):
|
||||
for document in documents:
|
||||
for document_chunk in document.read(chunk_size = chunk_size):
|
||||
for document_chunk in document.read(chunk_size = chunk_size, chunker = chunker):
|
||||
yield document_chunk
|
||||
|
|
|
|||
171
cognee/tasks/repo_processor/top_down_repo_parse.py
Normal file
171
cognee/tasks/repo_processor/top_down_repo_parse.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
import os
|
||||
|
||||
import jedi
|
||||
import parso
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import logger
|
||||
|
||||
_NODE_TYPE_MAP = {
|
||||
"funcdef": "func_def",
|
||||
"classdef": "class_def",
|
||||
"async_funcdef": "async_func_def",
|
||||
"async_stmt": "async_func_def",
|
||||
"simple_stmt": "var_def",
|
||||
}
|
||||
|
||||
def _create_object_dict(name_node, type_name=None):
|
||||
return {
|
||||
"name": name_node.value,
|
||||
"line": name_node.start_pos[0],
|
||||
"column": name_node.start_pos[1],
|
||||
"type": type_name,
|
||||
}
|
||||
|
||||
|
||||
def _parse_node(node):
|
||||
"""Parse a node to extract importable object details, including async functions and classes."""
|
||||
node_type = _NODE_TYPE_MAP.get(node.type)
|
||||
|
||||
if node.type in {"funcdef", "classdef", "async_funcdef"}:
|
||||
return [_create_object_dict(node.name, type_name=node_type)]
|
||||
if node.type == "async_stmt" and len(node.children) > 1:
|
||||
function_node = node.children[1]
|
||||
if function_node.type == "funcdef":
|
||||
return [_create_object_dict(function_node.name, type_name=_NODE_TYPE_MAP.get(function_node.type))]
|
||||
if node.type == "simple_stmt":
|
||||
# TODO: Handle multi-level/nested unpacking variable definitions in the future
|
||||
expr_child = node.children[0]
|
||||
if expr_child.type != "expr_stmt":
|
||||
return []
|
||||
if expr_child.children[0].type == "testlist_star_expr":
|
||||
name_targets = expr_child.children[0].children
|
||||
else:
|
||||
name_targets = expr_child.children
|
||||
return [
|
||||
_create_object_dict(target, type_name=_NODE_TYPE_MAP.get(target.type))
|
||||
for target in name_targets
|
||||
if target.type == "name"
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
|
||||
def extract_importable_objects_with_positions_from_source_code(source_code):
|
||||
"""Extract top-level objects in a Python source code string with their positions (line/column)."""
|
||||
try:
|
||||
tree = parso.parse(source_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing source code: {e}")
|
||||
return []
|
||||
|
||||
importable_objects = []
|
||||
try:
|
||||
for node in tree.children:
|
||||
importable_objects.extend(_parse_node(node))
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting nodes from parsed tree: {e}")
|
||||
return []
|
||||
|
||||
return importable_objects
|
||||
|
||||
|
||||
def extract_importable_objects_with_positions(file_path):
|
||||
"""Extract top-level objects in a Python file with their positions (line/column)."""
|
||||
try:
|
||||
with open(file_path, "r") as file:
|
||||
source_code = file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file {file_path}: {e}")
|
||||
return []
|
||||
|
||||
return extract_importable_objects_with_positions_from_source_code(source_code)
|
||||
|
||||
|
||||
|
||||
def find_entity_usages(script, line, column):
|
||||
"""
|
||||
Return a list of files in the repo where the entity at module_path:line,column is used.
|
||||
"""
|
||||
usages = set()
|
||||
|
||||
|
||||
try:
|
||||
inferred = script.infer(line, column)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inferring entity at {script.path}:{line},{column}: {e}")
|
||||
return []
|
||||
|
||||
if not inferred or not inferred[0]:
|
||||
logger.info(f"No entity inferred at {script.path}:{line},{column}")
|
||||
return []
|
||||
|
||||
logger.debug(f"Inferred entity: {inferred[0].name}, type: {inferred[0].type}")
|
||||
|
||||
try:
|
||||
references = script.get_references(line=line, column=column, scope="project", include_builtins=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving references for entity at {script.path}:{line},{column}: {e}")
|
||||
references = []
|
||||
|
||||
for ref in references:
|
||||
if ref.module_path: # Collect unique module paths
|
||||
usages.add(ref.module_path)
|
||||
logger.info(f"Entity used in: {ref.module_path}")
|
||||
|
||||
return list(usages)
|
||||
|
||||
def parse_file_with_references(project, file_path):
|
||||
"""Parse a file to extract object names and their references within a project."""
|
||||
try:
|
||||
importable_objects = extract_importable_objects_with_positions(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting objects from {file_path}: {e}")
|
||||
return []
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning(f"Module file does not exist: {file_path}")
|
||||
return []
|
||||
|
||||
try:
|
||||
script = jedi.Script(path=file_path, project=project)
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Jedi Script: {e}")
|
||||
return []
|
||||
|
||||
parsed_results = [
|
||||
{
|
||||
"name": obj["name"],
|
||||
"type": obj["type"],
|
||||
"references": find_entity_usages(script, obj["line"], obj["column"]),
|
||||
}
|
||||
for obj in importable_objects
|
||||
]
|
||||
return parsed_results
|
||||
|
||||
|
||||
def parse_repo(repo_path):
|
||||
"""Parse a repository to extract object names, types, and references for all Python files."""
|
||||
try:
|
||||
project = jedi.Project(path=repo_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Jedi project for repository at {repo_path}: {e}")
|
||||
return {}
|
||||
|
||||
EXCLUDE_DIRS = {'venv', '.git', '__pycache__', 'build'}
|
||||
|
||||
python_files = [
|
||||
os.path.join(directory, file)
|
||||
for directory, _, filenames in os.walk(repo_path)
|
||||
if not any(excluded in directory for excluded in EXCLUDE_DIRS)
|
||||
for file in filenames
|
||||
if file.endswith(".py") and os.path.getsize(os.path.join(directory, file)) > 0
|
||||
]
|
||||
|
||||
results = {
|
||||
file_path: parse_file_with_references(project, file_path)
|
||||
for file_path in tqdm(python_files)
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
async def index_data_points(data_points: list[DataPoint]):
|
||||
created_indexes = {}
|
||||
index_points = {}
|
||||
|
|
@ -80,11 +81,20 @@ if __name__ == "__main__":
|
|||
class Car(DataPoint):
|
||||
model: str
|
||||
color: str
|
||||
_metadata = {
|
||||
"index_fields": ["name"],
|
||||
"type": "Car"
|
||||
}
|
||||
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
age: int
|
||||
owns_car: list[Car]
|
||||
_metadata = {
|
||||
"index_fields": ["name"],
|
||||
"type": "Person"
|
||||
}
|
||||
|
||||
car1 = Car(model = "Tesla Model S", color = "Blue")
|
||||
car2 = Car(model = "Toyota Camry", color = "Red")
|
||||
|
|
@ -92,4 +102,4 @@ if __name__ == "__main__":
|
|||
|
||||
data_points = get_data_points_from_model(person)
|
||||
|
||||
print(data_points)
|
||||
print(data_points)
|
||||
|
|
@ -10,6 +10,7 @@ class TextSummary(DataPoint):
|
|||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"],
|
||||
"type": "TextSummary"
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -20,4 +21,5 @@ class CodeSummary(DataPoint):
|
|||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["text"],
|
||||
"type": "CodeSummary"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,39 +1,40 @@
|
|||
import asyncio
|
||||
from typing import Type
|
||||
from typing import AsyncGenerator, Union
|
||||
from uuid import uuid5
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Type
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.extraction.extract_summary import extract_summary
|
||||
from cognee.shared.CodeGraphEntities import CodeFile
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
from cognee.modules.data.extraction.extract_summary import extract_code_summary
|
||||
from .models import CodeSummary
|
||||
|
||||
|
||||
async def summarize_code(
|
||||
code_files: list[DataPoint],
|
||||
summarization_model: Type[BaseModel],
|
||||
) -> list[DataPoint]:
|
||||
if len(code_files) == 0:
|
||||
return code_files
|
||||
code_graph_nodes: list[DataPoint],
|
||||
) -> AsyncGenerator[Union[DataPoint, CodeSummary], None]:
|
||||
if len(code_graph_nodes) == 0:
|
||||
return
|
||||
|
||||
code_files_data_points = [file for file in code_files if isinstance(file, CodeFile)]
|
||||
code_data_points = [file for file in code_graph_nodes if hasattr(file, "source_code")]
|
||||
|
||||
file_summaries = await asyncio.gather(
|
||||
*[extract_summary(file.source_code, summarization_model) for file in code_files_data_points]
|
||||
*[extract_code_summary(file.source_code) for file in code_data_points]
|
||||
)
|
||||
|
||||
summaries = [
|
||||
CodeSummary(
|
||||
id = uuid5(file.id, "CodeSummary"),
|
||||
made_from = file,
|
||||
text = file_summaries[file_index].summary,
|
||||
file_summaries_map = {
|
||||
code_data_point.extracted_id: str(file_summary)
|
||||
for code_data_point, file_summary in zip(code_data_points, file_summaries)
|
||||
}
|
||||
|
||||
for node in code_graph_nodes:
|
||||
if not isinstance(node, DataPoint):
|
||||
continue
|
||||
yield node
|
||||
|
||||
if not hasattr(node, "source_code"):
|
||||
continue
|
||||
|
||||
yield CodeSummary(
|
||||
id=uuid5(node.id, "CodeSummary"),
|
||||
made_from=node,
|
||||
text=file_summaries_map[node.extracted_id],
|
||||
)
|
||||
for (file_index, file) in enumerate(code_files_data_points)
|
||||
]
|
||||
|
||||
await add_data_points(summaries)
|
||||
|
||||
return code_files
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def test_AudioDocument():
|
|||
)
|
||||
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=64)
|
||||
GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
|
||||
):
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def test_ImageDocument():
|
|||
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
|
||||
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=64)
|
||||
GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
|
||||
):
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def test_PdfDocument():
|
|||
)
|
||||
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=1024)
|
||||
GROUND_TRUTH, document.read(chunk_size=1024, chunker='text_chunker')
|
||||
):
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def test_TextDocument(input_file, chunk_size):
|
|||
)
|
||||
|
||||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size)
|
||||
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker='text_chunker')
|
||||
):
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pathlib
|
|||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
|
@ -44,12 +45,13 @@ async def main():
|
|||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata = True)
|
||||
|
||||
dataset_name = "cs_explanations"
|
||||
dataset_name_1 = "natural_language"
|
||||
dataset_name_2 = "quantum"
|
||||
|
||||
explanation_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
|
||||
)
|
||||
await cognee.add([explanation_file_path], dataset_name)
|
||||
await cognee.add([explanation_file_path], dataset_name_1)
|
||||
|
||||
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
|
||||
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
|
||||
|
|
@ -59,12 +61,23 @@ async def main():
|
|||
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
|
||||
"""
|
||||
|
||||
await cognee.add([text], dataset_name)
|
||||
await cognee.add([text], dataset_name_2)
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
await cognee.cognify([dataset_name_2, dataset_name_1])
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
# Test getting of documents for search per dataset
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
user = await get_default_user()
|
||||
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
||||
assert len(document_ids) == 1, f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
||||
|
||||
# Test getting of documents for search when no dataset is provided
|
||||
user = await get_default_user()
|
||||
document_ids = await get_document_ids_for_user(user.id)
|
||||
assert len(document_ids) == 2, f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
|
@ -75,7 +88,7 @@ async def main():
|
|||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(SearchType.CHUNKS, query_text = random_node_name)
|
||||
search_results = await cognee.search(SearchType.CHUNKS, query_text = random_node_name, datasets=[dataset_name_2])
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted chunks are:\n")
|
||||
for result in search_results:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import random
|
||||
import time
|
||||
from typing import List
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
|
|
@ -11,16 +11,28 @@ random.seed(1500)
|
|||
|
||||
class Repository(DataPoint):
|
||||
path: str
|
||||
_metadata = {
|
||||
"index_fields": [],
|
||||
"type": "Repository"
|
||||
}
|
||||
|
||||
class CodeFile(DataPoint):
|
||||
part_of: Repository
|
||||
contains: List["CodePart"] = []
|
||||
depends_on: List["CodeFile"] = []
|
||||
source_code: str
|
||||
_metadata = {
|
||||
"index_fields": [],
|
||||
"type": "CodeFile"
|
||||
}
|
||||
|
||||
class CodePart(DataPoint):
|
||||
part_of: CodeFile
|
||||
source_code: str
|
||||
_metadata = {
|
||||
"index_fields": [],
|
||||
"type": "CodePart"
|
||||
}
|
||||
|
||||
CodeFile.model_rebuild()
|
||||
CodePart.model_rebuild()
|
||||
|
|
|
|||
|
|
@ -1,25 +1,42 @@
|
|||
import asyncio
|
||||
import random
|
||||
from typing import List
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
|
||||
|
||||
class Document(DataPoint):
|
||||
path: str
|
||||
_metadata = {
|
||||
"index_fields": [],
|
||||
"type": "Document"
|
||||
}
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
part_of: Document
|
||||
text: str
|
||||
contains: List["Entity"] = None
|
||||
_metadata = {
|
||||
"index_fields": ["text"],
|
||||
"type": "DocumentChunk"
|
||||
}
|
||||
|
||||
class EntityType(DataPoint):
|
||||
name: str
|
||||
_metadata = {
|
||||
"index_fields": ["name"],
|
||||
"type": "EntityType"
|
||||
}
|
||||
|
||||
class Entity(DataPoint):
|
||||
name: str
|
||||
is_type: EntityType
|
||||
_metadata = {
|
||||
"index_fields": ["name"],
|
||||
"type": "Entity"
|
||||
}
|
||||
|
||||
DocumentChunk.model_rebuild()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,19 +7,13 @@ from pathlib import Path
|
|||
from swebench.harness.utils import load_swebench_dataset
|
||||
from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE
|
||||
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.modules.pipelines import Task, run_tasks
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import \
|
||||
brute_force_triplet_search
|
||||
# from cognee.shared.data_models import SummarizedContent
|
||||
from cognee.shared.utils import render_graph
|
||||
from cognee.tasks.repo_processor import (enrich_dependency_graph,
|
||||
expand_dependency_graph,
|
||||
get_repo_file_dependencies)
|
||||
from cognee.tasks.storage import add_data_points
|
||||
# from cognee.tasks.summarization import summarize_code
|
||||
from evals.eval_utils import download_github_repo, retrieved_edges_to_string
|
||||
|
||||
|
||||
|
|
@ -42,48 +36,22 @@ def check_install_package(package_name):
|
|||
|
||||
|
||||
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
|
||||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
file_path = Path(__file__).parent
|
||||
data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve())
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve())
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata = True)
|
||||
|
||||
await create_db_and_tables()
|
||||
|
||||
# repo_path = download_github_repo(instance, '../RAW_GIT_REPOS')
|
||||
|
||||
repo_path = '/Users/borisarzentar/Projects/graphrag'
|
||||
|
||||
tasks = [
|
||||
Task(get_repo_file_dependencies),
|
||||
Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
|
||||
Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
|
||||
Task(add_data_points, task_config = { "batch_size": 50 }),
|
||||
# Task(summarize_code, summarization_model = SummarizedContent),
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, repo_path, "cognify_code_pipeline")
|
||||
repo_path = download_github_repo(instance, '../RAW_GIT_REPOS')
|
||||
pipeline = await run_code_graph_pipeline(repo_path)
|
||||
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
|
||||
print('Here we have the repo under the repo_path')
|
||||
|
||||
await render_graph(None, include_labels = True, include_nodes = True)
|
||||
await render_graph(None, include_labels=True, include_nodes=True)
|
||||
|
||||
problem_statement = instance['problem_statement']
|
||||
instructions = read_query_prompt("patch_gen_kg_instructions.txt")
|
||||
|
||||
retrieved_edges = await brute_force_triplet_search(problem_statement, top_k = 3, collections = ["data_point_source_code", "data_point_text"])
|
||||
|
||||
retrieved_edges = await brute_force_triplet_search(problem_statement, top_k=3,
|
||||
collections=["data_point_source_code", "data_point_text"])
|
||||
|
||||
retrieved_edges_str = retrieved_edges_to_string(retrieved_edges)
|
||||
|
||||
prompt = "\n".join([
|
||||
|
|
@ -171,7 +139,6 @@ async def main():
|
|||
with open(predictions_path, "w") as file:
|
||||
json.dump(preds, file)
|
||||
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"python",
|
||||
|
|
|
|||
15
examples/python/code_graph_example.py
Normal file
15
examples/python/code_graph_example.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
|
||||
async def main(repo_path):
|
||||
async for result in await run_code_graph_pipeline(repo_path):
|
||||
print(result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--repo-path", type=str, required=True, help="Path to the repository")
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(args.repo_path))
|
||||
|
||||
Loading…
Add table
Reference in a new issue