fix: add code graph generation pipeline
This commit is contained in:
parent
c89063602e
commit
19d62f2c84
11 changed files with 338 additions and 66 deletions
110
cognee/api/v1/cognify/code_graph_pipeline.py
Normal file
110
cognee/api/v1/cognify/code_graph_pipeline.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
from cognee.shared.SourceCodeGraph import SourceCodeGraph
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.data.models import Dataset, Data
|
||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
|
||||
from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.graph import extract_graph_from_code
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
logger = logging.getLogger("code_graph_pipeline")
|
||||
|
||||
update_status_lock = asyncio.Lock()
|
||||
|
||||
class PermissionDeniedException(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
async def code_graph_pipeline(datasets: Union[str, list[str]] = None, user: User = None):
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
existing_datasets = await get_datasets(user.id)
|
||||
|
||||
if datasets is None or len(datasets) == 0:
|
||||
# If no datasets are provided, cognify all existing datasets.
|
||||
datasets = existing_datasets
|
||||
|
||||
if type(datasets[0]) == str:
|
||||
datasets = await get_datasets_by_name(datasets, user.id)
|
||||
|
||||
existing_datasets_map = {
|
||||
generate_dataset_name(dataset.name): True for dataset in existing_datasets
|
||||
}
|
||||
|
||||
awaitables = []
|
||||
|
||||
for dataset in datasets:
|
||||
dataset_name = generate_dataset_name(dataset.name)
|
||||
|
||||
if dataset_name in existing_datasets_map:
|
||||
awaitables.append(run_pipeline(dataset, user))
|
||||
|
||||
return await asyncio.gather(*awaitables)
|
||||
|
||||
|
||||
async def run_pipeline(dataset: Dataset, user: User):
|
||||
data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)
|
||||
|
||||
document_ids_str = [str(document.id) for document in data_documents]
|
||||
|
||||
dataset_id = dataset.id
|
||||
dataset_name = generate_dataset_name(dataset.name)
|
||||
|
||||
send_telemetry("code_graph_pipeline EXECUTION STARTED", user.id)
|
||||
|
||||
async with update_status_lock:
|
||||
task_status = await get_pipeline_status([dataset_id])
|
||||
|
||||
if dataset_id in task_status and task_status[dataset_id] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
|
||||
logger.info("Dataset %s is already being processed.", dataset_name)
|
||||
return
|
||||
|
||||
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_STARTED, {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
try:
|
||||
tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
|
||||
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
|
||||
Task(add_data_points, task_config = { "batch_size": 10 }),
|
||||
Task(extract_graph_from_code, graph_model = SourceCodeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, data_documents, "code_graph_pipeline")
|
||||
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
|
||||
send_telemetry("code_graph_pipeline EXECUTION COMPLETED", user.id)
|
||||
|
||||
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_COMPLETED, {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
except Exception as error:
|
||||
send_telemetry("code_graph_pipeline EXECUTION ERRORED", user.id)
|
||||
|
||||
await log_pipeline_status(dataset_id, PipelineRunStatus.DATASET_PROCESSING_ERRORED, {
|
||||
"dataset_name": dataset_name,
|
||||
"files": document_ids_str,
|
||||
})
|
||||
raise error
|
||||
|
||||
|
||||
def generate_dataset_name(dataset_name: str) -> str:
|
||||
return dataset_name.replace(".", "_").replace(" ", "_")
|
||||
|
|
@ -30,6 +30,10 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
def __init__(self, filename = "cognee_graph.pkl"):
|
||||
self.filename = filename
|
||||
|
||||
async def get_graph_data(self):
|
||||
await self.load_graph_from_file()
|
||||
return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True)))
|
||||
|
||||
async def query(self, query: str, params: dict):
|
||||
pass
|
||||
|
||||
|
|
@ -247,15 +251,27 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
async with aiofiles.open(file_path, "r") as file:
|
||||
graph_data = json.loads(await file.read())
|
||||
for node in graph_data["nodes"]:
|
||||
node["id"] = UUID(node["id"])
|
||||
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
try:
|
||||
node["id"] = UUID(node["id"])
|
||||
except:
|
||||
pass
|
||||
if "updated_at" in node:
|
||||
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
|
||||
for edge in graph_data["links"]:
|
||||
edge["source"] = UUID(edge["source"])
|
||||
edge["target"] = UUID(edge["target"])
|
||||
edge["source_node_id"] = UUID(edge["source_node_id"])
|
||||
edge["target_node_id"] = UUID(edge["target_node_id"])
|
||||
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
try:
|
||||
source_id = UUID(edge["source"])
|
||||
target_id = UUID(edge["target"])
|
||||
|
||||
edge["source"] = source_id
|
||||
edge["target"] = target_id
|
||||
edge["source_node_id"] = source_id
|
||||
edge["target_node_id"] = target_id
|
||||
except:
|
||||
pass
|
||||
|
||||
if "updated_at" in node:
|
||||
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
|
||||
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
|
||||
else:
|
||||
|
|
@ -268,8 +284,8 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
os.makedirs(file_dir, exist_ok = True)
|
||||
|
||||
await self.save_graph_to_file(file_path)
|
||||
except Exception:
|
||||
logger.error("Failed to load graph from file: %s", file_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load graph from file: %s \n %s", file_path, str(e))
|
||||
# Initialize an empty graph in case of error
|
||||
self.graph = nx.MultiDiGraph()
|
||||
|
||||
|
|
|
|||
|
|
@ -164,7 +164,16 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
if value < min_value:
|
||||
min_value = value
|
||||
|
||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
|
||||
normalized_values = []
|
||||
min_value = min(result["_distance"] for result in result_values)
|
||||
max_value = max(result["_distance"] for result in result_values)
|
||||
|
||||
if max_value == min_value:
|
||||
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
|
||||
normalized_values = [0 for _ in result_values]
|
||||
else:
|
||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
|
||||
result_values]
|
||||
|
||||
return [ScoredResult(
|
||||
id = UUID(result["id"]),
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
|
|||
added_edges[str(edge_key)] = True
|
||||
continue
|
||||
|
||||
if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
|
||||
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
for item in field_value:
|
||||
|
|
|
|||
|
|
@ -1,84 +1,95 @@
|
|||
from typing import List, Union, Literal, Optional
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, List, Union, Literal, Optional
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
class BaseClass(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Class"] = "Class"
|
||||
description: str
|
||||
constructor_parameters: Optional[List[str]] = None
|
||||
|
||||
class Class(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Class"] = "Class"
|
||||
description: str
|
||||
constructor_parameters: Optional[List[str]] = None
|
||||
from_class: Optional[BaseClass] = None
|
||||
|
||||
class ClassInstance(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["ClassInstance"] = "ClassInstance"
|
||||
description: str
|
||||
from_class: Class
|
||||
|
||||
class Function(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Function"] = "Function"
|
||||
description: str
|
||||
parameters: Optional[List[str]] = None
|
||||
return_type: str
|
||||
is_static: Optional[bool] = False
|
||||
|
||||
class Variable(BaseModel):
|
||||
class Variable(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Variable"] = "Variable"
|
||||
description: str
|
||||
is_static: Optional[bool] = False
|
||||
default_value: Optional[str] = None
|
||||
data_type: str
|
||||
|
||||
class Operator(BaseModel):
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
}
|
||||
|
||||
class Operator(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Operator"] = "Operator"
|
||||
description: str
|
||||
return_type: str
|
||||
|
||||
class ExpressionPart(BaseModel):
|
||||
class Class(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Class"] = "Class"
|
||||
description: str
|
||||
constructor_parameters: List[Variable]
|
||||
extended_from_class: Optional["Class"] = None
|
||||
has_methods: list["Function"]
|
||||
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
}
|
||||
|
||||
class ClassInstance(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["ClassInstance"] = "ClassInstance"
|
||||
description: str
|
||||
from_class: Class
|
||||
instantiated_by: Union["Function"]
|
||||
instantiation_arguments: List[Variable]
|
||||
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
}
|
||||
|
||||
class Function(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Function"] = "Function"
|
||||
description: str
|
||||
parameters: List[Variable]
|
||||
return_type: str
|
||||
is_static: Optional[bool] = False
|
||||
|
||||
_metadata = {
|
||||
"index_fields": ["name"]
|
||||
}
|
||||
|
||||
class FunctionCall(DataPoint):
|
||||
id: str
|
||||
type: Literal["FunctionCall"] = "FunctionCall"
|
||||
called_by: Union[Function, Literal["main"]]
|
||||
function_called: Function
|
||||
function_arguments: List[Any]
|
||||
|
||||
class Expression(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Expression"] = "Expression"
|
||||
description: str
|
||||
expression: str
|
||||
members: List[Union[Variable, Function, Operator]]
|
||||
members: List[Union[Variable, Function, Operator, "Expression"]]
|
||||
|
||||
class Expression(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: Literal["Expression"] = "Expression"
|
||||
description: str
|
||||
expression: str
|
||||
members: List[Union[Variable, Function, Operator, ExpressionPart]]
|
||||
|
||||
class Edge(BaseModel):
|
||||
source_node_id: str
|
||||
target_node_id: str
|
||||
relationship_name: Literal["called in", "stored in", "defined in", "returned by", "instantiated in", "uses", "updates"]
|
||||
|
||||
class SourceCodeGraph(BaseModel):
|
||||
class SourceCodeGraph(DataPoint):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
language: str
|
||||
nodes: List[Union[
|
||||
Class,
|
||||
ClassInstance,
|
||||
Function,
|
||||
FunctionCall,
|
||||
Variable,
|
||||
Operator,
|
||||
Expression,
|
||||
ClassInstance,
|
||||
]]
|
||||
edges: List[Edge]
|
||||
|
||||
Class.model_rebuild()
|
||||
ClassInstance.model_rebuild()
|
||||
Expression.model_rebuild()
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ def prepare_edges(graph, source, target, edge_key):
|
|||
source: str(edge[0]),
|
||||
target: str(edge[1]),
|
||||
edge_key: str(edge[2]),
|
||||
} for edge in graph.edges]
|
||||
} for edge in graph.edges(keys = True, data = True)]
|
||||
|
||||
return pd.DataFrame(edge_list)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
from .extract_graph_from_data import extract_graph_from_data
|
||||
from .extract_graph_from_code import extract_graph_from_code
|
||||
from .query_graph_connections import query_graph_connections
|
||||
|
|
|
|||
17
cognee/tasks/graph/extract_graph_from_code.py
Normal file
17
cognee/tasks/graph/extract_graph_from_code.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
import asyncio
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
from cognee.modules.data.extraction.knowledge_graph import extract_content_graph
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
async def extract_graph_from_code(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
|
||||
chunk_graphs = await asyncio.gather(
|
||||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
)
|
||||
|
||||
for (chunk_index, chunk) in enumerate(data_chunks):
|
||||
chunk_graph = chunk_graphs[chunk_index]
|
||||
await add_data_points(chunk_graph.nodes)
|
||||
|
||||
return data_chunks
|
||||
|
|
@ -47,7 +47,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
|
|||
added_data_points[str(new_point.id)] = True
|
||||
data_points.append(new_point)
|
||||
|
||||
if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
|
||||
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||
for field_value_item in field_value:
|
||||
new_data_points = get_data_points_from_model(field_value_item, added_data_points)
|
||||
|
||||
|
|
|
|||
38
cognee/tests/test_code_generation.py
Executable file
38
cognee/tests/test_code_generation.py
Executable file
|
|
@ -0,0 +1,38 @@
|
|||
import os
|
||||
import logging
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import code_graph_pipeline
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.shared.utils import render_graph
|
||||
|
||||
logging.basicConfig(level = logging.DEBUG)
|
||||
|
||||
async def main():
|
||||
data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_code_generation")).resolve())
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_code_generation")).resolve())
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata = True)
|
||||
|
||||
dataset_name = "artificial_intelligence"
|
||||
|
||||
ai_text_file_path = os.path.join(pathlib.Path(__file__).parent, "test_data/code.txt")
|
||||
await cognee.add([ai_text_file_path], dataset_name)
|
||||
|
||||
await code_graph_pipeline([dataset_name])
|
||||
|
||||
await render_graph(None, include_nodes = True, include_labels = True)
|
||||
|
||||
search_results = await cognee.search(SearchType.CHUNKS, query = "Student")
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted chunks are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main(), debug=True)
|
||||
70
cognee/tests/test_data/code.txt
Normal file
70
cognee/tests/test_data/code.txt
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
// Class definition for a Person
|
||||
class Person {
|
||||
constructor(name, age) {
|
||||
this.name = name;
|
||||
this.age = age;
|
||||
}
|
||||
|
||||
// Method to return a greeting message
|
||||
greet() {
|
||||
return `Hello, my name is ${this.name} and I'm ${this.age} years old.`;
|
||||
}
|
||||
|
||||
// Method to celebrate birthday
|
||||
celebrateBirthday() {
|
||||
this.age += 1;
|
||||
return `Happy Birthday, ${this.name}! You are now ${this.age} years old.`;
|
||||
}
|
||||
}
|
||||
|
||||
// Class definition for a Student, extending from Person
|
||||
class Student extends Person {
|
||||
constructor(name, age, grade) {
|
||||
super(name, age);
|
||||
this.grade = grade;
|
||||
}
|
||||
|
||||
// Method to describe the student
|
||||
describe() {
|
||||
return `${this.name} is a ${this.grade} grade student and is ${this.age} years old.`;
|
||||
}
|
||||
}
|
||||
|
||||
// Function to enroll a new student
|
||||
function enrollStudent(name, age, grade) {
|
||||
const student = new Student(name, age, grade);
|
||||
console.log(student.greet());
|
||||
console.log(student.describe());
|
||||
return student;
|
||||
}
|
||||
|
||||
// Function to promote a student to the next grade
|
||||
function promoteStudent(student) {
|
||||
student.grade += 1;
|
||||
console.log(`${student.name} has been promoted to grade ${student.grade}.`);
|
||||
return student;
|
||||
}
|
||||
|
||||
// Variable definition and assignment
|
||||
let schoolName = "Greenwood High School";
|
||||
let students = [];
|
||||
|
||||
// Enrolling students
|
||||
students.push(enrollStudent("Alice", 14, 9));
|
||||
students.push(enrollStudent("Bob", 15, 10));
|
||||
|
||||
// Looping through students to celebrate their birthdays
|
||||
students.forEach(student => {
|
||||
console.log(student.celebrateBirthday());
|
||||
});
|
||||
|
||||
// Promoting all students
|
||||
students = students.map(promoteStudent);
|
||||
|
||||
// Displaying the final state of all students
|
||||
console.log("Final Students List:");
|
||||
students.forEach(student => console.log(student.describe()));
|
||||
|
||||
// Updating the school name
|
||||
schoolName = "Greenwood International School";
|
||||
console.log(`School Name Updated to: ${schoolName}`);
|
||||
Loading…
Add table
Reference in a new issue