chore: Initial draft of stubs (#2)
* chore: Initial draft of stubs * updates * chore: Add comments and mock implementation of the add_episode method * chore: Add success and error callbacks * stub updates --------- Co-authored-by: prestonrasmussen <prasmuss15@gmail.com>
This commit is contained in:
parent
dab3a62247
commit
83c7640d9c
4 changed files with 232 additions and 55 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
from uuid import uuid1
|
from uuid import uuid1
|
||||||
|
|
@ -11,36 +11,35 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel, ABC):
|
class Edge(BaseModel, ABC):
|
||||||
uuid: str | None
|
uuid: Field(default_factory=lambda: uuid1().hex)
|
||||||
source_node: Node
|
source_node: Node
|
||||||
target_node: Node
|
target_node: Node
|
||||||
transaction_from: datetime
|
transaction_from: datetime
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver): ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodicEdge(Edge):
|
class EpisodicEdge(Edge):
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
if self.uuid is None:
|
result = await driver.execute_query(
|
||||||
uuid = uuid1()
|
"""
|
||||||
logger.info(f'Created uuid: {uuid} for episodic edge')
|
|
||||||
self.uuid = str(uuid)
|
|
||||||
|
|
||||||
result = await driver.execute_query("""
|
|
||||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||||
MATCH (node:Semantic {uuid: $semantic_uuid})
|
MATCH (node:Semantic {uuid: $semantic_uuid})
|
||||||
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
||||||
SET r = {uuid: $uuid, transaction_from: $transaction_from}
|
SET r = {uuid: $uuid, transaction_from: $transaction_from}
|
||||||
RETURN r.uuid AS uuid""",
|
RETURN r.uuid AS uuid""",
|
||||||
episode_uuid=self.source_node.uuid, semantic_uuid=self.target_node.uuid,
|
episode_uuid=self.source_node.uuid,
|
||||||
uuid=self.uuid, transaction_from=self.transaction_from)
|
semantic_uuid=self.target_node.uuid,
|
||||||
|
uuid=self.uuid,
|
||||||
|
transaction_from=self.transaction_from,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f'Saved edge to neo4j: {self.uuid}')
|
logger.info(f"Saved edge to neo4j: {self.uuid}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# TODO: Neo4j doesn't support variables for edge types and labels.
|
# TODO: Neo4j doesn't support variables for edge types and labels.
|
||||||
# Right now we have all edge nodes as type RELATES_TO
|
# Right now we have all edge nodes as type RELATES_TO
|
||||||
|
|
||||||
|
|
@ -62,12 +61,8 @@ class SemanticEdge(Edge):
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
if self.uuid is None:
|
result = await driver.execute_query(
|
||||||
uuid = uuid1()
|
"""
|
||||||
logger.info(f'Created uuid: {uuid} for edge with name: {self.name}')
|
|
||||||
self.uuid = str(uuid)
|
|
||||||
|
|
||||||
result = await driver.execute_query("""
|
|
||||||
MATCH (source:Semantic {uuid: $source_uuid})
|
MATCH (source:Semantic {uuid: $source_uuid})
|
||||||
MATCH (target:Semantic {uuid: $target_uuid})
|
MATCH (target:Semantic {uuid: $target_uuid})
|
||||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
||||||
|
|
@ -75,12 +70,19 @@ class SemanticEdge(Edge):
|
||||||
episodes: $episodes, transaction_from: $transaction_from, transaction_to: $transaction_to,
|
episodes: $episodes, transaction_from: $transaction_from, transaction_to: $transaction_to,
|
||||||
valid_from: $valid_from, valid_to: $valid_to}
|
valid_from: $valid_from, valid_to: $valid_to}
|
||||||
RETURN r.uuid AS uuid""",
|
RETURN r.uuid AS uuid""",
|
||||||
source_uuid=self.source_node.uuid,
|
source_uuid=self.source_node.uuid,
|
||||||
target_uuid=self.target_node.uuid, uuid=self.uuid, name=self.name, fact=self.fact,
|
target_uuid=self.target_node.uuid,
|
||||||
fact_embedding=self.fact_embedding, episodes=self.episodes,
|
uuid=self.uuid,
|
||||||
transaction_from=self.transaction_from, transaction_to=self.transaction_to,
|
name=self.name,
|
||||||
valid_from=self.valid_from, valid_to=self.valid_to)
|
fact=self.fact,
|
||||||
|
fact_embedding=self.fact_embedding,
|
||||||
|
episodes=self.episodes,
|
||||||
|
transaction_from=self.transaction_from,
|
||||||
|
transaction_to=self.transaction_to,
|
||||||
|
valid_from=self.valid_from,
|
||||||
|
valid_to=self.valid_to,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f'Saved Node to neo4j: {self.uuid}')
|
logger.info(f"Saved Node to neo4j: {self.uuid}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
137
core/graphiti.py
137
core/graphiti.py
|
|
@ -1,21 +1,148 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Tuple
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Callable, Tuple
|
||||||
from neo4j import AsyncGraphDatabase
|
from neo4j import AsyncGraphDatabase
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from core.nodes import SemanticNode, EpisodicNode, Node
|
from core.nodes import SemanticNode, EpisodicNode, Node
|
||||||
from core.edges import SemanticEdge, EpisodicEdge, Edge
|
from core.edges import SemanticEdge, Edge
|
||||||
|
from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfig:
|
||||||
|
"""Configuration for the language model"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
base_url: str = "https://api.openai.com",
|
||||||
|
):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
|
||||||
class Graphiti:
|
class Graphiti:
|
||||||
def __init__(self, uri, user, password):
|
def __init__(
|
||||||
|
self, uri: str, user: str, password: str, llm_config: LLMConfig | None
|
||||||
|
):
|
||||||
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
||||||
self.database = "neo4j"
|
self.database = "neo4j"
|
||||||
|
if llm_config:
|
||||||
|
self.llm_config = llm_config
|
||||||
|
else:
|
||||||
|
self.llm_config = None
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.driver.close()
|
self.driver.close()
|
||||||
|
|
||||||
|
async def retrieve_episodes(
|
||||||
|
self, last_n: int, sources: list[str] | None = "messages"
|
||||||
|
) -> list[EpisodicNode]:
|
||||||
|
"""Retrieve the last n episodic nodes from the graph"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# Utility function, to be removed from this class
|
||||||
|
async def clear_data(self): ...
|
||||||
|
|
||||||
|
async def search(self, query: str, config) -> (
|
||||||
|
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
|
||||||
|
(vec_nodes, vec_edges) = similarity_search(query, embedder)
|
||||||
|
(text_nodes, text_edges) = fulltext_search(query)
|
||||||
|
|
||||||
|
nodes = vec_nodes.extend(text_nodes)
|
||||||
|
edges = vec_edges.extend(text_edges)
|
||||||
|
|
||||||
|
results = bfs(nodes, edges, k=1)
|
||||||
|
|
||||||
|
episode_ids = ["Mode of episode ids"]
|
||||||
|
|
||||||
|
episodes = get_episodes(episode_ids[:episode_count])
|
||||||
|
|
||||||
|
return [(node, edges)], episodes
|
||||||
|
|
||||||
|
async def get_relevant_schema(self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]) -> (
|
||||||
|
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Call llm with the specified messages, and return the response
|
||||||
|
# Will be used in the conjunction with a prompt library
|
||||||
|
async def generate_llm_response(self, messages: list[any]) -> str: ...
|
||||||
|
|
||||||
|
# Extract new edges from the episode
|
||||||
|
async def extract_new_edges(
|
||||||
|
self,
|
||||||
|
episode: EpisodicNode,
|
||||||
|
new_nodes: list[SemanticNode],
|
||||||
|
relevant_schema: dict[str, any],
|
||||||
|
previous_episodes: list[EpisodicNode],
|
||||||
|
) -> list[SemanticEdge]: ...
|
||||||
|
|
||||||
|
# Extract new nodes from the episode
|
||||||
|
async def extract_new_nodes(
|
||||||
|
self,
|
||||||
|
episode: EpisodicNode,
|
||||||
|
relevant_schema: dict[str, any],
|
||||||
|
previous_episodes: list[EpisodicNode],
|
||||||
|
) -> list[SemanticNode]: ...
|
||||||
|
|
||||||
|
# Invalidate edges that are no longer valid
|
||||||
|
async def invalidate_edges(
|
||||||
|
self,
|
||||||
|
episode: EpisodicNode,
|
||||||
|
new_nodes: list[SemanticNode],
|
||||||
|
new_edges: list[SemanticEdge],
|
||||||
|
relevant_schema: dict[str, any],
|
||||||
|
previous_episodes: list[EpisodicNode],
|
||||||
|
): ...
|
||||||
|
|
||||||
|
async def add_episode(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
episode_body: str,
|
||||||
|
source_description: str,
|
||||||
|
reference_time: datetime = None,
|
||||||
|
episode_type="string",
|
||||||
|
success_callback: Callable | None = None,
|
||||||
|
error_callback: Callable | None = None,
|
||||||
|
):
|
||||||
|
"""Process an episode and update the graph"""
|
||||||
|
try:
|
||||||
|
nodes: list[Node] = []
|
||||||
|
edges: list[Edge] = []
|
||||||
|
previous_episodes = await self.retrieve_episodes(last_n=3)
|
||||||
|
episode = EpisodicNode()
|
||||||
|
await episode.save(self.driver)
|
||||||
|
relevant_schema = await self.retrieve_relevant_schema(episode.content)
|
||||||
|
new_nodes = await self.extract_new_nodes(
|
||||||
|
episode, relevant_schema, previous_episodes
|
||||||
|
)
|
||||||
|
nodes.extend(new_nodes)
|
||||||
|
new_edges = await self.extract_new_edges(
|
||||||
|
episode, new_nodes, relevant_schema, previous_episodes
|
||||||
|
)
|
||||||
|
edges.extend(new_edges)
|
||||||
|
episodic_edges = build_episodic_edges(nodes, episode, datetime.now())
|
||||||
|
edges.extend(episodic_edges)
|
||||||
|
|
||||||
|
invalidated_edges = await self.invalidate_edges(
|
||||||
|
episode, new_nodes, new_edges, relevant_schema, previous_episodes
|
||||||
|
)
|
||||||
|
|
||||||
|
edges.extend(invalidated_edges)
|
||||||
|
|
||||||
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||||
|
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
||||||
|
for node in nodes:
|
||||||
|
if isinstance(node, SemanticNode):
|
||||||
|
await node.update_summary(self.driver)
|
||||||
|
if success_callback:
|
||||||
|
await success_callback(episode)
|
||||||
|
except Exception as e:
|
||||||
|
if error_callback:
|
||||||
|
await error_callback(episode, e)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import uuid1
|
from uuid import uuid1
|
||||||
from pydantic import BaseModel
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
@ -9,14 +11,13 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Node(BaseModel, ABC):
|
class Node(BaseModel, ABC):
|
||||||
uuid: str | None
|
uuid: Field(default_factory=lambda: uuid1().hex)
|
||||||
name: str
|
name: str
|
||||||
labels: list[str]
|
labels: list[str]
|
||||||
transaction_from: datetime
|
transaction_from: datetime
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver): ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodicNode(Node):
|
class EpisodicNode(Node):
|
||||||
|
|
@ -27,21 +28,23 @@ class EpisodicNode(Node):
|
||||||
valid_from: datetime = None # datetime of when the original document was created
|
valid_from: datetime = None # datetime of when the original document was created
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
if self.uuid is None:
|
result = await driver.execute_query(
|
||||||
uuid = uuid1()
|
"""
|
||||||
logger.info(f'Created uuid: {uuid} for node with name: {self.name}')
|
|
||||||
self.uuid = str(uuid)
|
|
||||||
|
|
||||||
result = await driver.execute_query("""
|
|
||||||
MERGE (n:Episodic {uuid: $uuid})
|
MERGE (n:Episodic {uuid: $uuid})
|
||||||
SET n = {uuid: $uuid, name: $name, source_description: $source_description, content: $content,
|
SET n = {uuid: $uuid, name: $name, source_description: $source_description, content: $content,
|
||||||
semantic_edges: $semantic_edges, transaction_from: $transaction_from, valid_from: $valid_from}
|
semantic_edges: $semantic_edges, transaction_from: $transaction_from, valid_from: $valid_from}
|
||||||
RETURN n.uuid AS uuid""",
|
RETURN n.uuid AS uuid""",
|
||||||
uuid=self.uuid, name=self.name, source_description=self.source_description,
|
uuid=self.uuid,
|
||||||
content=self.content, semantic_edges=self.semantic_edges,
|
name=self.name,
|
||||||
transaction_from=self.transaction_from, valid_from=self.valid_from, _database='neo4j')
|
source_description=self.source_description,
|
||||||
|
content=self.content,
|
||||||
|
semantic_edges=self.semantic_edges,
|
||||||
|
transaction_from=self.transaction_from,
|
||||||
|
valid_from=self.valid_from,
|
||||||
|
_database="neo4j",
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f'Saved Node to neo4j: {self.uuid}')
|
logger.info(f"Saved Node to neo4j: {self.uuid}")
|
||||||
print(self.uuid)
|
print(self.uuid)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
@ -50,20 +53,20 @@ class EpisodicNode(Node):
|
||||||
class SemanticNode(Node):
|
class SemanticNode(Node):
|
||||||
summary: str # regional summary of surrounding edges
|
summary: str # regional summary of surrounding edges
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
|
||||||
if self.uuid is None:
|
|
||||||
uuid = uuid1()
|
|
||||||
logger.info(f'Created uuid: {uuid} for node with name: {self.name}')
|
|
||||||
self.uuid = str(uuid)
|
|
||||||
|
|
||||||
result = await driver.execute_query("""
|
async def save(self, driver: AsyncDriver):
|
||||||
|
result = await driver.execute_query(
|
||||||
|
"""
|
||||||
MERGE (n:Semantic {uuid: $uuid})
|
MERGE (n:Semantic {uuid: $uuid})
|
||||||
SET n = {uuid: $uuid, name: $name, summary: $summary, transaction_from: $transaction_from}
|
SET n = {uuid: $uuid, name: $name, summary: $summary, transaction_from: $transaction_from}
|
||||||
RETURN n.uuid AS uuid""",
|
RETURN n.uuid AS uuid""",
|
||||||
uuid=self.uuid, name=self.name, summary=self.summary,
|
uuid=self.uuid,
|
||||||
transaction_from=self.transaction_from)
|
name=self.name,
|
||||||
|
summary=self.summary,
|
||||||
|
transaction_from=self.transaction_from,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f'Saved Node to neo4j: {self.uuid}')
|
logger.info(f"Saved Node to neo4j: {self.uuid}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
||||||
45
core/utils.py
Normal file
45
core/utils.py
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from core.edges import EpisodicEdge, SemanticEdge, Edge
|
||||||
|
from core.nodes import SemanticNode, EpisodicNode, Node
|
||||||
|
|
||||||
|
|
||||||
|
async def bfs(
|
||||||
|
nodes: list[Node], edges: list[Edge], k: int
|
||||||
|
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
# Breadth first search over nodes and edges with desired depth
|
||||||
|
|
||||||
|
|
||||||
|
async def similarity_search(
|
||||||
|
query: str, embedder
|
||||||
|
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
# vector similarity search over embedded facts
|
||||||
|
|
||||||
|
|
||||||
|
async def fulltext_search(
|
||||||
|
query: str,
|
||||||
|
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
# fulltext search over names and summary
|
||||||
|
|
||||||
|
|
||||||
|
def build_episodic_edges(
|
||||||
|
semantic_nodes: list[SemanticNode], episode: EpisodicNode
|
||||||
|
) -> list[EpisodicEdge]:
|
||||||
|
edges: list[EpisodicEdge] = []
|
||||||
|
|
||||||
|
for node in semantic_nodes:
|
||||||
|
edges.append(
|
||||||
|
EpisodicEdge(
|
||||||
|
source_node=episode,
|
||||||
|
target_node=node,
|
||||||
|
transaction_from=episode.transaction_from,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return edges
|
||||||
Loading…
Add table
Reference in a new issue