fix: add summaries to the graph

This commit is contained in:
Boris Arzentar 2024-11-07 15:38:03 +01:00
parent 897bbac699
commit f569088a2e
16 changed files with 127 additions and 58 deletions

View file

@ -338,7 +338,7 @@ class Neo4jAdapter(GraphDBInterface):
return predecessors + successors return predecessors + successors
async def get_connections(self, node_id: str) -> list: async def get_connections(self, node_id: UUID) -> list:
predecessors_query = """ predecessors_query = """
MATCH (node)<-[relation]-(neighbour) MATCH (node)<-[relation]-(neighbour)
WHERE node.id = $node_id WHERE node.id = $node_id

View file

@ -7,6 +7,7 @@ import asyncio
import logging import logging
from re import A from re import A
from typing import Dict, Any, List from typing import Dict, Any, List
from uuid import UUID
import aiofiles import aiofiles
import aiofiles.os as aiofiles_os import aiofiles.os as aiofiles_os
import networkx as nx import networkx as nx
@ -130,7 +131,7 @@ class NetworkXAdapter(GraphDBInterface):
async def extract_nodes(self, node_ids: List[str]) -> List[dict]: async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)] return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list: async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list:
if self.graph.has_node(node_id): if self.graph.has_node(node_id):
if edge_label is None: if edge_label is None:
return [ return [
@ -146,7 +147,7 @@ class NetworkXAdapter(GraphDBInterface):
return nodes return nodes
async def get_successors(self, node_id: str, edge_label: str = None) -> list: async def get_successors(self, node_id: UUID, edge_label: str = None) -> list:
if self.graph.has_node(node_id): if self.graph.has_node(node_id):
if edge_label is None: if edge_label is None:
return [ return [
@ -175,13 +176,13 @@ class NetworkXAdapter(GraphDBInterface):
return neighbours return neighbours
async def get_connections(self, node_id: str) -> list: async def get_connections(self, node_id: UUID) -> list:
if not self.graph.has_node(node_id): if not self.graph.has_node(node_id):
return [] return []
node = self.graph.nodes[node_id] node = self.graph.nodes[node_id]
if "uuid" not in node: if "id" not in node:
return [] return []
predecessors, successors = await asyncio.gather( predecessors, successors = await asyncio.gather(
@ -192,14 +193,14 @@ class NetworkXAdapter(GraphDBInterface):
connections = [] connections = []
for neighbor in predecessors: for neighbor in predecessors:
if "uuid" in neighbor: if "id" in neighbor:
edge_data = self.graph.get_edge_data(neighbor["uuid"], node["uuid"]) edge_data = self.graph.get_edge_data(neighbor["id"], node["id"])
for edge_properties in edge_data.values(): for edge_properties in edge_data.values():
connections.append((neighbor, edge_properties, node)) connections.append((neighbor, edge_properties, node))
for neighbor in successors: for neighbor in successors:
if "uuid" in neighbor: if "id" in neighbor:
edge_data = self.graph.get_edge_data(node["uuid"], neighbor["uuid"]) edge_data = self.graph.get_edge_data(node["id"], neighbor["id"])
for edge_properties in edge_data.values(): for edge_properties in edge_data.values():
connections.append((node, edge_properties, neighbor)) connections.append((node, edge_properties, neighbor))
@ -245,6 +246,17 @@ class NetworkXAdapter(GraphDBInterface):
if os.path.exists(file_path): if os.path.exists(file_path):
async with aiofiles.open(file_path, "r") as file: async with aiofiles.open(file_path, "r") as file:
graph_data = json.loads(await file.read()) graph_data = json.loads(await file.read())
for node in graph_data["nodes"]:
node["id"] = UUID(node["id"])
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
for edge in graph_data["links"]:
edge["source"] = UUID(edge["source"])
edge["target"] = UUID(edge["target"])
edge["source_node_id"] = UUID(edge["source_node_id"])
edge["target_node_id"] = UUID(edge["target_node_id"])
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data) self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
else: else:
# Log that the file does not exist and an empty graph is initialized # Log that the file does not exist and an empty graph is initialized

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
from textwrap import dedent from textwrap import dedent
from typing import Any from typing import Any
from uuid import UUID
from falkordb import FalkorDB from falkordb import FalkorDB
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -161,6 +162,35 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
async def extract_nodes(self, data_point_ids: list[str]): async def extract_nodes(self, data_point_ids: list[str]):
return await self.retrieve(data_point_ids) return await self.retrieve(data_point_ids)
async def get_connections(self, node_id: UUID) -> list:
predecessors_query = """
MATCH (node)<-[relation]-(neighbour)
WHERE node.id = $node_id
RETURN neighbour, relation, node
"""
successors_query = """
MATCH (node)-[relation]->(neighbour)
WHERE node.id = $node_id
RETURN node, relation, neighbour
"""
predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)),
self.query(successors_query, dict(node_id = node_id)),
)
connections = []
for neighbour in predecessors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
for neighbour in successors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
return connections
async def search( async def search(
self, self,
collection_name: str, collection_name: str,

View file

@ -168,7 +168,7 @@ class WeaviateAdapter(VectorDBInterface):
return [ return [
ScoredResult( ScoredResult(
id = UUID(result.id), id = UUID(result.uuid),
payload = result.properties, payload = result.properties,
score = float(result.metadata.score) score = float(result.metadata.score)
) for result in search_result.objects ) for result in search_result.objects

View file

@ -29,7 +29,7 @@ class TextChunker():
else: else:
if len(self.paragraph_chunks) == 0: if len(self.paragraph_chunks) == 0:
yield DocumentChunk( yield DocumentChunk(
id = str(chunk_data["chunk_id"]), id = chunk_data["chunk_id"],
text = chunk_data["text"], text = chunk_data["text"],
word_count = chunk_data["word_count"], word_count = chunk_data["word_count"],
is_part_of = self.document, is_part_of = self.document,
@ -42,7 +42,7 @@ class TextChunker():
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks) chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
try: try:
yield DocumentChunk( yield DocumentChunk(
id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")), id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text = chunk_text, text = chunk_text,
word_count = self.chunk_size, word_count = self.chunk_size,
is_part_of = self.document, is_part_of = self.document,
@ -59,7 +59,7 @@ class TextChunker():
if len(self.paragraph_chunks) > 0: if len(self.paragraph_chunks) > 0:
try: try:
yield DocumentChunk( yield DocumentChunk(
id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")), id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks), text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
word_count = self.chunk_size, word_count = self.chunk_size,
is_part_of = self.document, is_part_of = self.document,

View file

@ -1,9 +1,8 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules import data
from cognee.modules.storage.utils import copy_model from cognee.modules.storage.utils import copy_model
def get_graph_from_model(data_point: DataPoint, include_root = True): def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
nodes = [] nodes = []
edges = [] edges = []
@ -17,29 +16,55 @@ def get_graph_from_model(data_point: DataPoint, include_root = True):
if isinstance(field_value, DataPoint): if isinstance(field_value, DataPoint):
excluded_properties.add(field_name) excluded_properties.add(field_name)
property_nodes, property_edges = get_graph_from_model(field_value, True) property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
nodes[:0] = property_nodes
edges[:0] = property_edges for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges): for property_node in get_own_properties(property_nodes, property_edges):
edges.append((data_point.id, property_node.id, field_name, { edge_key = str(data_point.id) + str(property_node.id) + field_name
"source_node_id": data_point.id,
"target_node_id": property_node.id, if str(edge_key) not in added_edges:
"relationship_name": field_name, edges.append((data_point.id, property_node.id, field_name, {
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), "source_node_id": data_point.id,
})) "target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue continue
if isinstance(field_value, list): if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
if isinstance(field_value[0], DataPoint): excluded_properties.add(field_name)
excluded_properties.add(field_name)
for item in field_value: for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True) property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
nodes[:0] = property_nodes
edges[:0] = property_edges
for property_node in get_own_properties(property_nodes, property_edges): for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, { edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id, "source_node_id": data_point.id,
"target_node_id": property_node.id, "target_node_id": property_node.id,
@ -49,7 +74,8 @@ def get_graph_from_model(data_point: DataPoint, include_root = True):
"type": "list" "type": "list"
}, },
})) }))
continue added_edges[edge_key] = True
continue
data_point_properties[field_name] = field_value data_point_properties[field_name] = field_value

View file

@ -7,7 +7,7 @@ from ..tasks.Task import Task
logger = logging.getLogger("run_tasks(tasks: [Task], data)") logger = logging.getLogger("run_tasks(tasks: [Task], data)")
async def run_tasks_base(tasks: [Task], data = None, user: User = None): async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
if len(tasks) == 0: if len(tasks) == 0:
yield data yield data
return return
@ -16,7 +16,7 @@ async def run_tasks_base(tasks: [Task], data = None, user: User = None):
running_task = tasks[0] running_task = tasks[0]
leftover_tasks = tasks[1:] leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1 next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
if inspect.isasyncgenfunction(running_task.executable): if inspect.isasyncgenfunction(running_task.executable):

View file

@ -22,13 +22,13 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
exact_node = await graph_engine.extract_node(node_id) exact_node = await graph_engine.extract_node(node_id)
if exact_node is not None and "uuid" in exact_node: if exact_node is not None and "id" in exact_node:
node_connections = await graph_engine.get_connections(str(exact_node["uuid"])) node_connections = await graph_engine.get_connections(str(exact_node["id"]))
else: else:
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
results = await asyncio.gather( results = await asyncio.gather(
vector_engine.search("Entity_text", query_text = query, limit = 5), vector_engine.search("Entity_name", query_text = query, limit = 5),
vector_engine.search("EntityType_text", query_text = query, limit = 5), vector_engine.search("EntityType_name", query_text = query, limit = 5),
) )
results = [*results[0], *results[1]] results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5] relevant_results = [result for result in results if result.score < 0.5][:5]
@ -37,7 +37,7 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
return [] return []
node_connections_results = await asyncio.gather( node_connections_results = await asyncio.gather(
*[graph_engine.get_connections(str(result.payload["uuid"])) for result in relevant_results] *[graph_engine.get_connections(result.id) for result in relevant_results]
) )
node_connections = [] node_connections = []
@ -48,10 +48,10 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
unique_node_connections_map = {} unique_node_connections_map = {}
unique_node_connections = [] unique_node_connections = []
for node_connection in node_connections: for node_connection in node_connections:
if "uuid" not in node_connection[0] or "uuid" not in node_connection[2]: if "id" not in node_connection[0] or "id" not in node_connection[2]:
continue continue
unique_id = f"{node_connection[0]['uuid']} {node_connection[1]['relationship_name']} {node_connection[2]['uuid']}" unique_id = f"{node_connection[0]['id']} {node_connection[1]['relationship_name']} {node_connection[2]['id']}"
if unique_id not in unique_node_connections_map: if unique_id not in unique_node_connections_map:
unique_node_connections_map[unique_id] = True unique_node_connections_map[unique_id] = True

View file

@ -56,7 +56,8 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
added_data_points[str(new_point.id)] = True added_data_points[str(new_point.id)] = True
data_points.append(new_point) data_points.append(new_point)
data_points.append(data_point) if (str(data_point.id) not in added_data_points):
data_points.append(data_point)
return data_points return data_points

View file

@ -4,9 +4,8 @@ from cognee.modules.data.processing.document_types import Document
class TextSummary(DataPoint): class TextSummary(DataPoint):
text: str text: str
chunk: DocumentChunk made_from: DocumentChunk
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"], "index_fields": ["text"],
} }

View file

@ -5,6 +5,7 @@ from pydantic import BaseModel
from cognee.modules.data.extraction.extract_summary import extract_summary from cognee.modules.data.extraction.extract_summary import extract_summary
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
from cognee.tasks.storage.index_data_points import get_data_points_from_model
from .models.TextSummary import TextSummary from .models.TextSummary import TextSummary
async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]): async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]):
@ -17,12 +18,12 @@ async def summarize_text(data_chunks: list[DocumentChunk], summarization_model:
summaries = [ summaries = [
TextSummary( TextSummary(
id = uuid5(chunk.id, "summary"), id = uuid5(chunk.id, "TextSummary"),
chunk = chunk, made_from = chunk,
text = chunk_summaries[chunk_index].summary, text = chunk_summaries[chunk_index].summary,
) for (chunk_index, chunk) in enumerate(data_chunks) ) for (chunk_index, chunk) in enumerate(data_chunks)
] ]
add_data_points(summaries) await add_data_points(summaries)
return data_chunks return data_chunks

View file

@ -32,8 +32,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["name"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -36,8 +36,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["name"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -65,8 +65,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["name"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -37,8 +37,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["name"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -35,8 +35,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["name"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."