Merge pull request #437 from topoteretes/feature/cog-761-project-graphiti-graph-to-memory

feat: adds cognee node and edge embeddings for graphiti graph
This commit is contained in:
Vasilije 2025-01-16 10:03:31 +01:00 committed by GitHub
commit 1c4a605eb7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 412 additions and 7 deletions

View file

@ -426,6 +426,15 @@ class Neo4jAdapter(GraphDBInterface):
return serialized_properties
async def get_model_independent_graph_data(self):
query_nodes = "MATCH (n) RETURN collect(n) AS nodes"
nodes = await self.query(query_nodes)
query_edges = "MATCH (n)-[r]->(m) RETURN collect([n, r, m]) AS elements"
edges = await self.query(query_edges)
return (nodes, edges)
async def get_graph_data(self):
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"

View file

@ -0,0 +1,12 @@
from cognee.infrastructure.engine import DataPoint
from typing import ClassVar, Optional
class GraphitiNode(DataPoint):
__tablename__ = "graphitinode"
content: Optional[str] = None
name: Optional[str] = None
summary: Optional[str] = None
pydantic_type: str = "GraphitiNode"
_metadata: dict = {"index_fields": ["name", "summary", "content"], "type": "GraphitiNode"}

View file

@ -0,0 +1,84 @@
import logging
from collections import Counter
from cognee.tasks.temporal_awareness.graphiti_model import GraphitiNode
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.models.EdgeType import EdgeType
async def index_and_transform_graphiti_nodes_and_edges():
try:
created_indexes = {}
index_points = {}
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
logging.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e
await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""")
await graph_engine.query("""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
r.target_node_id = target.id,
r.relationship_name = type(r) RETURN r""")
await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""")
nodes_data, edges_data = await graph_engine.get_model_independent_graph_data()
for node_data in nodes_data[0]["nodes"]:
graphiti_node = GraphitiNode(
**{key: node_data[key] for key in ("content", "name", "summary") if key in node_data},
id=node_data.get("uuid"),
)
data_point_type = type(graphiti_node)
for field_name in graphiti_node._metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"
if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
created_indexes[index_name] = True
if index_name not in index_points:
index_points[index_name] = []
if getattr(graphiti_node, field_name, None) is not None:
indexed_data_point = graphiti_node.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)
for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
await vector_engine.index_data_points(index_name, field_name, indexable_points)
edge_types = Counter(
edge[1][1]
for edge in edges_data[0]["elements"]
if isinstance(edge, list) and len(edge) == 3
)
for text, count in edge_types.items():
edge = EdgeType(relationship_name=text, number_of_edges=count)
data_point_type = type(edge)
for field_name in edge._metadata["index_fields"]:
index_name = f"{data_point_type.__tablename__}.{field_name}"
if index_name not in created_indexes:
await vector_engine.create_vector_index(data_point_type.__tablename__, field_name)
created_indexes[index_name] = True
if index_name not in index_points:
index_points[index_name] = []
indexed_data_point = edge.model_copy()
indexed_data_point._metadata["index_fields"] = [field_name]
index_points[index_name].append(indexed_data_point)
for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
await vector_engine.index_data_points(index_name, field_name, indexable_points)
return None

View file

@ -1,12 +1,20 @@
import asyncio
import cognee
from cognee.api.v1.search import SearchType
import logging
from cognee.modules.pipelines import Task, run_tasks
from cognee.tasks.temporal_awareness import (
build_graph_with_temporal_awareness,
search_graph_with_temporal_awareness,
from cognee.shared.utils import setup_logging
from cognee.tasks.temporal_awareness import build_graph_with_temporal_awareness
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
)
from cognee.tasks.temporal_awareness.index_graphiti_objects import (
index_and_transform_graphiti_nodes_and_edges,
)
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client
text_list = [
"Kamala Harris is the Attorney General of California. She was previously "
@ -16,11 +24,15 @@ text_list = [
async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await create_relational_db_and_tables()
for text in text_list:
await cognee.add(text)
tasks = [
Task(build_graph_with_temporal_awareness, text_list=text_list),
Task(
search_graph_with_temporal_awareness, query="Who was the California Attorney General?"
),
]
pipeline = run_tasks(tasks)
@ -28,6 +40,33 @@ async def main():
async for result in pipeline:
print(result)
await index_and_transform_graphiti_nodes_and_edges()
query = "When was Kamala Harris in office?"
triplets = await brute_force_triplet_search(
query=query,
top_k=3,
collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"],
)
args = {
"question": query,
"context": retrieved_edges_to_string(triplets),
}
user_prompt = render_prompt("graph_context_for_question.txt", args)
system_prompt = read_query_prompt("answer_simple_question_restricted.txt")
llm_client = get_llm_client()
computed_answer = await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
print(computed_answer)
if __name__ == "__main__":
setup_logging(logging.ERROR)
asyncio.run(main())

View file

@ -0,0 +1,261 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": "# Cognee Graphiti integration demo"
},
{
"cell_type": "markdown",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"source": "First we import the necessary libaries"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import cognee\n",
"import logging\n",
"import warnings\n",
"from cognee.modules.pipelines import Task, run_tasks\n",
"from cognee.shared.utils import setup_logging\n",
"from cognee.tasks.temporal_awareness import build_graph_with_temporal_awareness\n",
"from cognee.infrastructure.databases.relational import (\n",
" create_db_and_tables as create_relational_db_and_tables,\n",
")\n",
"from cognee.tasks.temporal_awareness.index_graphiti_objects import (\n",
" index_and_transform_graphiti_nodes_and_edges,\n",
")\n",
"from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search\n",
"from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string\n",
"from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt\n",
"from cognee.infrastructure.llm.get_llm_client import get_llm_client"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set environment variables"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-15T10:43:57.893763Z",
"start_time": "2025-01-15T10:43:57.891332Z"
}
},
"source": [
"import os\n",
"\n",
"# We ignore warnigns for now\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"# API key for cognee\n",
"if \"LLM_API_KEY\" not in os.environ:\n",
" os.environ[\"LLM_API_KEY\"] = \"\"\n",
"\n",
"# API key for graphiti\n",
"if \"OPENAI_API_KEY\" not in os.environ:\n",
" os.environ[\"OPENAI_API_KEY\"] = \"\"\n",
"\n",
"# Graphiti integration is only tested with neo4j + pgvector + postgres for now\n",
"GRAPH_DATABASE_PROVIDER=\"neo4j\"\n",
"GRAPH_DATABASE_URL=\"bolt://localhost:7687\"\n",
"GRAPH_DATABASE_USERNAME=\"neo4j\"\n",
"GRAPH_DATABASE_PASSWORD=\"pleaseletmein\"\n",
"\n",
"os.environ[\"VECTOR_DB_PROVIDER\"] = \"pgvector\"\n",
"\n",
"os.environ[\"DB_PROVIDER\"] = \"postgres\"\n",
"\n",
"os.environ[\"DB_NAME\"] = \"cognee_db\"\n",
"\n",
"os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
"os.environ[\"DB_PORT\"]=\"5432\"\n",
"os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
"os.environ[\"DB_PASSWORD\"]=\"cognee\""
],
"outputs": [],
"execution_count": 2
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Input texts with temporal information"
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-15T10:43:57.928664Z",
"start_time": "2025-01-15T10:43:57.927105Z"
}
},
"source": [
"text_list = [\n",
" \"Kamala Harris is the Attorney General of California. She was previously \"\n",
" \"the district attorney for San Francisco.\",\n",
" \"As AG, Harris was in office from January 3, 2011 January 3, 2017\",\n",
"]"
],
"outputs": [],
"execution_count": 3
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Running graphiti + transforming its graph into cognee's core system (graph transformation + vector embeddings)"
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-15T10:44:25.008501Z",
"start_time": "2025-01-15T10:43:57.932240Z"
}
},
"source": [
"# 🔧 Setting Up Logging to Suppress Errors\n",
"setup_logging(logging.ERROR) # Keeping logs clean and focused\n",
"\n",
"# 🧹 Pruning Old Data and Metadata\n",
"await cognee.prune.prune_data() # Removing outdated data\n",
"await cognee.prune.prune_system(metadata=True)\n",
"\n",
"# 🏗️ Creating Relational Database and Tables\n",
"await create_relational_db_and_tables()\n",
"\n",
"# 📚 Adding Text Data to Cognee\n",
"for text in text_list:\n",
" await cognee.add(text)\n",
"\n",
"# 🕰️ Building Temporal-Aware Graphs\n",
"tasks = [\n",
" Task(build_graph_with_temporal_awareness, text_list=text_list),\n",
"]\n",
"\n",
"# 🚀 Running the Task Pipeline\n",
"pipeline = run_tasks(tasks)\n",
"\n",
"# 🌟 Processing Pipeline Results\n",
"async for result in pipeline:\n",
" print(f\"✅ Result Processed: {result}\")\n",
"\n",
"# 🔄 Indexing and Transforming Graph Data\n",
"await index_and_transform_graphiti_nodes_and_edges()"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Database deleted successfully.\n",
"Database deleted successfully.\n",
"User d3b51a32-38e1-4fe5-8270-6dc1d6ebfdf0 has registered.\n",
"Pipeline file_load_from_filesystem load step completed in 0.10 seconds\n",
"1 load package(s) were loaded to destination sqlalchemy and into dataset public\n",
"The sqlalchemy destination used postgresql://cognee:***@127.0.0.1:5432/cognee_db location to store data\n",
"Load package 1736937839.7739599 is LOADED and contains no failed jobs\n",
"Pipeline file_load_from_filesystem load step completed in 0.06 seconds\n",
"1 load package(s) were loaded to destination sqlalchemy and into dataset public\n",
"The sqlalchemy destination used postgresql://cognee:***@127.0.0.1:5432/cognee_db location to store data\n",
"Load package 1736937841.8467042 is LOADED and contains no failed jobs\n",
"Graph database initialized.\n",
"Added text: Kamala Harris is the Attorney Gener...\n",
"Added text: As AG, Harris was in office from Ja...\n",
"✅ Result Processed: <graphiti_core.graphiti.Graphiti object at 0x326fe0ce0>\n"
]
}
],
"execution_count": 4
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Retrieving and generating answer from graphiti graph with cognee retriever"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-15T10:44:27.844438Z",
"start_time": "2025-01-15T10:44:25.013325Z"
}
},
"cell_type": "code",
"source": [
"# Step 1: Formulating the Query 🔍\n",
"query = \"When was Kamala Harris in office?\"\n",
"\n",
"# Step 2: Searching for Relevant Triplets 📊\n",
"triplets = await brute_force_triplet_search(\n",
" query=query,\n",
" top_k=3,\n",
" collections=[\"graphitinode_content\", \"graphitinode_name\", \"graphitinode_summary\"],\n",
")\n",
"\n",
"# Step 3: Preparing the Context for the LLM\n",
"context = retrieved_edges_to_string(triplets)\n",
"\n",
"args = {\n",
" \"question\": query,\n",
" \"context\": context\n",
"}\n",
"\n",
"# Step 4: Generating Prompts ✍️\n",
"user_prompt = render_prompt(\"graph_context_for_question.txt\", args)\n",
"system_prompt = read_query_prompt(\"answer_simple_question_restricted.txt\")\n",
"\n",
"# Step 5: Interacting with the LLM 🤖\n",
"llm_client = get_llm_client()\n",
"computed_answer = await llm_client.acreate_structured_output(\n",
" text_input=user_prompt, # Input prompt for the user context\n",
" system_prompt=system_prompt, # System-level instructions for the model\n",
" response_model=str,\n",
")\n",
"\n",
"# Step 6: Displaying the Computed Answer ✨\n",
"print(f\"💡 Answer: {computed_answer}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"💡 Answer: Kamala Harris was in office as Attorney General of California from January 3, 2011, to January 3, 2017.\n"
]
}
],
"execution_count": 5
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}