Merge branch 'dev' into COG-578

This commit is contained in:
Boris 2024-12-17 15:04:13 +01:00 committed by GitHub
commit be424249d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 614 additions and 190 deletions

View file

@ -17,6 +17,7 @@ jobs:
run_deduplication_test: run_deduplication_test:
name: test name: test
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
defaults: defaults:
run: run:
shell: bash shell: bash

View file

@ -18,6 +18,7 @@ jobs:
run_milvus: run_milvus:
name: test name: test
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
strategy: strategy:
fail-fast: false fail-fast: false
defaults: defaults:

View file

@ -15,6 +15,7 @@ env:
jobs: jobs:
run_neo4j_integration_test: run_neo4j_integration_test:
name: test name: test
if: ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:

View file

@ -18,6 +18,7 @@ jobs:
run_pgvector_integration_test: run_pgvector_integration_test:
name: test name: test
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
defaults: defaults:
run: run:
shell: bash shell: bash

View file

@ -18,6 +18,7 @@ jobs:
run_common: run_common:
name: test name: test
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
strategy: strategy:
fail-fast: false fail-fast: false
defaults: defaults:

View file

@ -18,6 +18,7 @@ jobs:
run_common: run_common:
name: test name: test
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
strategy: strategy:
fail-fast: false fail-fast: false
defaults: defaults:

View file

@ -18,6 +18,7 @@ jobs:
run_common: run_common:
name: test name: test
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
strategy: strategy:
fail-fast: false fail-fast: false
defaults: defaults:

View file

@ -18,6 +18,7 @@ jobs:
run_qdrant_integration_test: run_qdrant_integration_test:
name: test name: test
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
defaults: defaults:
run: run:

View file

@ -18,6 +18,7 @@ jobs:
run_weaviate_integration_test: run_weaviate_integration_test:
name: test name: test
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ github.event.label.name == 'run-checks' }}
defaults: defaults:
run: run:

View file

@ -1,22 +1,35 @@
# NOTICE: This module contains deprecated functions.
# Use only the run_code_graph_pipeline function; all other functions are deprecated.
# Related issue: COG-906
import asyncio import asyncio
import logging import logging
from pathlib import Path
from typing import Union from typing import Union
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.models import Data, Dataset
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import \
get_pipeline_status
from cognee.modules.pipelines.operations.log_pipeline_status import \
log_pipeline_status
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.SourceCodeGraph import SourceCodeGraph from cognee.shared.SourceCodeGraph import SourceCodeGraph
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.modules.data.models import Dataset, Data from cognee.tasks.documents import (check_permissions_on_documents,
from cognee.modules.data.methods.get_dataset_data import get_dataset_data classify_documents,
from cognee.modules.data.methods import get_datasets, get_datasets_by_name extract_chunks_from_documents)
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines import run_tasks
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_code from cognee.tasks.graph import extract_graph_from_code
from cognee.tasks.repo_processor import (enrich_dependency_graph,
expand_dependency_graph,
get_repo_file_dependencies)
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_code
logger = logging.getLogger("code_graph_pipeline") logger = logging.getLogger("code_graph_pipeline")
@ -51,6 +64,7 @@ async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User
async def run_pipeline(dataset: Dataset, user: User): async def run_pipeline(dataset: Dataset, user: User):
'''DEPRECATED: Use `run_code_graph_pipeline` instead. This function will be removed.'''
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id) data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
document_ids_str = [str(document.id) for document in data_documents] document_ids_str = [str(document.id) for document in data_documents]
@ -103,3 +117,30 @@ async def run_pipeline(dataset: Dataset, user: User):
def generate_dataset_name(dataset_name: str) -> str: def generate_dataset_name(dataset_name: str) -> str:
return dataset_name.replace(".", "_").replace(" ", "_") return dataset_name.replace(".", "_").replace(" ", "_")
async def run_code_graph_pipeline(repo_path):
import os
import pathlib
import cognee
from cognee.infrastructure.databases.relational import create_db_and_tables
file_path = Path(__file__).parent
data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve())
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve())
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await create_db_and_tables()
tasks = [
Task(get_repo_file_dependencies),
Task(enrich_dependency_graph, task_config={"batch_size": 50}),
Task(expand_dependency_graph, task_config={"batch_size": 50}),
Task(summarize_code, task_config={"batch_size": 50}),
Task(add_data_points, task_config={"batch_size": 50}),
]
return run_tasks(tasks, repo_path, "cognify_code_pipeline")

View file

@ -69,17 +69,18 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, graph_model: BaseMo
send_telemetry("cognee.cognify EXECUTION STARTED", user.id) send_telemetry("cognee.cognify EXECUTION STARTED", user.id)
async with update_status_lock: #async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests
task_status = await get_pipeline_status([dataset_id]) task_status = await get_pipeline_status([dataset_id])
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED: if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
logger.info("Dataset %s is already being processed.", dataset_name) logger.info("Dataset %s is already being processed.", dataset_name)
return return
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
try: try:
cognee_config = get_cognify_config() cognee_config = get_cognify_config()

View file

@ -1,13 +1,15 @@
from fastapi import APIRouter from fastapi import APIRouter
from typing import List from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from cognee.modules.users.models import User from cognee.modules.users.models import User
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from cognee.modules.users.methods import get_authenticated_user from cognee.modules.users.methods import get_authenticated_user
from fastapi import Depends from fastapi import Depends
from cognee.shared.data_models import KnowledgeGraph
class CognifyPayloadDTO(BaseModel): class CognifyPayloadDTO(BaseModel):
datasets: List[str] datasets: List[str]
graph_model: Optional[BaseModel] = KnowledgeGraph
def get_cognify_router() -> APIRouter: def get_cognify_router() -> APIRouter:
router = APIRouter() router = APIRouter()
@ -17,7 +19,7 @@ def get_cognify_router() -> APIRouter:
""" This endpoint is responsible for the cognitive processing of the content.""" """ This endpoint is responsible for the cognitive processing of the content."""
from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify from cognee.api.v1.cognify.cognify_v2 import cognify as cognee_cognify
try: try:
await cognee_cognify(payload.datasets, user) await cognee_cognify(payload.datasets, user, payload.graph_model)
except Exception as error: except Exception as error:
return JSONResponse( return JSONResponse(
status_code=409, status_code=409,

View file

@ -1,7 +1,7 @@
import json import json
from uuid import UUID from uuid import UUID
from enum import Enum from enum import Enum
from typing import Callable, Dict from typing import Callable, Dict, Union
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.modules.search.operations import log_query, log_result from cognee.modules.search.operations import log_query, log_result
@ -22,7 +22,12 @@ class SearchType(Enum):
CHUNKS = "CHUNKS" CHUNKS = "CHUNKS"
COMPLETION = "COMPLETION" COMPLETION = "COMPLETION"
async def search(query_type: SearchType, query_text: str, user: User = None) -> list: async def search(query_type: SearchType, query_text: str, user: User = None,
datasets: Union[list[str], str, None] = None) -> list:
# We use lists from now on for datasets
if isinstance(datasets, str):
datasets = [datasets]
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()
@ -31,7 +36,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) ->
query = await log_query(query_text, str(query_type), user.id) query = await log_query(query_text, str(query_type), user.id)
own_document_ids = await get_document_ids_for_user(user.id) own_document_ids = await get_document_ids_for_user(user.id, datasets)
search_results = await specific_search(query_type, query_text, user) search_results = await specific_search(query_type, query_text, user)
filtered_search_results = [] filtered_search_results = []

View file

@ -1,21 +1,26 @@
import asyncio import asyncio
# from datetime import datetime # from datetime import datetime
import json import json
from uuid import UUID
from textwrap import dedent from textwrap import dedent
from uuid import UUID
from falkordb import FalkorDB from falkordb import FalkorDB
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.databases.graph.graph_db_interface import \
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface GraphDBInterface
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface from cognee.infrastructure.databases.vector.vector_db_interface import \
VectorDBInterface
from cognee.infrastructure.engine import DataPoint
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }
class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):

View file

@ -1,25 +1,29 @@
from typing import List, Optional, get_type_hints, Generic, TypeVar
import asyncio import asyncio
from typing import Generic, List, Optional, TypeVar, get_type_hints
from uuid import UUID from uuid import UUID
import lancedb import lancedb
from lancedb.pydantic import LanceModel, Vector
from pydantic import BaseModel from pydantic import BaseModel
from lancedb.pydantic import Vector, LanceModel
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties from cognee.modules.storage.utils import copy_model, get_own_properties
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..utils import normalize_distances
from ..vector_db_interface import VectorDBInterface
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
id: str id: str
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }
class LanceDBAdapter(VectorDBInterface): class LanceDBAdapter(VectorDBInterface):

View file

@ -4,10 +4,12 @@ import asyncio
import logging import logging
from typing import List, Optional from typing import List, Optional
from uuid import UUID from uuid import UUID
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
logger = logging.getLogger("MilvusAdapter") logger = logging.getLogger("MilvusAdapter")
@ -16,7 +18,8 @@ class IndexSchema(DataPoint):
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }

View file

@ -1,27 +1,30 @@
import asyncio import asyncio
from uuid import UUID
from typing import List, Optional, get_type_hints from typing import List, Optional, get_type_hints
from uuid import UUID
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.databases.exceptions import EntityNotFoundError from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from .serialize_data import serialize_data
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ...relational.ModelBase import Base from ...relational.ModelBase import Base
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..utils import normalize_distances
from ..vector_db_interface import VectorDBInterface
from .serialize_data import serialize_data
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
@ -48,10 +51,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def has_collection(self, collection_name: str) -> bool: async def has_collection(self, collection_name: str) -> bool:
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
# Load the schema information into the MetaData object # Create a MetaData instance to load table information
await connection.run_sync(Base.metadata.reflect) metadata = MetaData()
# Load table information from schema into MetaData
await connection.run_sync(metadata.reflect)
if collection_name in Base.metadata.tables: if collection_name in metadata.tables:
return True return True
else: else:
return False return False
@ -87,6 +92,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def create_data_points( async def create_data_points(
self, collection_name: str, data_points: List[DataPoint] self, collection_name: str, data_points: List[DataPoint]
): ):
data_point_types = get_type_hints(DataPoint)
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
await self.create_collection( await self.create_collection(
collection_name = collection_name, collection_name = collection_name,
@ -106,7 +112,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
primary_key: Mapped[int] = mapped_column( primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True primary_key=True, autoincrement=True
) )
id: Mapped[type(data_points[0].id)] id: Mapped[data_point_types["id"]]
payload = Column(JSON) payload = Column(JSON)
vector = Column(self.Vector(vector_size)) vector = Column(self.Vector(vector_size))
@ -145,10 +151,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
with an async engine. with an async engine.
""" """
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
# Load the schema information into the MetaData object # Create a MetaData instance to load table information
await connection.run_sync(Base.metadata.reflect) metadata = MetaData()
if collection_name in Base.metadata.tables: # Load table information from schema into MetaData
return Base.metadata.tables[collection_name] await connection.run_sync(metadata.reflect)
if collection_name in metadata.tables:
return metadata.tables[collection_name]
else: else:
raise EntityNotFoundError(message=f"Table '{collection_name}' not found.") raise EntityNotFoundError(message=f"Table '{collection_name}' not found.")

View file

@ -1,13 +1,16 @@
import logging import logging
from typing import Dict, List, Optional
from uuid import UUID from uuid import UUID
from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult from cognee.infrastructure.databases.vector.models.ScoredResult import \
ScoredResult
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..vector_db_interface import VectorDBInterface
logger = logging.getLogger("QDrantAdapter") logger = logging.getLogger("QDrantAdapter")
@ -15,7 +18,8 @@ class IndexSchema(DataPoint):
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }
# class CollectionConfig(BaseModel, extra = "forbid"): # class CollectionConfig(BaseModel, extra = "forbid"):

View file

@ -5,9 +5,10 @@ from uuid import UUID
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
logger = logging.getLogger("WeaviateAdapter") logger = logging.getLogger("WeaviateAdapter")
@ -15,7 +16,8 @@ class IndexSchema(DataPoint):
text: str text: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"] "index_fields": ["text"],
"type": "IndexSchema"
} }
class WeaviateAdapter(VectorDBInterface): class WeaviateAdapter(VectorDBInterface):

View file

@ -1,8 +1,10 @@
from typing_extensions import TypedDict
from uuid import UUID, uuid4
from typing import Optional
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional
from uuid import UUID, uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import TypedDict
class MetaData(TypedDict): class MetaData(TypedDict):
index_fields: list[str] index_fields: list[str]
@ -13,7 +15,8 @@ class DataPoint(BaseModel):
updated_at: Optional[datetime] = datetime.now(timezone.utc) updated_at: Optional[datetime] = datetime.now(timezone.utc)
topological_rank: Optional[int] = 0 topological_rank: Optional[int] = 0
_metadata: Optional[MetaData] = { _metadata: Optional[MetaData] = {
"index_fields": [] "index_fields": [],
"type": "DataPoint"
} }
# class Config: # class Config:

View file

@ -0,0 +1,10 @@
You are an expert Python programmer and technical writer. Your task is to summarize the given Python code snippet or file.
The code may contain multiple imports, classes, functions, constants and logic. Provide a clear, structured explanation of its components
and their relationships.
Instructions:
Provide an overview: Start with a high-level summary of what the code does as a whole.
Break it down: Summarize each class and function individually, explaining their purpose and how they interact.
Describe the workflow: Outline how the classes and functions work together. Mention any control flow (e.g., main functions, entry points, loops).
Key features: Highlight important elements like arguments, return values, or unique logic.
Maintain clarity: Write in plain English for someone familiar with Python but unfamiliar with this code.

View file

@ -1,8 +1,10 @@
from typing import List, Optional from typing import List, Optional
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity from cognee.modules.engine.models import Entity
class DocumentChunk(DataPoint): class DocumentChunk(DataPoint):
__tablename__ = "document_chunk" __tablename__ = "document_chunk"
text: str text: str
@ -12,6 +14,7 @@ class DocumentChunk(DataPoint):
is_part_of: Document is_part_of: Document
contains: List[Entity] = None contains: List[Entity] = None
_metadata: Optional[dict] = { _metadata: dict = {
"index_fields": ["text"], "index_fields": ["text"],
"type": "DocumentChunk"
} }

View file

@ -1,7 +1,11 @@
from typing import Type from typing import Type
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.shared.data_models import SummarizedCode
async def extract_summary(content: str, response_model: Type[BaseModel]): async def extract_summary(content: str, response_model: Type[BaseModel]):
llm_client = get_llm_client() llm_client = get_llm_client()
@ -11,3 +15,7 @@ async def extract_summary(content: str, response_model: Type[BaseModel]):
llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model) llm_output = await llm_client.acreate_structured_output(content, system_prompt, response_model)
return llm_output return llm_output
async def extract_code_summary(content: str):
return await extract_summary(content, response_model=SummarizedCode)

View file

@ -1,2 +1,3 @@
from .Data import Data from .Data import Data
from .Dataset import Dataset from .Dataset import Dataset
from .DatasetData import DatasetData

View file

@ -1,6 +1,6 @@
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
from .ChunkerMapping import ChunkerConfig
class AudioDocument(Document): class AudioDocument(Document):
type: str = "audio" type: str = "audio"
@ -9,11 +9,12 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location) result = get_llm_client().create_transcript(self.raw_data_location)
return(result.text) return(result.text)
def read(self, chunk_size: int): def read(self, chunk_size: int, chunker: str):
# Transcribe the audio file # Transcribe the audio file
text = self.create_transcript() text = self.create_transcript()
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text]) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
yield from chunker.read() yield from chunker.read()

View file

@ -0,0 +1,15 @@
from cognee.modules.chunking.TextChunker import TextChunker
class ChunkerConfig:
chunker_mapping = {
"text_chunker": TextChunker
}
@classmethod
def get_chunker(cls, chunker_name: str):
chunker_class = cls.chunker_mapping.get(chunker_name)
if chunker_class is None:
raise NotImplementedError(
f"Chunker '{chunker_name}' is not implemented. Available options: {list(cls.chunker_mapping.keys())}"
)
return chunker_class

View file

@ -1,12 +1,17 @@
from cognee.infrastructure.engine import DataPoint
from uuid import UUID from uuid import UUID
from cognee.infrastructure.engine import DataPoint
class Document(DataPoint): class Document(DataPoint):
type: str
name: str name: str
raw_data_location: str raw_data_location: str
metadata_id: UUID metadata_id: UUID
mime_type: str mime_type: str
_metadata: dict = {
"index_fields": ["name"],
"type": "Document"
}
def read(self, chunk_size: int) -> str: def read(self, chunk_size: int, chunker = str) -> str:
pass pass

View file

@ -1,6 +1,6 @@
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
from .ChunkerMapping import ChunkerConfig
class ImageDocument(Document): class ImageDocument(Document):
type: str = "image" type: str = "image"
@ -10,10 +10,11 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location) result = get_llm_client().transcribe_image(self.raw_data_location)
return(result.choices[0].message.content) return(result.choices[0].message.content)
def read(self, chunk_size: int): def read(self, chunk_size: int, chunker: str):
# Transcribe the image file # Transcribe the image file
text = self.transcribe_image() text = self.transcribe_image()
chunker = TextChunker(self, chunk_size = chunk_size, get_text = lambda: [text]) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = lambda: [text])
yield from chunker.read() yield from chunker.read()

View file

@ -1,11 +1,11 @@
from pypdf import PdfReader from pypdf import PdfReader
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
from .ChunkerMapping import ChunkerConfig
class PdfDocument(Document): class PdfDocument(Document):
type: str = "pdf" type: str = "pdf"
def read(self, chunk_size: int): def read(self, chunk_size: int, chunker: str):
file = PdfReader(self.raw_data_location) file = PdfReader(self.raw_data_location)
def get_text(): def get_text():
@ -13,7 +13,8 @@ class PdfDocument(Document):
page_text = page.extract_text() page_text = page.extract_text()
yield page_text yield page_text
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
yield from chunker.read() yield from chunker.read()

View file

@ -1,10 +1,10 @@
from cognee.modules.chunking.TextChunker import TextChunker
from .Document import Document from .Document import Document
from .ChunkerMapping import ChunkerConfig
class TextDocument(Document): class TextDocument(Document):
type: str = "text" type: str = "text"
def read(self, chunk_size: int): def read(self, chunk_size: int, chunker: str):
def get_text(): def get_text():
with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file: with open(self.raw_data_location, mode = "r", encoding = "utf-8") as file:
while True: while True:
@ -15,6 +15,8 @@ class TextDocument(Document):
yield text yield text
chunker = TextChunker(self, chunk_size = chunk_size, get_text = get_text) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func(self, chunk_size = chunk_size, get_text = get_text)
yield from chunker.read() yield from chunker.read()

View file

@ -10,4 +10,5 @@ class Entity(DataPoint):
_metadata: dict = { _metadata: dict = {
"index_fields": ["name"], "index_fields": ["name"],
"type": "Entity"
} }

View file

@ -1,11 +1,12 @@
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
class EntityType(DataPoint): class EntityType(DataPoint):
__tablename__ = "entity_type" __tablename__ = "entity_type"
name: str name: str
type: str
description: str description: str
_metadata: dict = { _metadata: dict = {
"index_fields": ["name"], "index_fields": ["name"],
"type": "EntityType"
} }

View file

@ -1,11 +1,14 @@
from typing import Optional from typing import Optional
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
class EdgeType(DataPoint): class EdgeType(DataPoint):
__tablename__ = "edge_type" __tablename__ = "edge_type"
relationship_name: str relationship_name: str
number_of_edges: int number_of_edges: int
_metadata: Optional[dict] = { _metadata: dict = {
"index_fields": ["relationship_name"], "index_fields": ["relationship_name"],
"type": "EdgeType"
} }

View file

@ -2,7 +2,7 @@ from cognee.infrastructure.engine import DataPoint
def convert_node_to_data_point(node_data: dict) -> DataPoint: def convert_node_to_data_point(node_data: dict) -> DataPoint:
subclass = find_subclass_by_name(DataPoint, node_data["type"]) subclass = find_subclass_by_name(DataPoint, node_data._metadata["type"])
return subclass(**node_data) return subclass(**node_data)

View file

@ -1,9 +1,11 @@
from uuid import UUID from uuid import UUID
from sqlalchemy import select from sqlalchemy import select
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Dataset, DatasetData
from ...models import ACL, Resource, Permission from ...models import ACL, Resource, Permission
async def get_document_ids_for_user(user_id: UUID) -> list[str]:
async def get_document_ids_for_user(user_id: UUID, datasets: list[str] = None) -> list[str]:
db_engine = get_relational_engine() db_engine = get_relational_engine()
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
@ -18,4 +20,31 @@ async def get_document_ids_for_user(user_id: UUID) -> list[str]:
) )
)).all() )).all()
if datasets:
documents_ids_in_dataset = set()
# If datasets are specified filter out documents that aren't part of the specified datasets
for dataset in datasets:
# Find dataset id for dataset element
dataset_id = (await session.scalars(
select(Dataset.id)
.where(
Dataset.name == dataset,
Dataset.owner_id == user_id,
)
)).one_or_none()
# Check which documents are connected to this dataset
for document_id in document_ids:
data_id = (await session.scalars(
select(DatasetData.data_id)
.where(
DatasetData.dataset_id == dataset_id,
DatasetData.data_id == document_id,
)
)).one_or_none()
# If document is related to dataset added it to return value
if data_id:
documents_ids_in_dataset.add(document_id)
return list(documents_ids_in_dataset)
return document_ids return document_ids

View file

@ -1,15 +1,19 @@
from typing import List, Optional from typing import List, Optional
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
class Repository(DataPoint): class Repository(DataPoint):
__tablename__ = "Repository" __tablename__ = "Repository"
path: str path: str
type: Optional[str] = "Repository" _metadata: dict = {
"index_fields": ["source_code"],
"type": "Repository"
}
class CodeFile(DataPoint): class CodeFile(DataPoint):
__tablename__ = "codefile" __tablename__ = "codefile"
extracted_id: str # actually file path extracted_id: str # actually file path
type: Optional[str] = "CodeFile"
source_code: Optional[str] = None source_code: Optional[str] = None
part_of: Optional[Repository] = None part_of: Optional[Repository] = None
depends_on: Optional[List["CodeFile"]] = None depends_on: Optional[List["CodeFile"]] = None
@ -17,24 +21,27 @@ class CodeFile(DataPoint):
contains: Optional[List["CodePart"]] = None contains: Optional[List["CodePart"]] = None
_metadata: dict = { _metadata: dict = {
"index_fields": ["source_code"] "index_fields": ["source_code"],
"type": "CodeFile"
} }
class CodePart(DataPoint): class CodePart(DataPoint):
__tablename__ = "codepart" __tablename__ = "codepart"
# part_of: Optional[CodeFile] # part_of: Optional[CodeFile]
source_code: str source_code: str
type: Optional[str] = "CodePart"
_metadata: dict = { _metadata: dict = {
"index_fields": ["source_code"] "index_fields": ["source_code"],
"type": "CodePart"
} }
class CodeRelationship(DataPoint): class CodeRelationship(DataPoint):
source_id: str source_id: str
target_id: str target_id: str
type: str # between files
relation: str # depends on or depends directly relation: str # depends on or depends directly
_metadata: dict = {
"type": "CodeRelationship"
}
CodeFile.model_rebuild() CodeFile.model_rebuild()
CodePart.model_rebuild() CodePart.model_rebuild()

View file

@ -1,79 +1,90 @@
from typing import Any, List, Union, Literal, Optional from typing import Any, List, Literal, Optional, Union
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
class Variable(DataPoint): class Variable(DataPoint):
id: str id: str
name: str name: str
type: Literal["Variable"] = "Variable"
description: str description: str
is_static: Optional[bool] = False is_static: Optional[bool] = False
default_value: Optional[str] = None default_value: Optional[str] = None
data_type: str data_type: str
_metadata = { _metadata = {
"index_fields": ["name"] "index_fields": ["name"],
"type": "Variable"
} }
class Operator(DataPoint): class Operator(DataPoint):
id: str id: str
name: str name: str
type: Literal["Operator"] = "Operator"
description: str description: str
return_type: str return_type: str
_metadata = {
"index_fields": ["name"],
"type": "Operator"
}
class Class(DataPoint): class Class(DataPoint):
id: str id: str
name: str name: str
type: Literal["Class"] = "Class"
description: str description: str
constructor_parameters: List[Variable] constructor_parameters: List[Variable]
extended_from_class: Optional["Class"] = None extended_from_class: Optional["Class"] = None
has_methods: List["Function"] has_methods: List["Function"]
_metadata = { _metadata = {
"index_fields": ["name"] "index_fields": ["name"],
"type": "Class"
} }
class ClassInstance(DataPoint): class ClassInstance(DataPoint):
id: str id: str
name: str name: str
type: Literal["ClassInstance"] = "ClassInstance"
description: str description: str
from_class: Class from_class: Class
instantiated_by: Union["Function"] instantiated_by: Union["Function"]
instantiation_arguments: List[Variable] instantiation_arguments: List[Variable]
_metadata = { _metadata = {
"index_fields": ["name"] "index_fields": ["name"],
"type": "ClassInstance"
} }
class Function(DataPoint): class Function(DataPoint):
id: str id: str
name: str name: str
type: Literal["Function"] = "Function"
description: str description: str
parameters: List[Variable] parameters: List[Variable]
return_type: str return_type: str
is_static: Optional[bool] = False is_static: Optional[bool] = False
_metadata = { _metadata = {
"index_fields": ["name"] "index_fields": ["name"],
"type": "Function"
} }
class FunctionCall(DataPoint): class FunctionCall(DataPoint):
id: str id: str
type: Literal["FunctionCall"] = "FunctionCall"
called_by: Union[Function, Literal["main"]] called_by: Union[Function, Literal["main"]]
function_called: Function function_called: Function
function_arguments: List[Any] function_arguments: List[Any]
_metadata = {
"index_fields": [],
"type": "FunctionCall"
}
class Expression(DataPoint): class Expression(DataPoint):
id: str id: str
name: str name: str
type: Literal["Expression"] = "Expression"
description: str description: str
expression: str expression: str
members: List[Union[Variable, Function, Operator, "Expression"]] members: List[Union[Variable, Function, Operator, "Expression"]]
_metadata = {
"index_fields": ["name"],
"type": "Expression"
}
class SourceCodeGraph(DataPoint): class SourceCodeGraph(DataPoint):
id: str id: str
@ -89,6 +100,11 @@ class SourceCodeGraph(DataPoint):
Operator, Operator,
Expression, Expression,
]] ]]
_metadata = {
"index_fields": ["name"],
"type": "SourceCodeGraph"
}
Class.model_rebuild() Class.model_rebuild()
ClassInstance.model_rebuild() ClassInstance.model_rebuild()
Expression.model_rebuild() Expression.model_rebuild()

View file

@ -1,9 +1,11 @@
"""Data models for the cognitive architecture.""" """Data models for the cognitive architecture."""
from enum import Enum, auto from enum import Enum, auto
from typing import Optional, List, Union, Dict, Any from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class Node(BaseModel): class Node(BaseModel):
"""Node in a knowledge graph.""" """Node in a knowledge graph."""
id: str id: str
@ -194,6 +196,29 @@ class SummarizedContent(BaseModel):
summary: str summary: str
description: str description: str
class SummarizedFunction(BaseModel):
name: str
description: str
inputs: Optional[List[str]] = None
outputs: Optional[List[str]] = None
decorators: Optional[List[str]] = None
class SummarizedClass(BaseModel):
name: str
description: str
methods: Optional[List[SummarizedFunction]] = None
decorators: Optional[List[str]] = None
class SummarizedCode(BaseModel):
file_name: str
high_level_summary: str
key_features: List[str]
imports: List[str] = []
constants: List[str] = []
classes: List[SummarizedClass] = []
functions: List[SummarizedFunction] = []
workflow_description: Optional[str] = None
class GraphDBType(Enum): class GraphDBType(Enum):
NETWORKX = auto() NETWORKX = auto()

View file

@ -1,7 +1,7 @@
from cognee.modules.data.processing.document_types.Document import Document from cognee.modules.data.processing.document_types.Document import Document
async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024): async def extract_chunks_from_documents(documents: list[Document], chunk_size: int = 1024, chunker = 'text_chunker'):
for document in documents: for document in documents:
for document_chunk in document.read(chunk_size = chunk_size): for document_chunk in document.read(chunk_size = chunk_size, chunker = chunker):
yield document_chunk yield document_chunk

View file

@ -0,0 +1,171 @@
import os
import jedi
import parso
from tqdm import tqdm
from . import logger
_NODE_TYPE_MAP = {
"funcdef": "func_def",
"classdef": "class_def",
"async_funcdef": "async_func_def",
"async_stmt": "async_func_def",
"simple_stmt": "var_def",
}
def _create_object_dict(name_node, type_name=None):
return {
"name": name_node.value,
"line": name_node.start_pos[0],
"column": name_node.start_pos[1],
"type": type_name,
}
def _parse_node(node):
"""Parse a node to extract importable object details, including async functions and classes."""
node_type = _NODE_TYPE_MAP.get(node.type)
if node.type in {"funcdef", "classdef", "async_funcdef"}:
return [_create_object_dict(node.name, type_name=node_type)]
if node.type == "async_stmt" and len(node.children) > 1:
function_node = node.children[1]
if function_node.type == "funcdef":
return [_create_object_dict(function_node.name, type_name=_NODE_TYPE_MAP.get(function_node.type))]
if node.type == "simple_stmt":
# TODO: Handle multi-level/nested unpacking variable definitions in the future
expr_child = node.children[0]
if expr_child.type != "expr_stmt":
return []
if expr_child.children[0].type == "testlist_star_expr":
name_targets = expr_child.children[0].children
else:
name_targets = expr_child.children
return [
_create_object_dict(target, type_name=_NODE_TYPE_MAP.get(target.type))
for target in name_targets
if target.type == "name"
]
return []
def extract_importable_objects_with_positions_from_source_code(source_code):
"""Extract top-level objects in a Python source code string with their positions (line/column)."""
try:
tree = parso.parse(source_code)
except Exception as e:
logger.error(f"Error parsing source code: {e}")
return []
importable_objects = []
try:
for node in tree.children:
importable_objects.extend(_parse_node(node))
except Exception as e:
logger.error(f"Error extracting nodes from parsed tree: {e}")
return []
return importable_objects
def extract_importable_objects_with_positions(file_path):
"""Extract top-level objects in a Python file with their positions (line/column)."""
try:
with open(file_path, "r") as file:
source_code = file.read()
except Exception as e:
logger.error(f"Error reading file {file_path}: {e}")
return []
return extract_importable_objects_with_positions_from_source_code(source_code)
def find_entity_usages(script, line, column):
"""
Return a list of files in the repo where the entity at module_path:line,column is used.
"""
usages = set()
try:
inferred = script.infer(line, column)
except Exception as e:
logger.error(f"Error inferring entity at {script.path}:{line},{column}: {e}")
return []
if not inferred or not inferred[0]:
logger.info(f"No entity inferred at {script.path}:{line},{column}")
return []
logger.debug(f"Inferred entity: {inferred[0].name}, type: {inferred[0].type}")
try:
references = script.get_references(line=line, column=column, scope="project", include_builtins=False)
except Exception as e:
logger.error(f"Error retrieving references for entity at {script.path}:{line},{column}: {e}")
references = []
for ref in references:
if ref.module_path: # Collect unique module paths
usages.add(ref.module_path)
logger.info(f"Entity used in: {ref.module_path}")
return list(usages)
def parse_file_with_references(project, file_path):
"""Parse a file to extract object names and their references within a project."""
try:
importable_objects = extract_importable_objects_with_positions(file_path)
except Exception as e:
logger.error(f"Error extracting objects from {file_path}: {e}")
return []
if not os.path.isfile(file_path):
logger.warning(f"Module file does not exist: {file_path}")
return []
try:
script = jedi.Script(path=file_path, project=project)
except Exception as e:
logger.error(f"Error initializing Jedi Script: {e}")
return []
parsed_results = [
{
"name": obj["name"],
"type": obj["type"],
"references": find_entity_usages(script, obj["line"], obj["column"]),
}
for obj in importable_objects
]
return parsed_results
def parse_repo(repo_path):
"""Parse a repository to extract object names, types, and references for all Python files."""
try:
project = jedi.Project(path=repo_path)
except Exception as e:
logger.error(f"Error creating Jedi project for repository at {repo_path}: {e}")
return {}
EXCLUDE_DIRS = {'venv', '.git', '__pycache__', 'build'}
python_files = [
os.path.join(directory, file)
for directory, _, filenames in os.walk(repo_path)
if not any(excluded in directory for excluded in EXCLUDE_DIRS)
for file in filenames
if file.endswith(".py") and os.path.getsize(os.path.join(directory, file)) > 0
]
results = {
file_path: parse_file_with_references(project, file_path)
for file_path in tqdm(python_files)
}
return results

View file

@ -1,6 +1,7 @@
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
async def index_data_points(data_points: list[DataPoint]): async def index_data_points(data_points: list[DataPoint]):
created_indexes = {} created_indexes = {}
index_points = {} index_points = {}
@ -80,11 +81,20 @@ if __name__ == "__main__":
class Car(DataPoint): class Car(DataPoint):
model: str model: str
color: str color: str
_metadata = {
"index_fields": ["name"],
"type": "Car"
}
class Person(DataPoint): class Person(DataPoint):
name: str name: str
age: int age: int
owns_car: list[Car] owns_car: list[Car]
_metadata = {
"index_fields": ["name"],
"type": "Person"
}
car1 = Car(model = "Tesla Model S", color = "Blue") car1 = Car(model = "Tesla Model S", color = "Blue")
car2 = Car(model = "Toyota Camry", color = "Red") car2 = Car(model = "Toyota Camry", color = "Red")

View file

@ -10,6 +10,7 @@ class TextSummary(DataPoint):
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"], "index_fields": ["text"],
"type": "TextSummary"
} }
@ -20,4 +21,5 @@ class CodeSummary(DataPoint):
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"], "index_fields": ["text"],
"type": "CodeSummary"
} }

View file

@ -1,39 +1,40 @@
import asyncio import asyncio
from typing import Type from typing import AsyncGenerator, Union
from uuid import uuid5 from uuid import uuid5
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.extraction.extract_summary import extract_summary from cognee.modules.data.extraction.extract_summary import extract_code_summary
from cognee.shared.CodeGraphEntities import CodeFile
from cognee.tasks.storage import add_data_points
from .models import CodeSummary from .models import CodeSummary
async def summarize_code( async def summarize_code(
code_files: list[DataPoint], code_graph_nodes: list[DataPoint],
summarization_model: Type[BaseModel], ) -> AsyncGenerator[Union[DataPoint, CodeSummary], None]:
) -> list[DataPoint]: if len(code_graph_nodes) == 0:
if len(code_files) == 0: return
return code_files
code_files_data_points = [file for file in code_files if isinstance(file, CodeFile)] code_data_points = [file for file in code_graph_nodes if hasattr(file, "source_code")]
file_summaries = await asyncio.gather( file_summaries = await asyncio.gather(
*[extract_summary(file.source_code, summarization_model) for file in code_files_data_points] *[extract_code_summary(file.source_code) for file in code_data_points]
) )
summaries = [ file_summaries_map = {
CodeSummary( code_data_point.extracted_id: str(file_summary)
id = uuid5(file.id, "CodeSummary"), for code_data_point, file_summary in zip(code_data_points, file_summaries)
made_from = file, }
text = file_summaries[file_index].summary,
for node in code_graph_nodes:
if not isinstance(node, DataPoint):
continue
yield node
if not hasattr(node, "source_code"):
continue
yield CodeSummary(
id=uuid5(node.id, "CodeSummary"),
made_from=node,
text=file_summaries_map[node.extracted_id],
) )
for (file_index, file) in enumerate(code_files_data_points)
]
await add_data_points(summaries)
return code_files

View file

@ -31,7 +31,7 @@ def test_AudioDocument():
) )
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT): with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64) GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
): ):
assert ( assert (
ground_truth["word_count"] == paragraph_data.word_count ground_truth["word_count"] == paragraph_data.word_count

View file

@ -21,7 +21,7 @@ def test_ImageDocument():
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT): with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64) GROUND_TRUTH, document.read(chunk_size=64, chunker='text_chunker')
): ):
assert ( assert (
ground_truth["word_count"] == paragraph_data.word_count ground_truth["word_count"] == paragraph_data.word_count

View file

@ -22,7 +22,7 @@ def test_PdfDocument():
) )
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=1024) GROUND_TRUTH, document.read(chunk_size=1024, chunker='text_chunker')
): ):
assert ( assert (
ground_truth["word_count"] == paragraph_data.word_count ground_truth["word_count"] == paragraph_data.word_count

View file

@ -33,7 +33,7 @@ def test_TextDocument(input_file, chunk_size):
) )
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size) GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker='text_chunker')
): ):
assert ( assert (
ground_truth["word_count"] == paragraph_data.word_count ground_truth["word_count"] == paragraph_data.word_count

View file

@ -4,6 +4,7 @@ import pathlib
import cognee import cognee
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.users.methods import get_default_user
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -44,12 +45,13 @@ async def main():
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata = True) await cognee.prune.prune_system(metadata = True)
dataset_name = "cs_explanations" dataset_name_1 = "natural_language"
dataset_name_2 = "quantum"
explanation_file_path = os.path.join( explanation_file_path = os.path.join(
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt" pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
) )
await cognee.add([explanation_file_path], dataset_name) await cognee.add([explanation_file_path], dataset_name_1)
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena. text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states. At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
@ -59,12 +61,23 @@ async def main():
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited. In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
""" """
await cognee.add([text], dataset_name) await cognee.add([text], dataset_name_2)
await cognee.cognify([dataset_name]) await cognee.cognify([dataset_name_2, dataset_name_1])
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
# Test getting of documents for search per dataset
from cognee.modules.users.permissions.methods import get_document_ids_for_user
user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
assert len(document_ids) == 1, f"Number of expected documents doesn't match {len(document_ids)} != 1"
# Test getting of documents for search when no dataset is provided
user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id)
assert len(document_ids) == 2, f"Number of expected documents doesn't match {len(document_ids)} != 2"
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0] random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"] random_node_name = random_node.payload["text"]
@ -75,7 +88,7 @@ async def main():
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
search_results = await cognee.search(SearchType.CHUNKS, query_text = random_node_name) search_results = await cognee.search(SearchType.CHUNKS, query_text = random_node_name, datasets=[dataset_name_2])
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted chunks are:\n") print("\n\nExtracted chunks are:\n")
for result in search_results: for result in search_results:

View file

@ -2,7 +2,7 @@ import asyncio
import random import random
import time import time
from typing import List from typing import List
from uuid import uuid5, NAMESPACE_OID from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model from cognee.modules.graph.utils import get_graph_from_model
@ -11,16 +11,28 @@ random.seed(1500)
class Repository(DataPoint): class Repository(DataPoint):
path: str path: str
_metadata = {
"index_fields": [],
"type": "Repository"
}
class CodeFile(DataPoint): class CodeFile(DataPoint):
part_of: Repository part_of: Repository
contains: List["CodePart"] = [] contains: List["CodePart"] = []
depends_on: List["CodeFile"] = [] depends_on: List["CodeFile"] = []
source_code: str source_code: str
_metadata = {
"index_fields": [],
"type": "CodeFile"
}
class CodePart(DataPoint): class CodePart(DataPoint):
part_of: CodeFile part_of: CodeFile
source_code: str source_code: str
_metadata = {
"index_fields": [],
"type": "CodePart"
}
CodeFile.model_rebuild() CodeFile.model_rebuild()
CodePart.model_rebuild() CodePart.model_rebuild()

View file

@ -1,25 +1,42 @@
import asyncio import asyncio
import random import random
from typing import List from typing import List
from uuid import uuid5, NAMESPACE_OID from uuid import NAMESPACE_OID, uuid5
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model from cognee.modules.graph.utils import get_graph_from_model
class Document(DataPoint): class Document(DataPoint):
path: str path: str
_metadata = {
"index_fields": [],
"type": "Document"
}
class DocumentChunk(DataPoint): class DocumentChunk(DataPoint):
part_of: Document part_of: Document
text: str text: str
contains: List["Entity"] = None contains: List["Entity"] = None
_metadata = {
"index_fields": ["text"],
"type": "DocumentChunk"
}
class EntityType(DataPoint): class EntityType(DataPoint):
name: str name: str
_metadata = {
"index_fields": ["name"],
"type": "EntityType"
}
class Entity(DataPoint): class Entity(DataPoint):
name: str name: str
is_type: EntityType is_type: EntityType
_metadata = {
"index_fields": ["name"],
"type": "Entity"
}
DocumentChunk.model_rebuild() DocumentChunk.model_rebuild()

View file

@ -7,19 +7,13 @@ from pathlib import Path
from swebench.harness.utils import load_swebench_dataset from swebench.harness.utils import load_swebench_dataset
from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.pipelines import Task, run_tasks
from cognee.modules.retrieval.brute_force_triplet_search import \ from cognee.modules.retrieval.brute_force_triplet_search import \
brute_force_triplet_search brute_force_triplet_search
# from cognee.shared.data_models import SummarizedContent
from cognee.shared.utils import render_graph from cognee.shared.utils import render_graph
from cognee.tasks.repo_processor import (enrich_dependency_graph,
expand_dependency_graph,
get_repo_file_dependencies)
from cognee.tasks.storage import add_data_points
# from cognee.tasks.summarization import summarize_code
from evals.eval_utils import download_github_repo, retrieved_edges_to_string from evals.eval_utils import download_github_repo, retrieved_edges_to_string
@ -42,47 +36,21 @@ def check_install_package(package_name):
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS): async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
import os repo_path = download_github_repo(instance, '../RAW_GIT_REPOS')
import pathlib pipeline = await run_code_graph_pipeline(repo_path)
import cognee
from cognee.infrastructure.databases.relational import create_db_and_tables
file_path = Path(__file__).parent
data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve())
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve())
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata = True)
await create_db_and_tables()
# repo_path = download_github_repo(instance, '../RAW_GIT_REPOS')
repo_path = '/Users/borisarzentar/Projects/graphrag'
tasks = [
Task(get_repo_file_dependencies),
Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
Task(add_data_points, task_config = { "batch_size": 50 }),
# Task(summarize_code, summarization_model = SummarizedContent),
]
pipeline = run_tasks(tasks, repo_path, "cognify_code_pipeline")
async for result in pipeline: async for result in pipeline:
print(result) print(result)
print('Here we have the repo under the repo_path') print('Here we have the repo under the repo_path')
await render_graph(None, include_labels = True, include_nodes = True) await render_graph(None, include_labels=True, include_nodes=True)
problem_statement = instance['problem_statement'] problem_statement = instance['problem_statement']
instructions = read_query_prompt("patch_gen_kg_instructions.txt") instructions = read_query_prompt("patch_gen_kg_instructions.txt")
retrieved_edges = await brute_force_triplet_search(problem_statement, top_k = 3, collections = ["data_point_source_code", "data_point_text"]) retrieved_edges = await brute_force_triplet_search(problem_statement, top_k=3,
collections=["data_point_source_code", "data_point_text"])
retrieved_edges_str = retrieved_edges_to_string(retrieved_edges) retrieved_edges_str = retrieved_edges_to_string(retrieved_edges)
@ -171,7 +139,6 @@ async def main():
with open(predictions_path, "w") as file: with open(predictions_path, "w") as file:
json.dump(preds, file) json.dump(preds, file)
subprocess.run( subprocess.run(
[ [
"python", "python",

View file

@ -0,0 +1,15 @@
import argparse
import asyncio
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
async def main(repo_path):
async for result in await run_code_graph_pipeline(repo_path):
print(result)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--repo-path", type=str, required=True, help="Path to the repository")
args = parser.parse_args()
asyncio.run(main(args.repo_path))