diff --git a/cognee/modules/data/extraction/knowledge_graph/extract_content_graph_sequential.py b/cognee/modules/data/extraction/knowledge_graph/extract_content_graph_sequential.py index 42c242490..092a239ac 100644 --- a/cognee/modules/data/extraction/knowledge_graph/extract_content_graph_sequential.py +++ b/cognee/modules/data/extraction/knowledge_graph/extract_content_graph_sequential.py @@ -27,7 +27,7 @@ async def extract_content_graph_sequential( edges_json = json.dumps([e.model_dump() for e in current_edges], ensure_ascii=False) graph_user = render_prompt( - graph_user_prompt_path, + graph_user_prompt_path, #:TODO: this could use some formatting due to html #34 codes. { "text": content, "graph": f"nodes: {nodes_json}, edges: {edges_json}", diff --git a/cognee/modules/data/extraction/knowledge_graph/extract_content_node_edge_multi_sequential.py b/cognee/modules/data/extraction/knowledge_graph/extract_content_node_edge_multi_sequential.py new file mode 100644 index 000000000..4336c6b65 --- /dev/null +++ b/cognee/modules/data/extraction/knowledge_graph/extract_content_node_edge_multi_sequential.py @@ -0,0 +1,57 @@ +import json + +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.infrastructure.llm.prompts import render_prompt +from cognee.shared.data_models import KnowledgeGraph, NodeList, EdgeList + + +async def extract_content_node_edge_multi_sequential( + content: str, node_rounds: int = 2, edge_rounds=2 +): + llm_client = get_llm_client() + + current_nodes = NodeList() + + for pass_idx in range(node_rounds): + nodes_json = json.dumps([n.model_dump() for n in current_nodes.nodes], ensure_ascii=False) + + node_system = render_prompt("node_extraction_prompt_sequential.txt", {}) + node_user = render_prompt( + "node_extraction_prompt_sequential_user.txt", + { + "text": content, + "nodes": {nodes_json}, + "total_rounds": {node_rounds}, + "round_number": {pass_idx}, + }, + ) + + current_nodes = await llm_client.acreate_structured_output(node_user, node_system, NodeList) + + final_nodes = current_nodes + final_nodes_json = json.dumps([n.model_dump() for n in final_nodes.nodes], ensure_ascii=False) + + current_edges = EdgeList() + + for pass_idx in range(edge_rounds): + edges_json = json.dumps([n.model_dump() for n in current_edges.edges], ensure_ascii=False) + + edges_system = render_prompt("edge_extraction_prompt_sequential.txt", {}) + edges_user = render_prompt( + "edge_extraction_prompt_sequential_user.txt", + { + "text": content, + "nodes": {final_nodes_json}, + "edges": {edges_json}, + "total_rounds": {node_rounds}, + "round_number": {pass_idx}, + }, + ) + + current_edges = await llm_client.acreate_structured_output( + edges_user, edges_system, EdgeList + ) + + final_edges = current_edges + + return KnowledgeGraph(nodes=final_nodes.nodes, edges=final_edges.edges) diff --git a/cognee/tasks/graph/extract_graph_from_data.py b/cognee/tasks/graph/extract_graph_from_data.py index 58b2502a2..252467e0c 100644 --- a/cognee/tasks/graph/extract_graph_from_data.py +++ b/cognee/tasks/graph/extract_graph_from_data.py @@ -10,13 +10,18 @@ from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.data.extraction.knowledge_graph.extract_content_graph import ( extract_content_graph, ) -from cognee.modules.data.extraction.knowledge_graph.extract_content_graph_multi_parallel import ( - extract_content_graph_multi_parallel, +from cognee.modules.data.extraction.knowledge_graph.extract_content_node_edge_multi_parallel import ( + extract_content_node_edge_multi_parallel, ) from cognee.modules.data.extraction.knowledge_graph.extract_content_graph_sequential import ( extract_content_graph_sequential, ) + +from cognee.modules.data.extraction.knowledge_graph.extract_content_node_edge_multi_sequential import ( + extract_content_node_edge_multi_sequential, +) + from cognee.modules.graph.utils import ( expand_with_nodes_and_edges, retrieve_existing_edges, @@ -70,11 +75,16 @@ async def extract_graph_from_data( """ chunk_graphs = await asyncio.gather( # *[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks] - # *[extract_content_graph_multi_parallel(chunk.text, graph_model) for chunk in data_chunks] - *[extract_content_graph_sequential(chunk.text, graph_model) for chunk in data_chunks] + # *[extract_content_node_edge_multi_parallel(content=chunk.text, node_rounds=2) for chunk in data_chunks] + # *[extract_content_graph_sequential(content=chunk.text, response_model=graph_model, graph_extraction_rounds=2) for chunk in data_chunks] + *[ + extract_content_node_edge_multi_sequential( + content=chunk.text, node_rounds=1, edge_rounds=1 + ) + for chunk in data_chunks + ] ) - # Note: Filter edges with missing source or target nodes if graph_model == KnowledgeGraph: for graph in chunk_graphs: valid_node_ids = {node.id for node in graph.nodes}