fix: add summaries to the graph

This commit is contained in:
Boris Arzentar 2024-11-07 15:38:03 +01:00 committed by Leon Luithlen
parent 63900f6b0a
commit 7ea5f638fe
14 changed files with 106 additions and 61 deletions

View file

@ -247,27 +247,15 @@ class NetworkXAdapter(GraphDBInterface):
async with aiofiles.open(file_path, "r") as file: async with aiofiles.open(file_path, "r") as file:
graph_data = json.loads(await file.read()) graph_data = json.loads(await file.read())
for node in graph_data["nodes"]: for node in graph_data["nodes"]:
try: node["id"] = UUID(node["id"])
node["id"] = UUID(node["id"]) node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
except:
pass
if "updated_at" in node:
node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
for edge in graph_data["links"]: for edge in graph_data["links"]:
try: edge["source"] = UUID(edge["source"])
source_id = UUID(edge["source"]) edge["target"] = UUID(edge["target"])
target_id = UUID(edge["target"]) edge["source_node_id"] = UUID(edge["source_node_id"])
edge["target_node_id"] = UUID(edge["target_node_id"])
edge["source"] = source_id edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
edge["target"] = target_id
edge["source_node_id"] = source_id
edge["target_node_id"] = target_id
except:
pass
if "updated_at" in edge:
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data) self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
else: else:

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
from textwrap import dedent from textwrap import dedent
from typing import Any from typing import Any
from uuid import UUID
from falkordb import FalkorDB from falkordb import FalkorDB
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
@ -161,6 +162,35 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
async def extract_nodes(self, data_point_ids: list[str]): async def extract_nodes(self, data_point_ids: list[str]):
return await self.retrieve(data_point_ids) return await self.retrieve(data_point_ids)
async def get_connections(self, node_id: UUID) -> list:
predecessors_query = """
MATCH (node)<-[relation]-(neighbour)
WHERE node.id = $node_id
RETURN neighbour, relation, node
"""
successors_query = """
MATCH (node)-[relation]->(neighbour)
WHERE node.id = $node_id
RETURN node, relation, neighbour
"""
predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)),
self.query(successors_query, dict(node_id = node_id)),
)
connections = []
for neighbour in predecessors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
for neighbour in successors:
neighbour = neighbour["relation"]
connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2]))
return connections
async def search( async def search(
self, self,
collection_name: str, collection_name: str,

View file

@ -178,7 +178,7 @@ class WeaviateAdapter(VectorDBInterface):
return [ return [
ScoredResult( ScoredResult(
id = UUID(result.id), id = UUID(result.uuid),
payload = result.properties, payload = result.properties,
score = float(result.metadata.score) score = float(result.metadata.score)
) for result in search_result.objects ) for result in search_result.objects

View file

@ -29,7 +29,7 @@ class TextChunker():
else: else:
if len(self.paragraph_chunks) == 0: if len(self.paragraph_chunks) == 0:
yield DocumentChunk( yield DocumentChunk(
id = str(chunk_data["chunk_id"]), id = chunk_data["chunk_id"],
text = chunk_data["text"], text = chunk_data["text"],
word_count = chunk_data["word_count"], word_count = chunk_data["word_count"],
is_part_of = self.document, is_part_of = self.document,
@ -42,7 +42,7 @@ class TextChunker():
chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks) chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks)
try: try:
yield DocumentChunk( yield DocumentChunk(
id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")), id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text = chunk_text, text = chunk_text,
word_count = self.chunk_size, word_count = self.chunk_size,
is_part_of = self.document, is_part_of = self.document,
@ -59,7 +59,7 @@ class TextChunker():
if len(self.paragraph_chunks) > 0: if len(self.paragraph_chunks) > 0:
try: try:
yield DocumentChunk( yield DocumentChunk(
id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")), id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text = " ".join(chunk["text"] for chunk in self.paragraph_chunks), text = " ".join(chunk["text"] for chunk in self.paragraph_chunks),
word_count = self.chunk_size, word_count = self.chunk_size,
is_part_of = self.document, is_part_of = self.document,

View file

@ -1,9 +1,8 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules import data
from cognee.modules.storage.utils import copy_model from cognee.modules.storage.utils import copy_model
def get_graph_from_model(data_point: DataPoint, include_root = True): def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
nodes = [] nodes = []
edges = [] edges = []
@ -17,29 +16,55 @@ def get_graph_from_model(data_point: DataPoint, include_root = True):
if isinstance(field_value, DataPoint): if isinstance(field_value, DataPoint):
excluded_properties.add(field_name) excluded_properties.add(field_name)
property_nodes, property_edges = get_graph_from_model(field_value, True) property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
nodes[:0] = property_nodes
edges[:0] = property_edges for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges): for property_node in get_own_properties(property_nodes, property_edges):
edges.append((data_point.id, property_node.id, field_name, { edge_key = str(data_point.id) + str(property_node.id) + field_name
"source_node_id": data_point.id,
"target_node_id": property_node.id, if str(edge_key) not in added_edges:
"relationship_name": field_name, edges.append((data_point.id, property_node.id, field_name, {
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), "source_node_id": data_point.id,
})) "target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue continue
if isinstance(field_value, list): if isinstance(field_value, list) and isinstance(field_value[0], DataPoint):
if isinstance(field_value[0], DataPoint): excluded_properties.add(field_name)
excluded_properties.add(field_name)
for item in field_value: for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True) property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
nodes[:0] = property_nodes
edges[:0] = property_edges
for property_node in get_own_properties(property_nodes, property_edges): for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, { edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id, "source_node_id": data_point.id,
"target_node_id": property_node.id, "target_node_id": property_node.id,
@ -49,7 +74,8 @@ def get_graph_from_model(data_point: DataPoint, include_root = True):
"type": "list" "type": "list"
}, },
})) }))
continue added_edges[edge_key] = True
continue
data_point_properties[field_name] = field_value data_point_properties[field_name] = field_value

View file

@ -27,8 +27,8 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
else: else:
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
results = await asyncio.gather( results = await asyncio.gather(
vector_engine.search("Entity_text", query_text = query, limit = 5), vector_engine.search("Entity_name", query_text = query, limit = 5),
vector_engine.search("EntityType_text", query_text = query, limit = 5), vector_engine.search("EntityType_name", query_text = query, limit = 5),
) )
results = [*results[0], *results[1]] results = [*results[0], *results[1]]
relevant_results = [result for result in results if result.score < 0.5][:5] relevant_results = [result for result in results if result.score < 0.5][:5]

View file

@ -56,7 +56,8 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
added_data_points[str(new_point.id)] = True added_data_points[str(new_point.id)] = True
data_points.append(new_point) data_points.append(new_point)
data_points.append(data_point) if (str(data_point.id) not in added_data_points):
data_points.append(data_point)
return data_points return data_points

View file

@ -4,9 +4,8 @@ from cognee.modules.data.processing.document_types import Document
class TextSummary(DataPoint): class TextSummary(DataPoint):
text: str text: str
chunk: DocumentChunk made_from: DocumentChunk
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"], "index_fields": ["text"],
} }

View file

@ -5,6 +5,7 @@ from pydantic import BaseModel
from cognee.modules.data.extraction.extract_summary import extract_summary from cognee.modules.data.extraction.extract_summary import extract_summary
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
from cognee.tasks.storage.index_data_points import get_data_points_from_model
from .models.TextSummary import TextSummary from .models.TextSummary import TextSummary
async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]): async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]):
@ -17,12 +18,12 @@ async def summarize_text(data_chunks: list[DocumentChunk], summarization_model:
summaries = [ summaries = [
TextSummary( TextSummary(
id = uuid5(chunk.id, "summary"), id = uuid5(chunk.id, "TextSummary"),
chunk = chunk, made_from = chunk,
text = chunk_summaries[chunk_index].summary, text = chunk_summaries[chunk_index].summary,
) for (chunk_index, chunk) in enumerate(data_chunks) ) for (chunk_index, chunk) in enumerate(data_chunks)
] ]
add_data_points(summaries) await add_data_points(summaries)
return data_chunks return data_chunks

View file

@ -32,8 +32,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["name"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -36,8 +36,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["name"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -65,8 +65,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["name"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -37,8 +37,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["name"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."

View file

@ -35,8 +35,8 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node_name = random_node.payload["name"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)
assert len(search_results) != 0, "The search results list is empty." assert len(search_results) != 0, "The search results list is empty."