fix: add code graph generation pipeline

This commit is contained in:
Boris Arzentar 2024-11-08 15:31:02 +01:00
parent c89063602e
commit 19d62f2c84
11 changed files with 338 additions and 66 deletions

View file

@ -0,0 +1,110 @@
import asyncio
import logging
from typing import Union
from cognee.shared.SourceCodeGraph import SourceCodeGraph
from cognee.shared.utils import send_telemetry
from cognee.modules.data.models import Dataset, Data
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines import run_tasks
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_code
from cognee.tasks.storage import add_data_points
logger = logging.getLogger("code_graph_pipeline")
update_status_lock = asyncio.Lock()
class PermissionDeniedException(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None):
if user is None:
user = await get_default_user()
existing_datasets = await get_datasets(user.id)
if datasets is None or len(datasets) == 0:
# If no datasets are provided, cognify all existing datasets.
datasets = existing_datasets
if type(datasets[0]) == str:
datasets = await get_datasets_by_name(datasets, user.id)
existing_datasets_map = {
generate_dataset_name(dataset.name): True for dataset in existing_datasets
}
awaitables = []
for dataset in datasets:
dataset_name = generate_dataset_name(dataset.name)
if dataset_name in existing_datasets_map:
awaitables.append(run_pipeline(dataset, user))
return await asyncio.gather(*awaitables)
async def run_pipeline(dataset: Dataset, user: User):
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
document_ids_str = [str(document.id) for document in data_documents]
dataset_id = dataset.id
dataset_name = generate_dataset_name(dataset.name)
send_telemetry("code_graph_pipeline EXECUTION STARTED", user.id)
async with update_status_lock:
task_status = await get_pipeline_status([dataset_id])
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
logger.info("Dataset %s is already being processed.", dataset_name)
return
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
try:
tasks = [
Task(classify_documents),
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
Task(add_data_points, task_config = { "batch_size": 10 }),
Task(extract_graph_from_code, graph_model = SourceCodeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
]
pipeline = run_tasks(tasks, data_documents, "code_graph_pipeline")
async for result in pipeline:
print(result)
send_telemetry("code_graph_pipeline EXECUTION COMPLETED", user.id)
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_COMPLETED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
except Exception as error:
send_telemetry("code_graph_pipeline EXECUTION ERRORED", user.id)
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_ERRORED, {
"dataset_name": dataset_name,
"files": document_ids_str,
})
raise error
def generate_dataset_name(dataset_name: str) -> str:
return dataset_name.replace(".", "_").replace(" ", "_")

View file

@ -30,6 +30,10 @@ class NetworkXAdapter(GraphDBInterface):
def __init__(self, filename = "cognee_graph.pkl"):
self.filename = filename
async def get_graph_data(self):
await self.load_graph_from_file()
return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True)))
async def query(self, query: str, params: dict):
pass
@ -247,15 +251,27 @@ class NetworkXAdapter(GraphDBInterface):
async with aiofiles.open(file_path, "r") as file:
graph_data = json.loads(await file.read())
for node in graph_data["nodes"]:
node["id"] = UUID(node["id"])
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
try:
node["id"] = UUID(node["id"])
except:
pass
if "updated_at" in node:
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
for edge in graph_data["links"]:
edge["source"] = UUID(edge["source"])
edge["target"] = UUID(edge["target"])
edge["source_node_id"] = UUID(edge["source_node_id"])
edge["target_node_id"] = UUID(edge["target_node_id"])
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
try:
source_id = UUID(edge["source"])
target_id = UUID(edge["target"])
edge["source"] = source_id
edge["target"] = target_id
edge["source_node_id"] = source_id
edge["target_node_id"] = target_id
except:
pass
if "updated_at" in node:
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
else:
@ -268,8 +284,8 @@ class NetworkXAdapter(GraphDBInterface):
os.makedirs(file_dir, exist_ok = True)
await self.save_graph_to_file(file_path)
except Exception:
logger.error("Failed to load graph from file: %s", file_path)
except Exception as e:
logger.error("Failed to load graph from file: %s \n %s", file_path, str(e))
# Initialize an empty graph in case of error
self.graph = nx.MultiDiGraph()

View file

@ -164,7 +164,16 @@ class LanceDBAdapter(VectorDBInterface):
if value < min_value:
min_value = value
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
normalized_values = []
min_value = min(result["_distance"] for result in result_values)
max_value = max(result["_distance"] for result in result_values)
if max_value == min_value:
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
normalized_values = [0 for _ in result_values]
else:
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
result_values]
return [ScoredResult(
id = UUID(result["id"]),

View file

@ -43,7 +43,7 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
added_edges[str(edge_key)] = True
continue
if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name)
for item in field_value:

View file

@ -1,84 +1,95 @@
from typing import List, Union, Literal, Optional
from pydantic import BaseModel
from typing import Any, List, Union, Literal, Optional
from cognee.infrastructure.engine import DataPoint
class BaseClass(BaseModel):
id: str
name: str
type: Literal["Class"] = "Class"
description: str
constructor_parameters: Optional[List[str]] = None
class Class(BaseModel):
id: str
name: str
type: Literal["Class"] = "Class"
description: str
constructor_parameters: Optional[List[str]] = None
from_class: Optional[BaseClass] = None
class ClassInstance(BaseModel):
id: str
name: str
type: Literal["ClassInstance"] = "ClassInstance"
description: str
from_class: Class
class Function(BaseModel):
id: str
name: str
type: Literal["Function"] = "Function"
description: str
parameters: Optional[List[str]] = None
return_type: str
is_static: Optional[bool] = False
class Variable(BaseModel):
class Variable(DataPoint):
id: str
name: str
type: Literal["Variable"] = "Variable"
description: str
is_static: Optional[bool] = False
default_value: Optional[str] = None
data_type: str
class Operator(BaseModel):
_metadata = {
"index_fields": ["name"]
}
class Operator(DataPoint):
id: str
name: str
type: Literal["Operator"] = "Operator"
description: str
return_type: str
class ExpressionPart(BaseModel):
class Class(DataPoint):
id: str
name: str
type: Literal["Class"] = "Class"
description: str
constructor_parameters: List[Variable]
extended_from_class: Optional["Class"] = None
has_methods: list["Function"]
_metadata = {
"index_fields": ["name"]
}
class ClassInstance(DataPoint):
id: str
name: str
type: Literal["ClassInstance"] = "ClassInstance"
description: str
from_class: Class
instantiated_by: Union["Function"]
instantiation_arguments: List[Variable]
_metadata = {
"index_fields": ["name"]
}
class Function(DataPoint):
id: str
name: str
type: Literal["Function"] = "Function"
description: str
parameters: List[Variable]
return_type: str
is_static: Optional[bool] = False
_metadata = {
"index_fields": ["name"]
}
class FunctionCall(DataPoint):
id: str
type: Literal["FunctionCall"] = "FunctionCall"
called_by: Union[Function, Literal["main"]]
function_called: Function
function_arguments: List[Any]
class Expression(DataPoint):
id: str
name: str
type: Literal["Expression"] = "Expression"
description: str
expression: str
members: List[Union[Variable, Function, Operator]]
members: List[Union[Variable, Function, Operator, "Expression"]]
class Expression(BaseModel):
id: str
name: str
type: Literal["Expression"] = "Expression"
description: str
expression: str
members: List[Union[Variable, Function, Operator, ExpressionPart]]
class Edge(BaseModel):
source_node_id: str
target_node_id: str
relationship_name: Literal["called in", "stored in", "defined in", "returned by", "instantiated in", "uses", "updates"]
class SourceCodeGraph(BaseModel):
class SourceCodeGraph(DataPoint):
id: str
name: str
description: str
language: str
nodes: List[Union[
Class,
ClassInstance,
Function,
FunctionCall,
Variable,
Operator,
Expression,
ClassInstance,
]]
edges: List[Edge]
Class.model_rebuild()
ClassInstance.model_rebuild()
Expression.model_rebuild()

View file

@ -91,7 +91,7 @@ def prepare_edges(graph, source, target, edge_key):
source: str(edge[0]),
target: str(edge[1]),
edge_key: str(edge[2]),
} for edge in graph.edges]
} for edge in graph.edges(keys = True, data = True)]
return pd.DataFrame(edge_list)

View file

@ -1,2 +1,3 @@
from .extract_graph_from_data import extract_graph_from_data
from .extract_graph_from_code import extract_graph_from_code
from .query_graph_connections import query_graph_connections

View file

@ -0,0 +1,17 @@
import asyncio
from typing import Type
from pydantic import BaseModel
from cognee.modules.data.extraction.knowledge_graph import extract_content_graph
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.tasks.storage import add_data_points
async def extract_graph_from_code(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
for (chunk_index, chunk) in enumerate(data_chunks):
chunk_graph = chunk_graphs[chunk_index]
await add_data_points(chunk_graph.nodes)
return data_chunks

View file

@ -47,7 +47,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
added_data_points[str(new_point.id)] = True
data_points.append(new_point)
if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
for field_value_item in field_value:
new_data_points = get_data_points_from_model(field_value_item, added_data_points)

View file

@ -0,0 +1,38 @@
import os
import logging
import pathlib
import cognee
from cognee.api.v1.cognify.code_graph_pipeline import code_graph_pipeline
from cognee.api.v1.search import SearchType
from cognee.shared.utils import render_graph
logging.basicConfig(level = logging.DEBUG)
async def main():
data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_code_generation")).resolve())
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_code_generation")).resolve())
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata = True)
dataset_name = "artificial_intelligence"
ai_text_file_path = os.path.join(pathlib.Path(__file__).parent, "test_data/code.txt")
await cognee.add([ai_text_file_path], dataset_name)
await code_graph_pipeline([dataset_name])
await render_graph(None, include_nodes = True, include_labels = True)
search_results = await cognee.search(SearchType.CHUNKS, query = "Student")
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted chunks are:\n")
for result in search_results:
print(f"{result}\n")
if __name__ == "__main__":
import asyncio
asyncio.run(main(), debug=True)

View file

@ -0,0 +1,70 @@
// Class definition for a Person
class Person {
constructor(name, age) {
this.name = name;
this.age = age;
}
// Method to return a greeting message
greet() {
return `Hello, my name is ${this.name} and I'm ${this.age} years old.`;
}
// Method to celebrate birthday
celebrateBirthday() {
this.age += 1;
return `Happy Birthday, ${this.name}! You are now ${this.age} years old.`;
}
}
// Class definition for a Student, extending from Person
class Student extends Person {
constructor(name, age, grade) {
super(name, age);
this.grade = grade;
}
// Method to describe the student
describe() {
return `${this.name} is a ${this.grade} grade student and is ${this.age} years old.`;
}
}
// Function to enroll a new student
function enrollStudent(name, age, grade) {
const student = new Student(name, age, grade);
console.log(student.greet());
console.log(student.describe());
return student;
}
// Function to promote a student to the next grade
function promoteStudent(student) {
student.grade += 1;
console.log(`${student.name} has been promoted to grade ${student.grade}.`);
return student;
}
// Variable definition and assignment
let schoolName = "Greenwood High School";
let students = [];
// Enrolling students
students.push(enrollStudent("Alice", 14, 9));
students.push(enrollStudent("Bob", 15, 10));
// Looping through students to celebrate their birthdays
students.forEach(student => {
console.log(student.celebrateBirthday());
});
// Promoting all students
students = students.map(promoteStudent);
// Displaying the final state of all students
console.log("Final Students List:");
students.forEach(student => console.log(student.describe()));
// Updating the school name
schoolName = "Greenwood International School";
console.log(`School Name Updated to: ${schoolName}`);