chore: Initial draft of stubs (#2)

* chore: Initial draft of stubs

* updates

* chore: Add comments and mock implementation of the add_episode method

* chore: Add success and error callbacks

* stub updates

---------

Co-authored-by: prestonrasmussen <prasmuss15@gmail.com>
This commit is contained in:
Pavlo Paliychuk 2024-08-14 10:17:12 -04:00 committed by GitHub
parent dab3a62247
commit 83c7640d9c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 232 additions and 55 deletions

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from pydantic import BaseModel, Field
from datetime import datetime
from neo4j import AsyncDriver
from uuid import uuid1
@ -11,36 +11,35 @@ logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC):
uuid: str | None
uuid: Field(default_factory=lambda: uuid1().hex)
source_node: Node
target_node: Node
transaction_from: datetime
@abstractmethod
async def save(self, driver: AsyncDriver):
...
async def save(self, driver: AsyncDriver): ...
class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver):
if self.uuid is None:
uuid = uuid1()
logger.info(f'Created uuid: {uuid} for episodic edge')
self.uuid = str(uuid)
result = await driver.execute_query("""
result = await driver.execute_query(
"""
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Semantic {uuid: $semantic_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, transaction_from: $transaction_from}
RETURN r.uuid AS uuid""",
episode_uuid=self.source_node.uuid, semantic_uuid=self.target_node.uuid,
uuid=self.uuid, transaction_from=self.transaction_from)
episode_uuid=self.source_node.uuid,
semantic_uuid=self.target_node.uuid,
uuid=self.uuid,
transaction_from=self.transaction_from,
)
logger.info(f'Saved edge to neo4j: {self.uuid}')
logger.info(f"Saved edge to neo4j: {self.uuid}")
return result
# TODO: Neo4j doesn't support variables for edge types and labels.
# Right now we have all edge nodes as type RELATES_TO
@ -62,12 +61,8 @@ class SemanticEdge(Edge):
return embedding
async def save(self, driver: AsyncDriver):
if self.uuid is None:
uuid = uuid1()
logger.info(f'Created uuid: {uuid} for edge with name: {self.name}')
self.uuid = str(uuid)
result = await driver.execute_query("""
result = await driver.execute_query(
"""
MATCH (source:Semantic {uuid: $source_uuid})
MATCH (target:Semantic {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
@ -75,12 +70,19 @@ class SemanticEdge(Edge):
episodes: $episodes, transaction_from: $transaction_from, transaction_to: $transaction_to,
valid_from: $valid_from, valid_to: $valid_to}
RETURN r.uuid AS uuid""",
source_uuid=self.source_node.uuid,
target_uuid=self.target_node.uuid, uuid=self.uuid, name=self.name, fact=self.fact,
fact_embedding=self.fact_embedding, episodes=self.episodes,
transaction_from=self.transaction_from, transaction_to=self.transaction_to,
valid_from=self.valid_from, valid_to=self.valid_to)
source_uuid=self.source_node.uuid,
target_uuid=self.target_node.uuid,
uuid=self.uuid,
name=self.name,
fact=self.fact,
fact_embedding=self.fact_embedding,
episodes=self.episodes,
transaction_from=self.transaction_from,
transaction_to=self.transaction_to,
valid_from=self.valid_from,
valid_to=self.valid_to,
)
logger.info(f'Saved Node to neo4j: {self.uuid}')
logger.info(f"Saved Node to neo4j: {self.uuid}")
return result

View file

@ -1,21 +1,148 @@
import asyncio
from typing import Tuple
from datetime import datetime
import logging
from typing import Callable, Tuple
from neo4j import AsyncGraphDatabase
from openai import OpenAI
from core.nodes import SemanticNode, EpisodicNode, Node
from core.edges import SemanticEdge, EpisodicEdge, Edge
from core.edges import SemanticEdge, Edge
from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges
logger = logging.getLogger(__name__)
class LLMConfig:
"""Configuration for the language model"""
def __init__(
self,
api_key: str,
model: str = "gpt-4o",
base_url: str = "https://api.openai.com",
):
self.base_url = base_url
self.api_key = api_key
self.model = model
class Graphiti:
def __init__(self, uri, user, password):
def __init__(
self, uri: str, user: str, password: str, llm_config: LLMConfig | None
):
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.database = "neo4j"
if llm_config:
self.llm_config = llm_config
else:
self.llm_config = None
def close(self):
self.driver.close()
self.driver.close()
async def retrieve_episodes(
self, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
...
# Utility function, to be removed from this class
async def clear_data(self): ...
async def search(self, query: str, config) -> (
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
(vec_nodes, vec_edges) = similarity_search(query, embedder)
(text_nodes, text_edges) = fulltext_search(query)
nodes = vec_nodes.extend(text_nodes)
edges = vec_edges.extend(text_edges)
results = bfs(nodes, edges, k=1)
episode_ids = ["Mode of episode ids"]
episodes = get_episodes(episode_ids[:episode_count])
return [(node, edges)], episodes
async def get_relevant_schema(self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]) -> (
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
pass
# Call llm with the specified messages, and return the response
# Will be used in the conjunction with a prompt library
async def generate_llm_response(self, messages: list[any]) -> str: ...
# Extract new edges from the episode
async def extract_new_edges(
self,
episode: EpisodicNode,
new_nodes: list[SemanticNode],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[SemanticEdge]: ...
# Extract new nodes from the episode
async def extract_new_nodes(
self,
episode: EpisodicNode,
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[SemanticNode]: ...
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[SemanticNode],
new_edges: list[SemanticEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime = None,
episode_type="string",
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""Process an episode and update the graph"""
try:
nodes: list[Node] = []
edges: list[Edge] = []
previous_episodes = await self.retrieve_episodes(last_n=3)
episode = EpisodicNode()
await episode.save(self.driver)
relevant_schema = await self.retrieve_relevant_schema(episode.content)
new_nodes = await self.extract_new_nodes(
episode, relevant_schema, previous_episodes
)
nodes.extend(new_nodes)
new_edges = await self.extract_new_edges(
episode, new_nodes, relevant_schema, previous_episodes
)
edges.extend(new_edges)
episodic_edges = build_episodic_edges(nodes, episode, datetime.now())
edges.extend(episodic_edges)
invalidated_edges = await self.invalidate_edges(
episode, new_nodes, new_edges, relevant_schema, previous_episodes
)
edges.extend(invalidated_edges)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
for node in nodes:
if isinstance(node, SemanticNode):
await node.update_summary(self.driver)
if success_callback:
await success_callback(episode)
except Exception as e:
if error_callback:
await error_callback(episode, e)
else:
raise e

View file

@ -1,7 +1,9 @@
from abc import ABC, abstractmethod
from datetime import datetime
from uuid import uuid1
from pydantic import BaseModel
from openai import OpenAI
from pydantic import BaseModel, Field
from neo4j import AsyncDriver
import logging
@ -9,14 +11,13 @@ logger = logging.getLogger(__name__)
class Node(BaseModel, ABC):
uuid: str | None
uuid: Field(default_factory=lambda: uuid1().hex)
name: str
labels: list[str]
transaction_from: datetime
@abstractmethod
async def save(self, driver: AsyncDriver):
...
async def save(self, driver: AsyncDriver): ...
class EpisodicNode(Node):
@ -27,21 +28,23 @@ class EpisodicNode(Node):
valid_from: datetime = None # datetime of when the original document was created
async def save(self, driver: AsyncDriver):
if self.uuid is None:
uuid = uuid1()
logger.info(f'Created uuid: {uuid} for node with name: {self.name}')
self.uuid = str(uuid)
result = await driver.execute_query("""
result = await driver.execute_query(
"""
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, source_description: $source_description, content: $content,
semantic_edges: $semantic_edges, transaction_from: $transaction_from, valid_from: $valid_from}
RETURN n.uuid AS uuid""",
uuid=self.uuid, name=self.name, source_description=self.source_description,
content=self.content, semantic_edges=self.semantic_edges,
transaction_from=self.transaction_from, valid_from=self.valid_from, _database='neo4j')
uuid=self.uuid,
name=self.name,
source_description=self.source_description,
content=self.content,
semantic_edges=self.semantic_edges,
transaction_from=self.transaction_from,
valid_from=self.valid_from,
_database="neo4j",
)
logger.info(f'Saved Node to neo4j: {self.uuid}')
logger.info(f"Saved Node to neo4j: {self.uuid}")
print(self.uuid)
return result
@ -50,20 +53,20 @@ class EpisodicNode(Node):
class SemanticNode(Node):
summary: str # regional summary of surrounding edges
async def save(self, driver: AsyncDriver):
if self.uuid is None:
uuid = uuid1()
logger.info(f'Created uuid: {uuid} for node with name: {self.name}')
self.uuid = str(uuid)
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
result = await driver.execute_query("""
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Semantic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, summary: $summary, transaction_from: $transaction_from}
RETURN n.uuid AS uuid""",
uuid=self.uuid, name=self.name, summary=self.summary,
transaction_from=self.transaction_from)
uuid=self.uuid,
name=self.name,
summary=self.summary,
transaction_from=self.transaction_from,
)
logger.info(f'Saved Node to neo4j: {self.uuid}')
logger.info(f"Saved Node to neo4j: {self.uuid}")
return result

45
core/utils.py Normal file
View file

@ -0,0 +1,45 @@
from typing import Tuple
from core.edges import EpisodicEdge, SemanticEdge, Edge
from core.nodes import SemanticNode, EpisodicNode, Node
async def bfs(
nodes: list[Node], edges: list[Edge], k: int
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
# Breadth first search over nodes and edges with desired depth
async def similarity_search(
query: str, embedder
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
# vector similarity search over embedded facts
async def fulltext_search(
query: str,
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
# fulltext search over names and summary
def build_episodic_edges(
semantic_nodes: list[SemanticNode], episode: EpisodicNode
) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = []
for node in semantic_nodes:
edges.append(
EpisodicEdge(
source_node=episode,
target_node=node,
transaction_from=episode.transaction_from,
)
)
return edges