From 7ea5f638fe0478bf34f528481aec23f033ecc978 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 7 Nov 2024 15:38:03 +0100 Subject: [PATCH] fix: add summaries to the graph --- .../databases/graph/networkx/adapter.py | 26 ++------ .../hybrid/falkordb/FalkorDBAdapter.py | 30 +++++++++ .../vector/weaviate_db/WeaviateAdapter.py | 2 +- cognee/modules/chunking/TextChunker.py | 6 +- .../graph/utils/get_graph_from_model.py | 66 +++++++++++++------ cognee/tasks/graph/query_graph_connections.py | 4 +- cognee/tasks/storage/index_data_points.py | 3 +- .../tasks/summarization/models/TextSummary.py | 3 +- cognee/tasks/summarization/summarize_text.py | 7 +- cognee/tests/test_library.py | 4 +- cognee/tests/test_neo4j.py | 4 +- cognee/tests/test_pgvector.py | 4 +- cognee/tests/test_qdrant.py | 4 +- cognee/tests/test_weaviate.py | 4 +- 14 files changed, 106 insertions(+), 61 deletions(-) diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index b0a9e7a13..dcb05c2ed 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -247,27 +247,15 @@ class NetworkXAdapter(GraphDBInterface): async with aiofiles.open(file_path, "r") as file: graph_data = json.loads(await file.read()) for node in graph_data["nodes"]: - try: - node["id"] = UUID(node["id"]) - except: - pass - if "updated_at" in node: - node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") + node["id"] = UUID(node["id"]) + node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") for edge in graph_data["links"]: - try: - source_id = UUID(edge["source"]) - target_id = UUID(edge["target"]) - - edge["source"] = source_id - edge["target"] = target_id - edge["source_node_id"] = source_id - edge["target_node_id"] = target_id - except: - pass - - if "updated_at" in edge: - edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") + edge["source"] = UUID(edge["source"]) + edge["target"] = UUID(edge["target"]) + edge["source_node_id"] = UUID(edge["source_node_id"]) + edge["target_node_id"] = UUID(edge["target_node_id"]) + edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") self.graph = nx.readwrite.json_graph.node_link_graph(graph_data) else: diff --git a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py index effe9e682..ea5a75088 100644 --- a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py +++ b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py @@ -1,6 +1,7 @@ import asyncio from textwrap import dedent from typing import Any +from uuid import UUID from falkordb import FalkorDB from cognee.infrastructure.engine import DataPoint @@ -161,6 +162,35 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): async def extract_nodes(self, data_point_ids: list[str]): return await self.retrieve(data_point_ids) + async def get_connections(self, node_id: UUID) -> list: + predecessors_query = """ + MATCH (node)<-[relation]-(neighbour) + WHERE node.id = $node_id + RETURN neighbour, relation, node + """ + successors_query = """ + MATCH (node)-[relation]->(neighbour) + WHERE node.id = $node_id + RETURN node, relation, neighbour + """ + + predecessors, successors = await asyncio.gather( + self.query(predecessors_query, dict(node_id = node_id)), + self.query(successors_query, dict(node_id = node_id)), + ) + + connections = [] + + for neighbour in predecessors: + neighbour = neighbour["relation"] + connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2])) + + for neighbour in successors: + neighbour = neighbour["relation"] + connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2])) + + return connections + async def search( self, collection_name: str, diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index 2e4e88323..dd7539118 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -178,7 +178,7 @@ class WeaviateAdapter(VectorDBInterface): return [ ScoredResult( - id = UUID(result.id), + id = UUID(result.uuid), payload = result.properties, score = float(result.metadata.score) ) for result in search_result.objects diff --git a/cognee/modules/chunking/TextChunker.py b/cognee/modules/chunking/TextChunker.py index 4717d108d..714383804 100644 --- a/cognee/modules/chunking/TextChunker.py +++ b/cognee/modules/chunking/TextChunker.py @@ -29,7 +29,7 @@ class TextChunker(): else: if len(self.paragraph_chunks) == 0: yield DocumentChunk( - id = str(chunk_data["chunk_id"]), + id = chunk_data["chunk_id"], text = chunk_data["text"], word_count = chunk_data["word_count"], is_part_of = self.document, @@ -42,7 +42,7 @@ class TextChunker(): chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks) try: yield DocumentChunk( - id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")), + id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), text = chunk_text, word_count = self.chunk_size, is_part_of = self.document, @@ -59,7 +59,7 @@ class TextChunker(): if len(self.paragraph_chunks) > 0: try: yield DocumentChunk( - id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")), + id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), text = " ".join(chunk["text"] for chunk in self.paragraph_chunks), word_count = self.chunk_size, is_part_of = self.document, diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index ef402e4d6..35e00fb5d 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -1,9 +1,8 @@ from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint -from cognee.modules import data from cognee.modules.storage.utils import copy_model -def get_graph_from_model(data_point: DataPoint, include_root = True): +def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}): nodes = [] edges = [] @@ -17,29 +16,55 @@ def get_graph_from_model(data_point: DataPoint, include_root = True): if isinstance(field_value, DataPoint): excluded_properties.add(field_name) - property_nodes, property_edges = get_graph_from_model(field_value, True) - nodes[:0] = property_nodes - edges[:0] = property_edges + property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) + + for node in property_nodes: + if str(node.id) not in added_nodes: + nodes.append(node) + added_nodes[str(node.id)] = True + + for edge in property_edges: + edge_key = str(edge[0]) + str(edge[1]) + edge[2] + + if str(edge_key) not in added_edges: + edges.append(edge) + added_edges[str(edge_key)] = True for property_node in get_own_properties(property_nodes, property_edges): - edges.append((data_point.id, property_node.id, field_name, { - "source_node_id": data_point.id, - "target_node_id": property_node.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - })) + edge_key = str(data_point.id) + str(property_node.id) + field_name + + if str(edge_key) not in added_edges: + edges.append((data_point.id, property_node.id, field_name, { + "source_node_id": data_point.id, + "target_node_id": property_node.id, + "relationship_name": field_name, + "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + })) + added_edges[str(edge_key)] = True continue - if isinstance(field_value, list): - if isinstance(field_value[0], DataPoint): - excluded_properties.add(field_name) + if isinstance(field_value, list) and isinstance(field_value[0], DataPoint): + excluded_properties.add(field_name) - for item in field_value: - property_nodes, property_edges = get_graph_from_model(item, True) - nodes[:0] = property_nodes - edges[:0] = property_edges + for item in field_value: + property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges) - for property_node in get_own_properties(property_nodes, property_edges): + for node in property_nodes: + if str(node.id) not in added_nodes: + nodes.append(node) + added_nodes[str(node.id)] = True + + for edge in property_edges: + edge_key = str(edge[0]) + str(edge[1]) + edge[2] + + if str(edge_key) not in added_edges: + edges.append(edge) + added_edges[edge_key] = True + + for property_node in get_own_properties(property_nodes, property_edges): + edge_key = str(data_point.id) + str(property_node.id) + field_name + + if str(edge_key) not in added_edges: edges.append((data_point.id, property_node.id, field_name, { "source_node_id": data_point.id, "target_node_id": property_node.id, @@ -49,7 +74,8 @@ def get_graph_from_model(data_point: DataPoint, include_root = True): "type": "list" }, })) - continue + added_edges[edge_key] = True + continue data_point_properties[field_name] = field_value diff --git a/cognee/tasks/graph/query_graph_connections.py b/cognee/tasks/graph/query_graph_connections.py index 5c538a994..cd4d76a5e 100644 --- a/cognee/tasks/graph/query_graph_connections.py +++ b/cognee/tasks/graph/query_graph_connections.py @@ -27,8 +27,8 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s else: vector_engine = get_vector_engine() results = await asyncio.gather( - vector_engine.search("Entity_text", query_text = query, limit = 5), - vector_engine.search("EntityType_text", query_text = query, limit = 5), + vector_engine.search("Entity_name", query_text = query, limit = 5), + vector_engine.search("EntityType_name", query_text = query, limit = 5), ) results = [*results[0], *results[1]] relevant_results = [result for result in results if result.score < 0.5][:5] diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index a28335e24..681fbaa1f 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -56,7 +56,8 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) -> added_data_points[str(new_point.id)] = True data_points.append(new_point) - data_points.append(data_point) + if (str(data_point.id) not in added_data_points): + data_points.append(data_point) return data_points diff --git a/cognee/tasks/summarization/models/TextSummary.py b/cognee/tasks/summarization/models/TextSummary.py index 5e724cd63..c6a932b37 100644 --- a/cognee/tasks/summarization/models/TextSummary.py +++ b/cognee/tasks/summarization/models/TextSummary.py @@ -4,9 +4,8 @@ from cognee.modules.data.processing.document_types import Document class TextSummary(DataPoint): text: str - chunk: DocumentChunk + made_from: DocumentChunk _metadata: dict = { "index_fields": ["text"], } - diff --git a/cognee/tasks/summarization/summarize_text.py b/cognee/tasks/summarization/summarize_text.py index a1abacccf..756f65e39 100644 --- a/cognee/tasks/summarization/summarize_text.py +++ b/cognee/tasks/summarization/summarize_text.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from cognee.modules.data.extraction.extract_summary import extract_summary from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.tasks.storage import add_data_points +from cognee.tasks.storage.index_data_points import get_data_points_from_model from .models.TextSummary import TextSummary async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]): @@ -17,12 +18,12 @@ async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: summaries = [ TextSummary( - id = uuid5(chunk.id, "summary"), - chunk = chunk, + id = uuid5(chunk.id, "TextSummary"), + made_from = chunk, text = chunk_summaries[chunk_index].summary, ) for (chunk_index, chunk) in enumerate(data_chunks) ] - add_data_points(summaries) + await add_data_points(summaries) return data_chunks diff --git a/cognee/tests/test_library.py b/cognee/tests/test_library.py index d7e7e5fe8..2e707b64c 100755 --- a/cognee/tests/test_library.py +++ b/cognee/tests/test_library.py @@ -32,8 +32,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity", "AI"))[0] - random_node_name = random_node.payload["name"] + random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index 2f9abf124..0783e973a 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -36,8 +36,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity", "AI"))[0] - random_node_name = random_node.payload["name"] + random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index 4adc6accc..cea7c8f72 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -65,8 +65,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity", "AI"))[0] - random_node_name = random_node.payload["name"] + random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_qdrant.py b/cognee/tests/test_qdrant.py index 84fac6a2e..faa2cbcf4 100644 --- a/cognee/tests/test_qdrant.py +++ b/cognee/tests/test_qdrant.py @@ -37,8 +37,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity", "AI"))[0] - random_node_name = random_node.payload["name"] + random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_weaviate.py b/cognee/tests/test_weaviate.py index e943e1ec9..121c1749e 100644 --- a/cognee/tests/test_weaviate.py +++ b/cognee/tests/test_weaviate.py @@ -35,8 +35,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity", "AI"))[0] - random_node_name = random_node.payload["name"] + random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty."