From 83c7640d9c765eee210bb7cfdbb91de4b4385c21 Mon Sep 17 00:00:00 2001 From: Pavlo Paliychuk Date: Wed, 14 Aug 2024 10:17:12 -0400 Subject: [PATCH] chore: Initial draft of stubs (#2) * chore: Initial draft of stubs * updates * chore: Add comments and mock implementation of the add_episode method * chore: Add success and error callbacks * stub updates --------- Co-authored-by: prestonrasmussen --- core/edges.py | 52 +++++++++--------- core/graphiti.py | 139 +++++++++++++++++++++++++++++++++++++++++++++-- core/nodes.py | 51 +++++++++-------- core/utils.py | 45 +++++++++++++++ 4 files changed, 232 insertions(+), 55 deletions(-) create mode 100644 core/utils.py diff --git a/core/edges.py b/core/edges.py index 686753eb..ca0749fd 100644 --- a/core/edges.py +++ b/core/edges.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from pydantic import BaseModel +from pydantic import BaseModel, Field from datetime import datetime from neo4j import AsyncDriver from uuid import uuid1 @@ -11,36 +11,35 @@ logger = logging.getLogger(__name__) class Edge(BaseModel, ABC): - uuid: str | None + uuid: Field(default_factory=lambda: uuid1().hex) source_node: Node target_node: Node transaction_from: datetime @abstractmethod - async def save(self, driver: AsyncDriver): - ... + async def save(self, driver: AsyncDriver): ... class EpisodicEdge(Edge): async def save(self, driver: AsyncDriver): - if self.uuid is None: - uuid = uuid1() - logger.info(f'Created uuid: {uuid} for episodic edge') - self.uuid = str(uuid) - - result = await driver.execute_query(""" + result = await driver.execute_query( + """ MATCH (episode:Episodic {uuid: $episode_uuid}) MATCH (node:Semantic {uuid: $semantic_uuid}) MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node) SET r = {uuid: $uuid, transaction_from: $transaction_from} RETURN r.uuid AS uuid""", - episode_uuid=self.source_node.uuid, semantic_uuid=self.target_node.uuid, - uuid=self.uuid, transaction_from=self.transaction_from) + episode_uuid=self.source_node.uuid, + semantic_uuid=self.target_node.uuid, + uuid=self.uuid, + transaction_from=self.transaction_from, + ) - logger.info(f'Saved edge to neo4j: {self.uuid}') + logger.info(f"Saved edge to neo4j: {self.uuid}") return result + # TODO: Neo4j doesn't support variables for edge types and labels. # Right now we have all edge nodes as type RELATES_TO @@ -62,12 +61,8 @@ class SemanticEdge(Edge): return embedding async def save(self, driver: AsyncDriver): - if self.uuid is None: - uuid = uuid1() - logger.info(f'Created uuid: {uuid} for edge with name: {self.name}') - self.uuid = str(uuid) - - result = await driver.execute_query(""" + result = await driver.execute_query( + """ MATCH (source:Semantic {uuid: $source_uuid}) MATCH (target:Semantic {uuid: $target_uuid}) MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) @@ -75,12 +70,19 @@ class SemanticEdge(Edge): episodes: $episodes, transaction_from: $transaction_from, transaction_to: $transaction_to, valid_from: $valid_from, valid_to: $valid_to} RETURN r.uuid AS uuid""", - source_uuid=self.source_node.uuid, - target_uuid=self.target_node.uuid, uuid=self.uuid, name=self.name, fact=self.fact, - fact_embedding=self.fact_embedding, episodes=self.episodes, - transaction_from=self.transaction_from, transaction_to=self.transaction_to, - valid_from=self.valid_from, valid_to=self.valid_to) + source_uuid=self.source_node.uuid, + target_uuid=self.target_node.uuid, + uuid=self.uuid, + name=self.name, + fact=self.fact, + fact_embedding=self.fact_embedding, + episodes=self.episodes, + transaction_from=self.transaction_from, + transaction_to=self.transaction_to, + valid_from=self.valid_from, + valid_to=self.valid_to, + ) - logger.info(f'Saved Node to neo4j: {self.uuid}') + logger.info(f"Saved Node to neo4j: {self.uuid}") return result diff --git a/core/graphiti.py b/core/graphiti.py index 0ad2970f..8ba9060e 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -1,21 +1,148 @@ import asyncio -from typing import Tuple from datetime import datetime import logging - +from typing import Callable, Tuple from neo4j import AsyncGraphDatabase -from openai import OpenAI from core.nodes import SemanticNode, EpisodicNode, Node -from core.edges import SemanticEdge, EpisodicEdge, Edge +from core.edges import SemanticEdge, Edge +from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges logger = logging.getLogger(__name__) +class LLMConfig: + """Configuration for the language model""" + + def __init__( + self, + api_key: str, + model: str = "gpt-4o", + base_url: str = "https://api.openai.com", + ): + self.base_url = base_url + self.api_key = api_key + self.model = model + + class Graphiti: - def __init__(self, uri, user, password): + def __init__( + self, uri: str, user: str, password: str, llm_config: LLMConfig | None + ): self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) self.database = "neo4j" + if llm_config: + self.llm_config = llm_config + else: + self.llm_config = None def close(self): - self.driver.close() \ No newline at end of file + self.driver.close() + + async def retrieve_episodes( + self, last_n: int, sources: list[str] | None = "messages" + ) -> list[EpisodicNode]: + """Retrieve the last n episodic nodes from the graph""" + ... + + # Utility function, to be removed from this class + async def clear_data(self): ... + + async def search(self, query: str, config) -> ( + list)[Tuple[SemanticNode, list[SemanticEdge]]]: + (vec_nodes, vec_edges) = similarity_search(query, embedder) + (text_nodes, text_edges) = fulltext_search(query) + + nodes = vec_nodes.extend(text_nodes) + edges = vec_edges.extend(text_edges) + + results = bfs(nodes, edges, k=1) + + episode_ids = ["Mode of episode ids"] + + episodes = get_episodes(episode_ids[:episode_count]) + + return [(node, edges)], episodes + + async def get_relevant_schema(self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]) -> ( + list)[Tuple[SemanticNode, list[SemanticEdge]]]: + pass + + # Call llm with the specified messages, and return the response + # Will be used in the conjunction with a prompt library + async def generate_llm_response(self, messages: list[any]) -> str: ... + + # Extract new edges from the episode + async def extract_new_edges( + self, + episode: EpisodicNode, + new_nodes: list[SemanticNode], + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], + ) -> list[SemanticEdge]: ... + + # Extract new nodes from the episode + async def extract_new_nodes( + self, + episode: EpisodicNode, + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], + ) -> list[SemanticNode]: ... + + # Invalidate edges that are no longer valid + async def invalidate_edges( + self, + episode: EpisodicNode, + new_nodes: list[SemanticNode], + new_edges: list[SemanticEdge], + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], + ): ... + + async def add_episode( + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime = None, + episode_type="string", + success_callback: Callable | None = None, + error_callback: Callable | None = None, + ): + """Process an episode and update the graph""" + try: + nodes: list[Node] = [] + edges: list[Edge] = [] + previous_episodes = await self.retrieve_episodes(last_n=3) + episode = EpisodicNode() + await episode.save(self.driver) + relevant_schema = await self.retrieve_relevant_schema(episode.content) + new_nodes = await self.extract_new_nodes( + episode, relevant_schema, previous_episodes + ) + nodes.extend(new_nodes) + new_edges = await self.extract_new_edges( + episode, new_nodes, relevant_schema, previous_episodes + ) + edges.extend(new_edges) + episodic_edges = build_episodic_edges(nodes, episode, datetime.now()) + edges.extend(episodic_edges) + + invalidated_edges = await self.invalidate_edges( + episode, new_nodes, new_edges, relevant_schema, previous_episodes + ) + + edges.extend(invalidated_edges) + + await asyncio.gather(*[node.save(self.driver) for node in nodes]) + await asyncio.gather(*[edge.save(self.driver) for edge in edges]) + for node in nodes: + if isinstance(node, SemanticNode): + await node.update_summary(self.driver) + if success_callback: + await success_callback(episode) + except Exception as e: + if error_callback: + await error_callback(episode, e) + else: + raise e diff --git a/core/nodes.py b/core/nodes.py index 3d06e8b0..864e94bb 100644 --- a/core/nodes.py +++ b/core/nodes.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod from datetime import datetime from uuid import uuid1 -from pydantic import BaseModel + +from openai import OpenAI +from pydantic import BaseModel, Field from neo4j import AsyncDriver import logging @@ -9,14 +11,13 @@ logger = logging.getLogger(__name__) class Node(BaseModel, ABC): - uuid: str | None + uuid: Field(default_factory=lambda: uuid1().hex) name: str labels: list[str] transaction_from: datetime @abstractmethod - async def save(self, driver: AsyncDriver): - ... + async def save(self, driver: AsyncDriver): ... class EpisodicNode(Node): @@ -27,21 +28,23 @@ class EpisodicNode(Node): valid_from: datetime = None # datetime of when the original document was created async def save(self, driver: AsyncDriver): - if self.uuid is None: - uuid = uuid1() - logger.info(f'Created uuid: {uuid} for node with name: {self.name}') - self.uuid = str(uuid) - - result = await driver.execute_query(""" + result = await driver.execute_query( + """ MERGE (n:Episodic {uuid: $uuid}) SET n = {uuid: $uuid, name: $name, source_description: $source_description, content: $content, semantic_edges: $semantic_edges, transaction_from: $transaction_from, valid_from: $valid_from} RETURN n.uuid AS uuid""", - uuid=self.uuid, name=self.name, source_description=self.source_description, - content=self.content, semantic_edges=self.semantic_edges, - transaction_from=self.transaction_from, valid_from=self.valid_from, _database='neo4j') + uuid=self.uuid, + name=self.name, + source_description=self.source_description, + content=self.content, + semantic_edges=self.semantic_edges, + transaction_from=self.transaction_from, + valid_from=self.valid_from, + _database="neo4j", + ) - logger.info(f'Saved Node to neo4j: {self.uuid}') + logger.info(f"Saved Node to neo4j: {self.uuid}") print(self.uuid) return result @@ -50,20 +53,20 @@ class EpisodicNode(Node): class SemanticNode(Node): summary: str # regional summary of surrounding edges - async def save(self, driver: AsyncDriver): - if self.uuid is None: - uuid = uuid1() - logger.info(f'Created uuid: {uuid} for node with name: {self.name}') - self.uuid = str(uuid) + async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... - result = await driver.execute_query(""" + async def save(self, driver: AsyncDriver): + result = await driver.execute_query( + """ MERGE (n:Semantic {uuid: $uuid}) SET n = {uuid: $uuid, name: $name, summary: $summary, transaction_from: $transaction_from} RETURN n.uuid AS uuid""", - uuid=self.uuid, name=self.name, summary=self.summary, - transaction_from=self.transaction_from) + uuid=self.uuid, + name=self.name, + summary=self.summary, + transaction_from=self.transaction_from, + ) - logger.info(f'Saved Node to neo4j: {self.uuid}') + logger.info(f"Saved Node to neo4j: {self.uuid}") return result - diff --git a/core/utils.py b/core/utils.py new file mode 100644 index 00000000..74688ae6 --- /dev/null +++ b/core/utils.py @@ -0,0 +1,45 @@ +from typing import Tuple + +from core.edges import EpisodicEdge, SemanticEdge, Edge +from core.nodes import SemanticNode, EpisodicNode, Node + + +async def bfs( + nodes: list[Node], edges: list[Edge], k: int +) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ... + + +# Breadth first search over nodes and edges with desired depth + + +async def similarity_search( + query: str, embedder +) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ... + + +# vector similarity search over embedded facts + + +async def fulltext_search( + query: str, +) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ... + + +# fulltext search over names and summary + + +def build_episodic_edges( + semantic_nodes: list[SemanticNode], episode: EpisodicNode +) -> list[EpisodicEdge]: + edges: list[EpisodicEdge] = [] + + for node in semantic_nodes: + edges.append( + EpisodicEdge( + source_node=episode, + target_node=node, + transaction_from=episode.transaction_from, + ) + ) + + return edges