diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 00000000..23540544 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,3 @@ +from .graphiti import Graphiti + +__all__ = ["Graphiti"] diff --git a/core/edges.py b/core/edges.py index bc0ff449..47cfbfed 100644 --- a/core/edges.py +++ b/core/edges.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from pydantic import BaseModel, Field from datetime import datetime from neo4j import AsyncDriver -from uuid import uuid1 +from uuid import uuid4 import logging from core.nodes import Node @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) class Edge(BaseModel, ABC): - uuid: str = Field(default_factory=lambda: uuid1().hex) + uuid: str = Field(default_factory=lambda: str(uuid4())) source_node: Node target_node: Node created_at: datetime @@ -22,6 +22,11 @@ class Edge(BaseModel, ABC): class EpisodicEdge(Edge): async def save(self, driver: AsyncDriver): + if self.uuid is None: + uuid = uuid4() + logger.info(f"Created uuid: {uuid} for episodic edge") + self.uuid = str(uuid) + result = await driver.execute_query( """ MATCH (episode:Episodic {uuid: $episode_uuid}) @@ -45,13 +50,25 @@ class EpisodicEdge(Edge): class EntityEdge(Edge): - name: str - fact: str - fact_embedding: list[float] = None - episodes: list[str] = None # list of episode ids that reference these entity edges - expired_at: datetime = None # datetime of when the node was invalidated - valid_at: datetime = None # datetime of when the fact became true - invalid_at: datetime = None # datetime of when the fact stopped being true + name: str = Field(description="name of the edge, relation name") + fact: str = Field( + description="fact representing the edge and nodes that it connects" + ) + fact_embedding: list[float] | None = Field( + default=None, description="embedding of the fact" + ) + episodes: list[str] | None = Field( + default=None, description="list of episode ids that reference these entity edges" + ) + expired_at: datetime | None = Field( + default=None, description="datetime of when the node was invalidated" + ) + valid_at: datetime | None = Field( + default=None, description="datetime of when the fact became true" + ) + invalid_at: datetime | None = Field( + default=None, description="datetime of when the fact stopped being true" + ) def generate_embedding(self, embedder, model="text-embedding-3-large"): text = self.fact.replace("\n", " ") @@ -62,6 +79,7 @@ class EntityEdge(Edge): async def save(self, driver: AsyncDriver): result = await driver.execute_query( + """ MATCH (source:Entity {uuid: $source_uuid}) MATCH (target:Entity {uuid: $target_uuid}) diff --git a/core/graphiti.py b/core/graphiti.py index 00c5a273..d86a8014 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -1,47 +1,124 @@ import asyncio from datetime import datetime import logging -from typing import Callable, Tuple, LiteralString +from typing import Callable, LiteralString, Tuple from neo4j import AsyncGraphDatabase - +from dotenv import load_dotenv +import os from core.nodes import EntityNode, EpisodicNode, Node from core.edges import EntityEdge, Edge -from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges +from core.utils import ( + build_episodic_edges, + retrieve_relevant_schema, + extract_new_edges, + extract_new_nodes, + clear_data, + retrieve_episodes, +) +from core.llm_client import LLMClient, OpenAIClient, LLMConfig logger = logging.getLogger(__name__) - -class LLMConfig: - """Configuration for the language model""" - - def __init__( - self, - api_key: str, - model: str = "gpt-4o", - base_url: str = "https://api.openai.com", - ): - self.base_url = base_url - self.api_key = api_key - self.model = model +load_dotenv() class Graphiti: def __init__( - self, uri: str, user: str, password: str, llm_config: LLMConfig | None + self, uri: str, user: str, password: str, llm_client: LLMClient | None = None ): self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) self.database = "neo4j" - - self.build_indices() - - if llm_config: - self.llm_config = llm_config + if llm_client: + self.llm_client = llm_client else: - self.llm_config = None + self.llm_client = OpenAIClient( + LLMConfig( + api_key=os.getenv("OPENAI_API_KEY"), + model="gpt-4o", + base_url="https://api.openai.com/v1", + ) + ) def close(self): self.driver.close() + async def retrieve_episodes( + self, last_n: int, sources: list[str] | None = "messages" + ) -> list[EpisodicNode]: + """Retrieve the last n episodic nodes from the graph""" + return await retrieve_episodes(self.driver, last_n, sources) + + async def retrieve_relevant_schema(self, query: str = None) -> dict[str, any]: + """Retrieve relevant nodes and edges to a specific query""" + return await retrieve_relevant_schema(self.driver, query) + ... + + # Invalidate edges that are no longer valid + async def invalidate_edges( + self, + episode: EpisodicNode, + new_nodes: list[EntityNode], + new_edges: list[EntityEdge], + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], + ): ... + + async def add_episode( + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime = None, + episode_type="string", + success_callback: Callable | None = None, + error_callback: Callable | None = None, + ): + """Process an episode and update the graph""" + try: + nodes: list[Node] = [] + edges: list[Edge] = [] + previous_episodes = await self.retrieve_episodes(last_n=3) + episode = EpisodicNode( + name=name, + labels=[], + source="messages", + content=episode_body, + source_description=source_description, + created_at=datetime.now(), + valid_at=reference_time, + ) + await episode.save(self.driver) + relevant_schema = await self.retrieve_relevant_schema(episode.content) + new_nodes = await extract_new_nodes( + self.llm_client, episode, relevant_schema, previous_episodes + ) + nodes.extend(new_nodes) + new_edges = await extract_new_edges( + self.llm_client, episode, new_nodes, relevant_schema, previous_episodes + ) + edges.extend(new_edges) + episodic_edges = build_episodic_edges(nodes, episode, datetime.now()) + edges.extend(episodic_edges) + + # invalidated_edges = await self.invalidate_edges( + # episode, new_nodes, new_edges, relevant_schema, previous_episodes + # ) + + # edges.extend(invalidated_edges) + # Future optimization would be using batch operations to save nodes and edges + await asyncio.gather(*[node.save(self.driver) for node in nodes]) + await asyncio.gather(*[edge.save(self.driver) for edge in edges]) + # for node in nodes: + # if isinstance(node, EntityNode): + # await node.update_summary(self.driver) + if success_callback: + await success_callback(episode) + except Exception as e: + if error_callback: + await error_callback(episode, e) + else: + raise e + async def build_indices(self): index_queries: list[LiteralString] = [ "CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)", @@ -76,18 +153,9 @@ class Graphiti: """ ) - async def retrieve_episodes( - self, last_n: int, sources: list[str] | None = "messages" - ) -> list[EpisodicNode]: - """Retrieve the last n episodic nodes from the graph""" - ... - - # Utility function, to be removed from this class - async def clear_data(self): ... - async def search( self, query: str, config - ) -> (list)[Tuple[EntityNode, list[EntityEdge]]]: + ) -> (list)[tuple[EntityNode, list[EntityEdge]]]: (vec_nodes, vec_edges) = similarity_search(query, embedder) (text_nodes, text_edges) = fulltext_search(query) @@ -102,32 +170,6 @@ class Graphiti: return [(node, edges)], episodes - async def get_relevant_schema( - self, episode: EpisodicNode, previous_episodes: list[EpisodicNode] - ) -> list[Tuple[EntityNode, list[EntityEdge]]]: - pass - - # Call llm with the specified messages, and return the response - # Will be used in the conjunction with a prompt library - async def generate_llm_response(self, messages: list[any]) -> str: ... - - # Extract new edges from the episode - async def extract_new_edges( - self, - episode: EpisodicNode, - new_nodes: list[EntityNode], - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], - ) -> list[EntityEdge]: ... - - # Extract new nodes from the episode - async def extract_new_nodes( - self, - episode: EpisodicNode, - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], - ) -> list[EntityNode]: ... - # Invalidate edges that are no longer valid async def invalidate_edges( self, @@ -137,51 +179,3 @@ class Graphiti: relevant_schema: dict[str, any], previous_episodes: list[EpisodicNode], ): ... - - async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime = None, - episode_type="string", - success_callback: Callable | None = None, - error_callback: Callable | None = None, - ): - """Process an episode and update the graph""" - try: - nodes: list[Node] = [] - edges: list[Edge] = [] - previous_episodes = await self.retrieve_episodes(last_n=3) - episode = EpisodicNode() - await episode.save(self.driver) - relevant_schema = await self.retrieve_relevant_schema(episode.content) - new_nodes = await self.extract_new_nodes( - episode, relevant_schema, previous_episodes - ) - nodes.extend(new_nodes) - new_edges = await self.extract_new_edges( - episode, new_nodes, relevant_schema, previous_episodes - ) - edges.extend(new_edges) - episodic_edges = build_episodic_edges(nodes, episode, datetime.now()) - edges.extend(episodic_edges) - - invalidated_edges = await self.invalidate_edges( - episode, new_nodes, new_edges, relevant_schema, previous_episodes - ) - - edges.extend(invalidated_edges) - - await asyncio.gather(*[node.save(self.driver) for node in nodes]) - await asyncio.gather(*[edge.save(self.driver) for edge in edges]) - for node in nodes: - if isinstance(node, EntityNode): - await node.update_summary(self.driver) - if success_callback: - await success_callback(episode) - except Exception as e: - if error_callback: - await error_callback(episode, e) - else: - raise e diff --git a/core/llm_client/__init__.py b/core/llm_client/__init__.py new file mode 100644 index 00000000..67571a59 --- /dev/null +++ b/core/llm_client/__init__.py @@ -0,0 +1,5 @@ +from .client import LLMClient +from .openai_client import OpenAIClient +from .config import LLMConfig + +__all__ = ["LLMClient", "OpenAIClient", "LLMConfig"] diff --git a/core/llm_client/client.py b/core/llm_client/client.py new file mode 100644 index 00000000..926a188d --- /dev/null +++ b/core/llm_client/client.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from .config import LLMConfig + + +class LLMClient(ABC): + @abstractmethod + def __init__(self, config: LLMConfig): + pass + + @abstractmethod + async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]: + pass diff --git a/core/llm_client/config.py b/core/llm_client/config.py new file mode 100644 index 00000000..af9dad41 --- /dev/null +++ b/core/llm_client/config.py @@ -0,0 +1,33 @@ +class LLMConfig: + """ + Configuration class for the Language Learning Model (LLM). + + This class encapsulates the necessary parameters to interact with an LLM API, + such as OpenAI's GPT models. It stores the API key, model name, and base URL + for making requests to the LLM service. + """ + + def __init__( + self, + api_key: str, + model: str = "gpt-4o", + base_url: str = "https://api.openai.com", + ): + """ + Initialize the LLMConfig with the provided parameters. + + Args: + api_key (str): The authentication key for accessing the LLM API. + This is required for making authorized requests. + + model (str, optional): The specific LLM model to use for generating responses. + Defaults to "gpt-4o", which appears to be a custom model name. + Common values might include "gpt-3.5-turbo" or "gpt-4". + + base_url (str, optional): The base URL of the LLM API service. + Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint. + This can be changed if using a different provider or a custom endpoint. + """ + self.base_url = base_url + self.api_key = api_key + self.model = model diff --git a/core/llm_client/openai_client.py b/core/llm_client/openai_client.py new file mode 100644 index 00000000..27903baf --- /dev/null +++ b/core/llm_client/openai_client.py @@ -0,0 +1,24 @@ +import json +from openai import AsyncOpenAI +from .client import LLMClient +from .config import LLMConfig + + +class OpenAIClient(LLMClient): + def __init__(self, config: LLMConfig): + self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + self.model = config.model + + async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]: + try: + response = await self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.1, + max_tokens=3000, + response_format={"type": "json_object"}, + ) + return json.loads(response.choices[0].message.content) + except Exception as e: + print(f"Error in generating LLM response: {e}") + raise diff --git a/core/nodes.py b/core/nodes.py index 61fcd7b9..82dbccd9 100644 --- a/core/nodes.py +++ b/core/nodes.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod +from pydantic import Field from datetime import datetime -from uuid import uuid1 +from uuid import uuid4 from openai import OpenAI from pydantic import BaseModel, Field @@ -11,9 +12,9 @@ logger = logging.getLogger(__name__) class Node(BaseModel, ABC): - uuid: str = Field(default_factory=lambda: uuid1().hex) + uuid: str = Field(default_factory=lambda: str(uuid4())) name: str - labels: list[str] + labels: list[str] = Field(default_factory=list) created_at: datetime @abstractmethod @@ -21,11 +22,17 @@ class Node(BaseModel, ABC): class EpisodicNode(Node): - source: str # source type - source_description: str # description of the data source - content: str # raw episode data - entity_edges: list[str] # list of entity edge ids referenced in this episode - valid_at: datetime = None # datetime of when the original document was created + source: str = Field(description="source type") + source_description: str = Field(description="description of the data source") + content: str = Field(description="raw episode data") + entity_edges: list[str] = Field( + description="list of entity edges referenced in this episode", + default_factory=list, + ) + valid_at: datetime | None = Field( + description="datetime of when the original document was created", + default=None, + ) async def save(self, driver: AsyncDriver): result = await driver.execute_query( @@ -51,7 +58,9 @@ class EpisodicNode(Node): class EntityNode(Node): - summary: str # regional summary of surrounding edges + summary: str = Field(description="regional summary of surrounding edges") + + async def update_summary(self, driver: AsyncDriver): ... async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... diff --git a/core/prompts/__init__.py b/core/prompts/__init__.py new file mode 100644 index 00000000..2dd3483c --- /dev/null +++ b/core/prompts/__init__.py @@ -0,0 +1,4 @@ +from .lib import prompt_library +from .models import Message + +__all__ = ["prompt_library", "Message"] diff --git a/core/prompts/extract_edges.py b/core/prompts/extract_edges.py new file mode 100644 index 00000000..9e9c3750 --- /dev/null +++ b/core/prompts/extract_edges.py @@ -0,0 +1,73 @@ +import json +from typing import TypedDict, Protocol + +from .models import Message, PromptVersion, PromptFunction + + +class Prompt(Protocol): + v1: PromptVersion + + +class Versions(TypedDict): + v1: PromptFunction + + +def v1(context: dict[str, any]) -> list[Message]: + return [ + Message( + role="system", + content="You are a helpful assistant that extracts graph edges from provided context.", + ), + Message( + role="user", + content=f""" + Given the following context, extract new semantic edges (relationships) that need to be added to the knowledge graph: + + Current Graph Structure: + {context['relevant_schema']} + + New Nodes: + {json.dumps(context['new_nodes'], indent=2)} + + New Episode: + Content: {context['episode_content']} + Timestamp: {context['episode_timestamp']} + + Previous Episodes: + {json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)} + + Extract new semantic edges based on the content of the current episode, considering the existing graph structure, new nodes, and context from previous episodes. + + Guidelines: + 1. Create edges only between semantic nodes (not episodic nodes like messages). + 2. Each edge should represent a clear relationship between two semantic nodes. + 3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR). + 4. Provide a more detailed fact describing the relationship. + 5. If a relationship seems to update an existing one, create a new edge with the updated information. + 6. Consider temporal aspects of relationships when relevant. + 7. Do not create edges involving episodic nodes (like Message 1 or Message 2). + 8. Use existing nodes from the current graph structure when appropriate. + + Respond with a JSON object in the following format: + {{ + "new_edges": [ + {{ + "relation_type": "RELATION_TYPE_IN_CAPS", + "source_node": "Name of the source semantic node", + "target_node": "Name of the target semantic node", + "fact": "Detailed description of the relationship", + "valid_at": "YYYY-MM-DDTHH:MM:SSZ or null if not explicitly mentioned", + "invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null if ongoing or not explicitly mentioned" + }} + ] + }} + + If no new edges need to be added, return an empty list for "new_edges". + """, + ), + ] + + +versions: Versions = { + "v1": v1, +} diff --git a/core/prompts/extract_nodes.py b/core/prompts/extract_nodes.py new file mode 100644 index 00000000..1d171943 --- /dev/null +++ b/core/prompts/extract_nodes.py @@ -0,0 +1,65 @@ +import json +from typing import TypedDict, Protocol + +from .models import Message, PromptVersion, PromptFunction + + +class Prompt(Protocol): + v1: PromptVersion + + +class Versions(TypedDict): + v1: PromptFunction + + +def v1(context: dict[str, any]) -> list[Message]: + return [ + Message( + role="system", + content="You are a helpful assistant that extracts graph nodes from provided context.", + ), + Message( + role="user", + content=f""" + Given the following context, extract new semantic nodes that need to be added to the knowledge graph: + + Existing Nodes: + {json.dumps(context['existing_nodes'], indent=2)} + + Previous Episodes: + {json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)} + + New Episode: + Content: {context["episode_content"]} + Timestamp: {context['episode_timestamp']} + + Extract new semantic nodes based on the content of the current episode, while considering the existing nodes and context from previous episodes. + + Guidelines: + 1. Only extract new nodes that don't already exist in the graph structure. + 2. Focus on entities, concepts, or actors that are central to the current episode. + 3. Avoid creating nodes for relationships or actions (these will be handled as edges later). + 4. Provide a brief but informative summary for each node. + 5. If a node seems to represent an existing concept but with updated information, don't create a new node. This will be handled by edge updates. + 6. Do not create nodes for episodic content (like Message 1 or Message 2). + + Respond with a JSON object in the following format: + {{ + "new_nodes": [ + {{ + "name": "Unique identifier for the node", + "labels": ["Semantic", "OptionalAdditionalLabel"], + "summary": "Brief summary of the node's role or significance" + }} + ] + }} + + If no new nodes need to be added, return an empty list for "new_nodes". + """, + ), + ] + + +versions: Versions = { + "v1": v1, +} diff --git a/core/prompts/lib.py b/core/prompts/lib.py new file mode 100644 index 00000000..7d47a650 --- /dev/null +++ b/core/prompts/lib.py @@ -0,0 +1,53 @@ +from typing import TypedDict, Protocol +from .models import Message, PromptFunction +from typing import TypedDict, Protocol +from .models import Message, PromptFunction +from .extract_nodes import ( + Prompt as ExtractNodesPrompt, + Versions as ExtractNodesVersions, + versions as extract_nodes_versions, +) + +from .extract_edges import ( + Prompt as ExtractEdgesPrompt, + Versions as ExtractEdgesVersions, + versions as extract_edges_versions, +) + + +class PromptLibrary(Protocol): + extract_nodes: ExtractNodesPrompt + extract_edges: ExtractEdgesPrompt + + +class PromptLibraryImpl(TypedDict): + extract_nodes: ExtractNodesVersions + extract_edges: ExtractEdgesVersions + + +class VersionWrapper: + def __init__(self, func: PromptFunction): + self.func = func + + def __call__(self, context: dict[str, any]) -> list[Message]: + return self.func(context) + + +class PromptTypeWrapper: + def __init__(self, versions: dict[str, PromptFunction]): + for version, func in versions.items(): + setattr(self, version, VersionWrapper(func)) + + +class PromptLibraryWrapper: + def __init__(self, library: PromptLibraryImpl): + for prompt_type, versions in library.items(): + setattr(self, prompt_type, PromptTypeWrapper(versions)) + + +PROMPT_LIBRARY_IMPL: PromptLibraryImpl = { + "extract_nodes": extract_nodes_versions, + "extract_edges": extract_edges_versions, +} + +prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) diff --git a/core/prompts/models.py b/core/prompts/models.py new file mode 100644 index 00000000..cc63e597 --- /dev/null +++ b/core/prompts/models.py @@ -0,0 +1,15 @@ +from typing import Callable, Protocol + +from pydantic import BaseModel + + +class Message(BaseModel): + role: str + content: str + + +class PromptVersion(Protocol): + def __call__(self, context: dict[str, any]) -> list[Message]: ... + + +PromptFunction = Callable[[dict[str, any]], list[Message]] diff --git a/core/utils/__init__.py b/core/utils/__init__.py new file mode 100644 index 00000000..581acf75 --- /dev/null +++ b/core/utils/__init__.py @@ -0,0 +1,17 @@ +from .maintenance import ( + extract_new_edges, + build_episodic_edges, + extract_new_nodes, + clear_data, + retrieve_relevant_schema, + retrieve_episodes, +) + +__all__ = [ + "extract_new_edges", + "build_episodic_edges", + "extract_new_nodes", + "clear_data", + "retrieve_relevant_schema", + "retrieve_episodes", +] diff --git a/core/utils/maintenance/__init__.py b/core/utils/maintenance/__init__.py new file mode 100644 index 00000000..79e6536b --- /dev/null +++ b/core/utils/maintenance/__init__.py @@ -0,0 +1,16 @@ +from .edge_operations import extract_new_edges, build_episodic_edges +from .node_operations import extract_new_nodes +from .graph_data_operations import ( + clear_data, + retrieve_relevant_schema, + retrieve_episodes, +) + +__all__ = [ + "extract_new_edges", + "build_episodic_edges", + "extract_new_nodes", + "clear_data", + "retrieve_relevant_schema", + "retrieve_episodes", +] diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py new file mode 100644 index 00000000..925ced9d --- /dev/null +++ b/core/utils/maintenance/edge_operations.py @@ -0,0 +1,128 @@ +import json +from typing import List +from datetime import datetime + +from core.nodes import EntityNode, EpisodicNode +from core.edges import EpisodicEdge, EntityEdge +import logging + +from core.prompts import prompt_library +from core.llm_client import LLMClient + +logger = logging.getLogger(__name__) + + +def build_episodic_edges( + semantic_nodes: List[EntityNode], + episode: EpisodicNode, + transaction_from: datetime, +) -> List[EpisodicEdge]: + edges: List[EpisodicEdge] = [] + + for node in semantic_nodes: + edge = EpisodicEdge( + source_node=episode, target_node=node, created_at=transaction_from + ) + edges.append(edge) + + return edges + + +async def extract_new_edges( + llm_client: LLMClient, + episode: EpisodicNode, + new_nodes: list[EntityNode], + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], +) -> list[EntityEdge]: + # Prepare context for LLM + context = { + "episode_content": episode.content, + "episode_timestamp": ( + episode.valid_at.isoformat() if episode.valid_at else None + ), + "relevant_schema": json.dumps(relevant_schema, indent=2), + "new_nodes": [ + {"name": node.name, "summary": node.summary} for node in new_nodes + ], + "previous_episodes": [ + { + "content": ep.content, + "timestamp": ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response( + prompt_library.extract_edges.v1(context) + ) + new_edges_data = llm_response.get("new_edges", []) + + # Convert the extracted data into EntityEdge objects + new_edges = [] + for edge_data in new_edges_data: + source_node = next( + (node for node in new_nodes if node.name == edge_data["source_node"]), + None, + ) + target_node = next( + (node for node in new_nodes if node.name == edge_data["target_node"]), + None, + ) + + # If source or target is not in new_nodes, check if it's an existing node + if source_node is None and edge_data["source_node"] in relevant_schema["nodes"]: + existing_node_data = relevant_schema["nodes"][edge_data["source_node"]] + source_node = EntityNode( + uuid=existing_node_data["uuid"], + name=edge_data["source_node"], + labels=[existing_node_data["label"]], + summary="", + created_at=datetime.now(), + ) + if target_node is None and edge_data["target_node"] in relevant_schema["nodes"]: + existing_node_data = relevant_schema["nodes"][edge_data["target_node"]] + target_node = EntityNode( + uuid=existing_node_data["uuid"], + name=edge_data["target_node"], + labels=[existing_node_data["label"]], + summary="", + created_at=datetime.now(), + ) + + if ( + source_node + and target_node + and not ( + source_node.name.startswith("Message") + or target_node.name.startswith("Message") + ) + ): + valid_at = ( + datetime.fromisoformat(edge_data["valid_at"]) + if edge_data["valid_at"] + else episode.valid_at or datetime.now() + ) + invalid_at = ( + datetime.fromisoformat(edge_data["invalid_at"]) + if edge_data["invalid_at"] + else None + ) + + new_edge = EntityEdge( + source_node=source_node, + target_node=target_node, + name=edge_data["relation_type"], + fact=edge_data["fact"], + episodes=[episode.uuid], + created_at=datetime.now(), + valid_at=valid_at, + invalid_at=invalid_at, + ) + new_edges.append(new_edge) + logger.info( + f"Created new edge: {new_edge.name} from {source_node.name} (UUID: {source_node.uuid}) to {target_node.name} (UUID: {target_node.uuid})" + ) + + return new_edges diff --git a/core/utils/maintenance/graph_data_operations.py b/core/utils/maintenance/graph_data_operations.py new file mode 100644 index 00000000..5f4333c5 --- /dev/null +++ b/core/utils/maintenance/graph_data_operations.py @@ -0,0 +1,95 @@ +from datetime import datetime, timezone + +from core.nodes import EpisodicNode +from neo4j import AsyncDriver +import logging + + +logger = logging.getLogger(__name__) + + +async def clear_data(driver: AsyncDriver): + async with driver.session() as session: + + async def delete_all(tx): + await tx.run("MATCH (n) DETACH DELETE n") + + await session.execute_write(delete_all) + + +async def retrieve_relevant_schema( + driver: AsyncDriver, query: str = None +) -> dict[str, any]: + async with driver.session() as session: + summary_query = """ + MATCH (n) + OPTIONAL MATCH (n)-[r]->(m) + RETURN DISTINCT labels(n) AS node_labels, n.uuid AS node_uuid, n.name AS node_name, + type(r) AS relationship_type, r.name AS relationship_name, m.name AS related_node_name + """ + result = await session.run(summary_query) + records = [record async for record in result] + + schema = {"nodes": {}, "relationships": []} + + for record in records: + node_label = record["node_labels"][0] # Assuming one label per node + node_uuid = record["node_uuid"] + node_name = record["node_name"] + rel_type = record["relationship_type"] + rel_name = record["relationship_name"] + related_node = record["related_node_name"] + + if node_name not in schema["nodes"]: + schema["nodes"][node_name] = { + "uuid": node_uuid, + "label": node_label, + "relationships": [], + } + + if rel_type and related_node: + schema["nodes"][node_name]["relationships"].append( + {"type": rel_type, "name": rel_name, "target": related_node} + ) + schema["relationships"].append( + { + "source": node_name, + "type": rel_type, + "name": rel_name, + "target": related_node, + } + ) + + return schema + + +async def retrieve_episodes( + driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages" +) -> list[EpisodicNode]: + """Retrieve the last n episodic nodes from the graph""" + async with driver.session() as session: + query = """ + MATCH (e:EpisodicNode) + RETURN e.content as text, e.timestamp as timestamp, e.reference_timestamp as reference_timestamp + ORDER BY e.timestamp DESC + LIMIT $num_episodes + """ + result = await session.run(query, num_episodes=last_n) + episodes = [ + EpisodicNode( + content=record["text"], + transaction_from=datetime.fromtimestamp( + record["timestamp"].to_native().timestamp(), timezone.utc + ), + valid_at=( + datetime.fromtimestamp( + record["reference_timestamp"].to_native().timestamp(), + timezone.utc, + ) + if record["reference_timestamp"] is not None + else None + ), + ) + async for record in result + ] + return list(reversed(episodes)) # Return in chronological order diff --git a/core/utils/maintenance/node_operations.py b/core/utils/maintenance/node_operations.py new file mode 100644 index 00000000..4c184f90 --- /dev/null +++ b/core/utils/maintenance/node_operations.py @@ -0,0 +1,63 @@ +from datetime import datetime + +from core.nodes import EntityNode, EpisodicNode +import logging +from core.llm_client import LLMClient + +from core.prompts import prompt_library + +logger = logging.getLogger(__name__) + + +async def extract_new_nodes( + llm_client: LLMClient, + episode: EpisodicNode, + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], +) -> list[EntityNode]: + # Prepare context for LLM + existing_nodes = [ + {"name": node_name, "label": node_info["label"], "uuid": node_info["uuid"]} + for node_name, node_info in relevant_schema["nodes"].items() + ] + + context = { + "episode_content": episode.content, + "episode_timestamp": ( + episode.valid_at.isoformat() if episode.valid_at else None + ), + "existing_nodes": existing_nodes, + "previous_episodes": [ + { + "content": ep.content, + "timestamp": ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.v1(context) + ) + new_nodes_data = llm_response.get("new_nodes", []) + logger.info(f"Extracted new nodes: {new_nodes_data}") + # Convert the extracted data into EntityNode objects + new_nodes = [] + for node_data in new_nodes_data: + # Check if the node already exists + if not any( + existing_node["name"] == node_data["name"] + for existing_node in existing_nodes + ): + new_node = EntityNode( + name=node_data["name"], + labels=node_data["labels"], + summary=node_data["summary"], + created_at=datetime.now(), + ) + new_nodes.append(new_node) + logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})") + else: + logger.info(f"Node {node_data['name']} already exists, skipping creation.") + + return new_nodes diff --git a/core/utils/maintenance/utils.py b/core/utils/maintenance/utils.py new file mode 100644 index 00000000..e69de29b diff --git a/runner.py b/runner.py new file mode 100644 index 00000000..db25b65b --- /dev/null +++ b/runner.py @@ -0,0 +1,66 @@ +from core import Graphiti +from core.utils.maintenance.graph_data_operations import clear_data +from dotenv import load_dotenv +import os +import asyncio +import logging +import sys + +load_dotenv() + +neo4j_uri = os.environ.get("NEO4J_URI") or "bolt://localhost:7687" +neo4j_user = os.environ.get("NEO4J_USER") or "neo4j" +neo4j_password = os.environ.get("NEO4J_PASSWORD") or "password" + + +def setup_logging(): + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO + + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + + # Create formatter + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Add formatter to console handler + console_handler.setFormatter(formatter) + + # Add console handler to logger + logger.addHandler(console_handler) + + return logger + + +async def main(): + setup_logging() + client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) + await clear_data(client.driver) + # await client.build_indices() + await client.add_episode( + name="Message 1", + episode_body="Paul: I love apples", + source_description="WhatsApp Message", + ) + await client.add_episode( + name="Message 2", + episode_body="Paul: I love bananas", + source_description="WhatsApp Message", + ) + await client.add_episode( + name="Message 3", + episode_body="Assistant: The best type of apples available are Fuji apples", + source_description="WhatsApp Message", + ) + await client.add_episode( + name="Message 4", + episode_body="Paul: Oh, I actually hate those", + source_description="WhatsApp Message", + ) + + +asyncio.run(main())