fix: add summaries to the graph

This commit is contained in:
Boris Arzentar 2024-11-07 15:38:03 +01:00 committed by Leon Luithlen
parent 63900f6b0a
commit 7ea5f638fe
14 changed files with 106 additions and 61 deletions

View file

@ -247,27 +247,15 @@ class NetworkXAdapter(GraphDBInterface):
async with aiofiles.open(file_path, "r") as file:
graph_data = json.loads(await file.read())
for node in graph_data["nodes"]:
try:
node["id"] = UUID(node["id"])
except:
pass
if "updated_at" in node:
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
node["id"] = UUID(node["id"])
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
for edge in graph_data["links"]:
try:
source_id = UUID(edge["source"])
target_id = UUID(edge["target"])
edge["source"] = source_id
edge["target"] = target_id
edge["source_node_id"] = source_id
edge["target_node_id"] = target_id
except:
pass
if "updated_at" in edge:
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
edge["source"] = UUID(edge["source"])
edge["target"] = UUID(edge["target"])
edge["source_node_id"] = UUID(edge["source_node_id"])
edge["target_node_id"] = UUID(edge["target_node_id"])
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
else:

View file

@ -1,6 +1,7 @@
import asyncio
from textwrap import dedent
from typing import Any
from uuid import UUID
from falkordb import FalkorDB
from cognee.infrastructure.engine import DataPoint
@ -161,6 +162,35 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
async def extract_nodes(self, data_point_ids: list[str]):
return await self.retrieve(data_point_ids)
async def get_connections(self, node_id: UUID) -> list:
predecessors_query = """
MATCH (node)<-[relation]-(neighbour)
WHERE node.id = $node_id
RETURN neighbour, relation, node
"""
successors_query = """
MATCH (node)-[relation]->(neighbour)
WHERE node.id = $node_id
RETURN node, relation, neighbour
"""
predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)),
self.query(successors_query, dict(node_id = node_id)),
)
connections = []
for neighbour in predecessors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
for neighbour in successors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
return connections
async def search(
self,
collection_name: str,

View file

@ -178,7 +178,7 @@ class WeaviateAdapter(VectorDBInterface):
return [
ScoredResult(
id = UUID(result.id),
id = UUID(result.uuid),
payload = result.properties,
score = float(result.metadata.score)
) for result in search_result.objects

View file

@ -29,7 +29,7 @@ class TextChunker():
else:
if len(self.paragraph_chunks) == 0:
yield DocumentChunk(
id = str(chunk_data["chunk_id"]),
id = chunk_data["chunk_id"],
text = chunk_data["text"],
word_count = chunk_data["word_count"],
is_part_of = self.document,
@ -42,7 +42,7 @@ class TextChunker():
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
try:
yield DocumentChunk(
id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")),
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text = chunk_text,
word_count = self.chunk_size,
is_part_of = self.document,
@ -59,7 +59,7 @@ class TextChunker():
if len(self.paragraph_chunks) > 0:
try:
yield DocumentChunk(
id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")),
id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
word_count = self.chunk_size,
is_part_of = self.document,

View file

@ -1,9 +1,8 @@
from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint
from cognee.modules import data
from cognee.modules.storage.utils import copy_model
def get_graph_from_model(data_point: DataPoint, include_root = True):
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
nodes = []
edges = []
@ -17,29 +16,55 @@ def get_graph_from_model(data_point: DataPoint, include_root = True):
if isinstance(field_value, DataPoint):
excluded_properties.add(field_name)
property_nodes, property_edges = get_graph_from_model(field_value, True)
nodes[:0] = property_nodes
edges[:0] = property_edges
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges):
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue
if isinstance(field_value, list):
if isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name)
if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name)
for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True)
nodes[:0] = property_nodes
edges[:0] = property_edges
for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
for property_node in get_own_properties(property_nodes, property_edges):
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
@ -49,7 +74,8 @@ def get_graph_from_model(data_point: DataPoint, include_root = True):
"type": "list"
},
}))
continue
added_edges[edge_key] = True
continue
data_point_properties[field_name] = field_value

View file

@ -27,8 +27,8 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
else:
vector_engine = get_vector_engine()
results = await asyncio.gather(
vector_engine.search("Entity_text", query_text = query, limit = 5),
vector_engine.search("EntityType_text", query_text = query, limit = 5),
vector_engine.search("Entity_name", query_text = query, limit = 5),
vector_engine.search("EntityType_name", query_text = query, limit = 5),
)
results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5]

View file

@ -56,7 +56,8 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
added_data_points[str(new_point.id)] = True
data_points.append(new_point)
data_points.append(data_point)
if (str(data_point.id) not in added_data_points):
data_points.append(data_point)
return data_points

View file

@ -4,9 +4,8 @@ from cognee.modules.data.processing.document_types import Document
class TextSummary(DataPoint):
text: str
chunk: DocumentChunk
made_from: DocumentChunk
_metadata: dict = {
"index_fields": ["text"],
}

View file

@ -5,6 +5,7 @@ from pydantic import BaseModel
from cognee.modules.data.extraction.extract_summary import extract_summary
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.tasks.storage import add_data_points
from cognee.tasks.storage.index_data_points import get_data_points_from_model
from .models.TextSummary import TextSummary
async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]):
@ -17,12 +18,12 @@ async def summarize_text(data_chunks: list[DocumentChunk], summarization_model:
summaries = [
TextSummary(
id = uuid5(chunk.id, "summary"),
chunk = chunk,
id = uuid5(chunk.id, "TextSummary"),
made_from = chunk,
text = chunk_summaries[chunk_index].summary,
) for (chunk_index, chunk) in enumerate(data_chunks)
]
add_data_points(summaries)
await add_data_points(summaries)
return data_chunks

View file

@ -32,8 +32,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0]
random_node_name = random_node.payload["name"]
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty."

View file

@ -36,8 +36,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0]
random_node_name = random_node.payload["name"]
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty."

View file

@ -65,8 +65,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0]
random_node_name = random_node.payload["name"]
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)
assert len(search_results) != 0, "The search results list is empty."

View file

@ -37,8 +37,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0]
random_node_name = random_node.payload["name"]
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty."

View file

@ -35,8 +35,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0]
random_node_name = random_node.payload["name"]
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty."