From c351047c3615d78c6fb5de613e24fbf44c5e3b50 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Mon, 13 Jan 2025 17:22:59 +0100 Subject: [PATCH] feat: adds cognee node and edge embeddings for graphiti graph --- .../databases/graph/graph_db_interface.py | 4 ++ .../databases/graph/neo4j_driver/adapter.py | 9 +++ cognee/tasks/storage/index_graph_edges.py | 71 +++++++++++++++++++ .../temporal_awareness/graphiti_model.py | 12 ++++ examples/python/graphiti_example.py | 16 ++++- 5 files changed, 109 insertions(+), 3 deletions(-) create mode 100644 cognee/tasks/temporal_awareness/graphiti_model.py diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 30acc1b95..b38fe610b 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -51,6 +51,10 @@ class GraphDBInterface(Protocol): ): raise NotImplementedError + @abstractmethod + async def get_model_independent_graph_data(self): + raise NotImplementedError + @abstractmethod async def get_graph_data(self): raise NotImplementedError diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 3543418fc..7988497c6 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -426,6 +426,15 @@ class Neo4jAdapter(GraphDBInterface): return serialized_properties + async def get_model_independent_graph_data(self): + query_nodes = "MATCH (n) RETURN collect(n) AS nodes" + nodes = await self.query(query_nodes) + + query_edges = "MATCH ()-[r]->() RETURN collect(r) AS relationships" + edges = await self.query(query_edges) + + return (nodes, edges) + async def get_graph_data(self): query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties" diff --git a/cognee/tasks/storage/index_graph_edges.py b/cognee/tasks/storage/index_graph_edges.py index 2aeb2bef2..0b2772523 100644 --- a/cognee/tasks/storage/index_graph_edges.py +++ b/cognee/tasks/storage/index_graph_edges.py @@ -4,6 +4,77 @@ from collections import Counter from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.graph.models.EdgeType import EdgeType +from cognee.tasks.temporal_awareness.graphiti_model import GraphitiNode + + +async def index_graphiti_nodes_and_edges(): + try: + created_indexes = {} + index_points = {} + + vector_engine = get_vector_engine() + graph_engine = await get_graph_engine() + except Exception as e: + logging.error("Failed to initialize engines: %s", e) + raise RuntimeError("Initialization error") from e + + nodes_data, edges_data = await graph_engine.get_model_independent_graph_data() + + for node_data in nodes_data[0]["nodes"]: + graphiti_node = GraphitiNode( + **{key: node_data[key] for key in ("content", "name", "summary") if key in node_data} + ) + + data_point_type = type(graphiti_node) + + for field_name in graphiti_node._metadata["index_fields"]: + index_name = f"{data_point_type.__tablename__}.{field_name}" + + if index_name not in created_indexes: + await vector_engine.create_vector_index(data_point_type.__tablename__, field_name) + created_indexes[index_name] = True + + if index_name not in index_points: + index_points[index_name] = [] + + if getattr(graphiti_node, field_name, None) is not None: + indexed_data_point = graphiti_node.model_copy() + indexed_data_point._metadata["index_fields"] = [field_name] + index_points[index_name].append(indexed_data_point) + + for index_name, indexable_points in index_points.items(): + index_name, field_name = index_name.split(".") + await vector_engine.index_data_points(index_name, field_name, indexable_points) + + edge_types = Counter( + edge[1] + for edge in edges_data[0]["relationships"] + if isinstance(edge, tuple) and len(edge) == 3 + ) + + for text, count in edge_types.items(): + edge = EdgeType(relationship_name=text, number_of_edges=count) + data_point_type = type(edge) + + for field_name in edge._metadata["index_fields"]: + index_name = f"{data_point_type.__tablename__}.{field_name}" + + if index_name not in created_indexes: + await vector_engine.create_vector_index(data_point_type.__tablename__, field_name) + created_indexes[index_name] = True + + if index_name not in index_points: + index_points[index_name] = [] + + indexed_data_point = edge.model_copy() + indexed_data_point._metadata["index_fields"] = [field_name] + index_points[index_name].append(indexed_data_point) + + for index_name, indexable_points in index_points.items(): + index_name, field_name = index_name.split(".") + await vector_engine.index_data_points(index_name, field_name, indexable_points) + + return None async def index_graph_edges(): diff --git a/cognee/tasks/temporal_awareness/graphiti_model.py b/cognee/tasks/temporal_awareness/graphiti_model.py new file mode 100644 index 000000000..fa255b542 --- /dev/null +++ b/cognee/tasks/temporal_awareness/graphiti_model.py @@ -0,0 +1,12 @@ +from cognee.infrastructure.engine import DataPoint +from typing import ClassVar, Optional + + +class GraphitiNode(DataPoint): + __tablename__ = "graphitinode" + content: Optional[str] = None + name: Optional[str] = None + summary: Optional[str] = None + pydantic_type: str = "GraphitiNode" + + _metadata: dict = {"index_fields": ["name", "summary", "content"], "type": "GraphitiNode"} diff --git a/examples/python/graphiti_example.py b/examples/python/graphiti_example.py index 248361321..156b372c1 100644 --- a/examples/python/graphiti_example.py +++ b/examples/python/graphiti_example.py @@ -7,6 +7,10 @@ from cognee.tasks.temporal_awareness import ( build_graph_with_temporal_awareness, search_graph_with_temporal_awareness, ) +from cognee.infrastructure.databases.relational import ( + create_db_and_tables as create_relational_db_and_tables, +) +from cognee.tasks.storage.index_graph_edges import index_graphiti_nodes_and_edges text_list = [ "Kamala Harris is the Attorney General of California. She was previously " @@ -16,11 +20,15 @@ text_list = [ async def main(): + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await create_relational_db_and_tables() + + for text in text_list: + await cognee.add(text) + tasks = [ Task(build_graph_with_temporal_awareness, text_list=text_list), - Task( - search_graph_with_temporal_awareness, query="Who was the California Attorney General?" - ), ] pipeline = run_tasks(tasks) @@ -28,6 +36,8 @@ async def main(): async for result in pipeline: print(result) + await index_graphiti_nodes_and_edges() + if __name__ == "__main__": asyncio.run(main())