diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py new file mode 100644 index 000000000..2cbb606c1 --- /dev/null +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -0,0 +1,110 @@ +import asyncio +import logging +from typing import Union + +from cognee.shared.SourceCodeGraph import SourceCodeGraph +from cognee.shared.utils import send_telemetry +from cognee.modules.data.models import Dataset, Data +from cognee.modules.data.methods.get_dataset_data import get_dataset_data +from cognee.modules.data.methods import get_datasets, get_datasets_by_name +from cognee.modules.pipelines.tasks.Task import Task +from cognee.modules.pipelines import run_tasks +from cognee.modules.users.models import User +from cognee.modules.users.methods import get_default_user +from cognee.modules.pipelines.models import PipelineRunStatus +from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status +from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status +from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents +from cognee.tasks.graph import extract_graph_from_code +from cognee.tasks.storage import add_data_points + +logger = logging.getLogger("code_graph_pipeline") + +update_status_lock = asyncio.Lock() + +class PermissionDeniedException(Exception): + def __init__(self, message: str): + self.message = message + super().__init__(self.message) + +async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None): + if user is None: + user = await get_default_user() + + existing_datasets = await get_datasets(user.id) + + if datasets is None or len(datasets) == 0: + # If no datasets are provided, cognify all existing datasets. + datasets = existing_datasets + + if type(datasets[0]) == str: + datasets = await get_datasets_by_name(datasets, user.id) + + existing_datasets_map = { + generate_dataset_name(dataset.name): True for dataset in existing_datasets + } + + awaitables = [] + + for dataset in datasets: + dataset_name = generate_dataset_name(dataset.name) + + if dataset_name in existing_datasets_map: + awaitables.append(run_pipeline(dataset, user)) + + return await asyncio.gather(*awaitables) + + +async def run_pipeline(dataset: Dataset, user: User): + data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id) + + document_ids_str = [str(document.id) for document in data_documents] + + dataset_id = dataset.id + dataset_name = generate_dataset_name(dataset.name) + + send_telemetry("code_graph_pipeline EXECUTION STARTED", user.id) + + async with update_status_lock: + task_status = await get_pipeline_status([dataset_id]) + + if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED: + logger.info("Dataset %s is already being processed.", dataset_name) + return + + await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, { + "dataset_name": dataset_name, + "files": document_ids_str, + }) + try: + tasks = [ + Task(classify_documents), + Task(check_permissions_on_documents, user = user, permissions = ["write"]), + Task(extract_chunks_from_documents), # Extract text chunks based on the document type. + Task(add_data_points, task_config = { "batch_size": 10 }), + Task(extract_graph_from_code, graph_model = SourceCodeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks. + ] + + pipeline = run_tasks(tasks, data_documents, "code_graph_pipeline") + + async for result in pipeline: + print(result) + + send_telemetry("code_graph_pipeline EXECUTION COMPLETED", user.id) + + await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_COMPLETED, { + "dataset_name": dataset_name, + "files": document_ids_str, + }) + except Exception as error: + send_telemetry("code_graph_pipeline EXECUTION ERRORED", user.id) + + await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_ERRORED, { + "dataset_name": dataset_name, + "files": document_ids_str, + }) + raise error + + +def generate_dataset_name(dataset_name: str) -> str: + return dataset_name.replace(".", "_").replace(" ", "_") diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index b106e9feb..6c7abd498 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -30,6 +30,10 @@ class NetworkXAdapter(GraphDBInterface): def __init__(self, filename = "cognee_graph.pkl"): self.filename = filename + async def get_graph_data(self): + await self.load_graph_from_file() + return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True))) + async def query(self, query: str, params: dict): pass @@ -247,15 +251,27 @@ class NetworkXAdapter(GraphDBInterface): async with aiofiles.open(file_path, "r") as file: graph_data = json.loads(await file.read()) for node in graph_data["nodes"]: - node["id"] = UUID(node["id"]) - node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") + try: + node["id"] = UUID(node["id"]) + except: + pass + if "updated_at" in node: + node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") for edge in graph_data["links"]: - edge["source"] = UUID(edge["source"]) - edge["target"] = UUID(edge["target"]) - edge["source_node_id"] = UUID(edge["source_node_id"]) - edge["target_node_id"] = UUID(edge["target_node_id"]) - edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") + try: + source_id = UUID(edge["source"]) + target_id = UUID(edge["target"]) + + edge["source"] = source_id + edge["target"] = target_id + edge["source_node_id"] = source_id + edge["target_node_id"] = target_id + except: + pass + + if "updated_at" in node: + edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") self.graph = nx.readwrite.json_graph.node_link_graph(graph_data) else: @@ -268,8 +284,8 @@ class NetworkXAdapter(GraphDBInterface): os.makedirs(file_dir, exist_ok = True) await self.save_graph_to_file(file_path) - except Exception: - logger.error("Failed to load graph from file: %s", file_path) + except Exception as e: + logger.error("Failed to load graph from file: %s \n %s", file_path, str(e)) # Initialize an empty graph in case of error self.graph = nx.MultiDiGraph() diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 39e43189c..d883a29e7 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -164,7 +164,16 @@ class LanceDBAdapter(VectorDBInterface): if value < min_value: min_value = value - normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values] + normalized_values = [] + min_value = min(result["_distance"] for result in result_values) + max_value = max(result["_distance"] for result in result_values) + + if max_value == min_value: + # Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1) + normalized_values = [0 for _ in result_values] + else: + normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in + result_values] return [ScoredResult( id = UUID(result["id"]), diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 35e00fb5d..29137ddc7 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -43,7 +43,7 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes added_edges[str(edge_key)] = True continue - if isinstance(field_value, list) and isinstance(field_value[0], DataPoint): + if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): excluded_properties.add(field_name) for item in field_value: diff --git a/cognee/shared/SourceCodeGraph.py b/cognee/shared/SourceCodeGraph.py index 51b90f296..60f425e32 100644 --- a/cognee/shared/SourceCodeGraph.py +++ b/cognee/shared/SourceCodeGraph.py @@ -1,84 +1,95 @@ -from typing import List, Union, Literal, Optional -from pydantic import BaseModel +from typing import Any, List, Union, Literal, Optional +from cognee.infrastructure.engine import DataPoint -class BaseClass(BaseModel): - id: str - name: str - type: Literal["Class"] = "Class" - description: str - constructor_parameters: Optional[List[str]] = None - -class Class(BaseModel): - id: str - name: str - type: Literal["Class"] = "Class" - description: str - constructor_parameters: Optional[List[str]] = None - from_class: Optional[BaseClass] = None - -class ClassInstance(BaseModel): - id: str - name: str - type: Literal["ClassInstance"] = "ClassInstance" - description: str - from_class: Class - -class Function(BaseModel): - id: str - name: str - type: Literal["Function"] = "Function" - description: str - parameters: Optional[List[str]] = None - return_type: str - is_static: Optional[bool] = False - -class Variable(BaseModel): +class Variable(DataPoint): id: str name: str type: Literal["Variable"] = "Variable" description: str is_static: Optional[bool] = False default_value: Optional[str] = None + data_type: str -class Operator(BaseModel): + _metadata = { + "index_fields": ["name"] + } + +class Operator(DataPoint): id: str name: str type: Literal["Operator"] = "Operator" description: str return_type: str -class ExpressionPart(BaseModel): +class Class(DataPoint): + id: str + name: str + type: Literal["Class"] = "Class" + description: str + constructor_parameters: List[Variable] + extended_from_class: Optional["Class"] = None + has_methods: list["Function"] + + _metadata = { + "index_fields": ["name"] + } + +class ClassInstance(DataPoint): + id: str + name: str + type: Literal["ClassInstance"] = "ClassInstance" + description: str + from_class: Class + instantiated_by: Union["Function"] + instantiation_arguments: List[Variable] + + _metadata = { + "index_fields": ["name"] + } + +class Function(DataPoint): + id: str + name: str + type: Literal["Function"] = "Function" + description: str + parameters: List[Variable] + return_type: str + is_static: Optional[bool] = False + + _metadata = { + "index_fields": ["name"] + } + +class FunctionCall(DataPoint): + id: str + type: Literal["FunctionCall"] = "FunctionCall" + called_by: Union[Function, Literal["main"]] + function_called: Function + function_arguments: List[Any] + +class Expression(DataPoint): id: str name: str type: Literal["Expression"] = "Expression" description: str expression: str - members: List[Union[Variable, Function, Operator]] + members: List[Union[Variable, Function, Operator, "Expression"]] -class Expression(BaseModel): - id: str - name: str - type: Literal["Expression"] = "Expression" - description: str - expression: str - members: List[Union[Variable, Function, Operator, ExpressionPart]] - -class Edge(BaseModel): - source_node_id: str - target_node_id: str - relationship_name: Literal["called in", "stored in", "defined in", "returned by", "instantiated in", "uses", "updates"] - -class SourceCodeGraph(BaseModel): +class SourceCodeGraph(DataPoint): id: str name: str description: str language: str nodes: List[Union[ Class, + ClassInstance, Function, + FunctionCall, Variable, Operator, Expression, - ClassInstance, ]] - edges: List[Edge] + +Class.model_rebuild() +ClassInstance.model_rebuild() +Expression.model_rebuild() diff --git a/cognee/shared/utils.py b/cognee/shared/utils.py index f3272357f..14578f202 100644 --- a/cognee/shared/utils.py +++ b/cognee/shared/utils.py @@ -91,7 +91,7 @@ def prepare_edges(graph, source, target, edge_key): source: str(edge[0]), target: str(edge[1]), edge_key: str(edge[2]), - } for edge in graph.edges] + } for edge in graph.edges(keys = True, data = True)] return pd.DataFrame(edge_list) diff --git a/cognee/tasks/graph/__init__.py b/cognee/tasks/graph/__init__.py index 94dc82f20..eafc12921 100644 --- a/cognee/tasks/graph/__init__.py +++ b/cognee/tasks/graph/__init__.py @@ -1,2 +1,3 @@ from .extract_graph_from_data import extract_graph_from_data +from .extract_graph_from_code import extract_graph_from_code from .query_graph_connections import query_graph_connections diff --git a/cognee/tasks/graph/extract_graph_from_code.py b/cognee/tasks/graph/extract_graph_from_code.py new file mode 100644 index 000000000..159e9baa4 --- /dev/null +++ b/cognee/tasks/graph/extract_graph_from_code.py @@ -0,0 +1,17 @@ +import asyncio +from typing import Type +from pydantic import BaseModel +from cognee.modules.data.extraction.knowledge_graph import extract_content_graph +from cognee.modules.chunking.models.DocumentChunk import DocumentChunk +from cognee.tasks.storage import add_data_points + +async def extract_graph_from_code(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]): + chunk_graphs = await asyncio.gather( + *[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks] + ) + + for (chunk_index, chunk) in enumerate(data_chunks): + chunk_graph = chunk_graphs[chunk_index] + await add_data_points(chunk_graph.nodes) + + return data_chunks diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 681fbaa1f..dc74d705d 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -47,7 +47,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) -> added_data_points[str(new_point.id)] = True data_points.append(new_point) - if isinstance(field_value, list) and isinstance(field_value[0], DataPoint): + if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): for field_value_item in field_value: new_data_points = get_data_points_from_model(field_value_item, added_data_points) diff --git a/cognee/tests/test_code_generation.py b/cognee/tests/test_code_generation.py new file mode 100755 index 000000000..aad59ace8 --- /dev/null +++ b/cognee/tests/test_code_generation.py @@ -0,0 +1,38 @@ +import os +import logging +import pathlib +import cognee +from cognee.api.v1.cognify.code_graph_pipeline import code_graph_pipeline +from cognee.api.v1.search import SearchType +from cognee.shared.utils import render_graph + +logging.basicConfig(level = logging.DEBUG) + +async def main(): + data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_code_generation")).resolve()) + cognee.config.data_root_directory(data_directory_path) + cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_code_generation")).resolve()) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata = True) + + dataset_name = "artificial_intelligence" + + ai_text_file_path = os.path.join(pathlib.Path(__file__).parent, "test_data/code.txt") + await cognee.add([ai_text_file_path], dataset_name) + + await code_graph_pipeline([dataset_name]) + + await render_graph(None, include_nodes = True, include_labels = True) + + search_results = await cognee.search(SearchType.CHUNKS, query = "Student") + assert len(search_results) != 0, "The search results list is empty." + print("\n\nExtracted chunks are:\n") + for result in search_results: + print(f"{result}\n") + + +if __name__ == "__main__": + import asyncio + asyncio.run(main(), debug=True) diff --git a/cognee/tests/test_data/code.txt b/cognee/tests/test_data/code.txt new file mode 100644 index 000000000..c40f7124a --- /dev/null +++ b/cognee/tests/test_data/code.txt @@ -0,0 +1,70 @@ +// Class definition for a Person +class Person { + constructor(name, age) { + this.name = name; + this.age = age; + } + + // Method to return a greeting message + greet() { + return `Hello, my name is ${this.name} and I'm ${this.age} years old.`; + } + + // Method to celebrate birthday + celebrateBirthday() { + this.age += 1; + return `Happy Birthday, ${this.name}! You are now ${this.age} years old.`; + } +} + +// Class definition for a Student, extending from Person +class Student extends Person { + constructor(name, age, grade) { + super(name, age); + this.grade = grade; + } + + // Method to describe the student + describe() { + return `${this.name} is a ${this.grade} grade student and is ${this.age} years old.`; + } +} + +// Function to enroll a new student +function enrollStudent(name, age, grade) { + const student = new Student(name, age, grade); + console.log(student.greet()); + console.log(student.describe()); + return student; +} + +// Function to promote a student to the next grade +function promoteStudent(student) { + student.grade += 1; + console.log(`${student.name} has been promoted to grade ${student.grade}.`); + return student; +} + +// Variable definition and assignment +let schoolName = "Greenwood High School"; +let students = []; + +// Enrolling students +students.push(enrollStudent("Alice", 14, 9)); +students.push(enrollStudent("Bob", 15, 10)); + +// Looping through students to celebrate their birthdays +students.forEach(student => { + console.log(student.celebrateBirthday()); +}); + +// Promoting all students +students = students.map(promoteStudent); + +// Displaying the final state of all students +console.log("Final Students List:"); +students.forEach(student => console.log(student.describe())); + +// Updating the school name +schoolName = "Greenwood International School"; +console.log(`School Name Updated to: ${schoolName}`);