renaming and add indices (#3)

rename and add indices
This commit is contained in:
Preston Rasmussen 2024-08-15 11:04:57 -04:00 committed by GitHub
parent 83c7640d9c
commit b728ff0f68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 181 additions and 55 deletions

View file

@ -11,10 +11,10 @@ logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC): class Edge(BaseModel, ABC):
uuid: Field(default_factory=lambda: uuid1().hex) uuid: str = Field(default_factory=lambda: uuid1().hex)
source_node: Node source_node: Node
target_node: Node target_node: Node
transaction_from: datetime created_at: datetime
@abstractmethod @abstractmethod
async def save(self, driver: AsyncDriver): ... async def save(self, driver: AsyncDriver): ...
@ -25,14 +25,14 @@ class EpisodicEdge(Edge):
result = await driver.execute_query( result = await driver.execute_query(
""" """
MATCH (episode:Episodic {uuid: $episode_uuid}) MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Semantic {uuid: $semantic_uuid}) MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node) MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, transaction_from: $transaction_from} SET r = {uuid: $uuid, created_at: $created_at}
RETURN r.uuid AS uuid""", RETURN r.uuid AS uuid""",
episode_uuid=self.source_node.uuid, episode_uuid=self.source_node.uuid,
semantic_uuid=self.target_node.uuid, entity_uuid=self.target_node.uuid,
uuid=self.uuid, uuid=self.uuid,
transaction_from=self.transaction_from, created_at=self.created_at,
) )
logger.info(f"Saved edge to neo4j: {self.uuid}") logger.info(f"Saved edge to neo4j: {self.uuid}")
@ -44,14 +44,14 @@ class EpisodicEdge(Edge):
# Right now we have all edge nodes as type RELATES_TO # Right now we have all edge nodes as type RELATES_TO
class SemanticEdge(Edge): class EntityEdge(Edge):
name: str name: str
fact: str fact: str
fact_embedding: list[int] = None fact_embedding: list[float] = None
episodes: list[str] = None # list of episodes that reference these semantic edges episodes: list[str] = None # list of episode ids that reference these entity edges
transaction_to: datetime = None # datetime of when the node was invalidated expired_at: datetime = None # datetime of when the node was invalidated
valid_from: datetime = None # datetime of when the fact became true valid_at: datetime = None # datetime of when the fact became true
valid_to: datetime = None # datetime of when the fact stopped being true invalid_at: datetime = None # datetime of when the fact stopped being true
def generate_embedding(self, embedder, model="text-embedding-3-large"): def generate_embedding(self, embedder, model="text-embedding-3-large"):
text = self.fact.replace("\n", " ") text = self.fact.replace("\n", " ")
@ -63,12 +63,12 @@ class SemanticEdge(Edge):
async def save(self, driver: AsyncDriver): async def save(self, driver: AsyncDriver):
result = await driver.execute_query( result = await driver.execute_query(
""" """
MATCH (source:Semantic {uuid: $source_uuid}) MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Semantic {uuid: $target_uuid}) MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding, SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
episodes: $episodes, transaction_from: $transaction_from, transaction_to: $transaction_to, episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
valid_from: $valid_from, valid_to: $valid_to} valid_at: $valid_at, invalid_at: $invalid_at}
RETURN r.uuid AS uuid""", RETURN r.uuid AS uuid""",
source_uuid=self.source_node.uuid, source_uuid=self.source_node.uuid,
target_uuid=self.target_node.uuid, target_uuid=self.target_node.uuid,
@ -77,10 +77,10 @@ class SemanticEdge(Edge):
fact=self.fact, fact=self.fact,
fact_embedding=self.fact_embedding, fact_embedding=self.fact_embedding,
episodes=self.episodes, episodes=self.episodes,
transaction_from=self.transaction_from, created_at=self.created_at,
transaction_to=self.transaction_to, expired_at=self.expired_at,
valid_from=self.valid_from, valid_at=self.valid_at,
valid_to=self.valid_to, invalid_at=self.invalid_at,
) )
logger.info(f"Saved Node to neo4j: {self.uuid}") logger.info(f"Saved Node to neo4j: {self.uuid}")

View file

@ -1,11 +1,11 @@
import asyncio import asyncio
from datetime import datetime from datetime import datetime
import logging import logging
from typing import Callable, Tuple from typing import Callable, Tuple, LiteralString
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from core.nodes import SemanticNode, EpisodicNode, Node from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import SemanticEdge, Edge from core.edges import EntityEdge, Edge
from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,6 +31,9 @@ class Graphiti:
): ):
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.database = "neo4j" self.database = "neo4j"
self.build_indices()
if llm_config: if llm_config:
self.llm_config = llm_config self.llm_config = llm_config
else: else:
@ -39,6 +42,40 @@ class Graphiti:
def close(self): def close(self):
self.driver.close() self.driver.close()
async def build_indices(self):
index_queries: list[LiteralString] = [
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.name)",
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.created_at)",
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.expired_at)",
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.valid_at)",
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.invalid_at)",
]
# Add the range indices
for query in index_queries:
await self.driver.execute_query(query)
# Add the entity indices
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
"""
)
await self.driver.execute_query(
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
"""
)
async def retrieve_episodes( async def retrieve_episodes(
self, last_n: int, sources: list[str] | None = "messages" self, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
@ -48,8 +85,9 @@ class Graphiti:
# Utility function, to be removed from this class # Utility function, to be removed from this class
async def clear_data(self): ... async def clear_data(self): ...
async def search(self, query: str, config) -> ( async def search(
list)[Tuple[SemanticNode, list[SemanticEdge]]]: self, query: str, config
) -> (list)[Tuple[EntityNode, list[EntityEdge]]]:
(vec_nodes, vec_edges) = similarity_search(query, embedder) (vec_nodes, vec_edges) = similarity_search(query, embedder)
(text_nodes, text_edges) = fulltext_search(query) (text_nodes, text_edges) = fulltext_search(query)
@ -64,8 +102,9 @@ class Graphiti:
return [(node, edges)], episodes return [(node, edges)], episodes
async def get_relevant_schema(self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]) -> ( async def get_relevant_schema(
list)[Tuple[SemanticNode, list[SemanticEdge]]]: self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
) -> list[Tuple[EntityNode, list[EntityEdge]]]:
pass pass
# Call llm with the specified messages, and return the response # Call llm with the specified messages, and return the response
@ -76,10 +115,10 @@ class Graphiti:
async def extract_new_edges( async def extract_new_edges(
self, self,
episode: EpisodicNode, episode: EpisodicNode,
new_nodes: list[SemanticNode], new_nodes: list[EntityNode],
relevant_schema: dict[str, any], relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode], previous_episodes: list[EpisodicNode],
) -> list[SemanticEdge]: ... ) -> list[EntityEdge]: ...
# Extract new nodes from the episode # Extract new nodes from the episode
async def extract_new_nodes( async def extract_new_nodes(
@ -87,14 +126,14 @@ class Graphiti:
episode: EpisodicNode, episode: EpisodicNode,
relevant_schema: dict[str, any], relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode], previous_episodes: list[EpisodicNode],
) -> list[SemanticNode]: ... ) -> list[EntityNode]: ...
# Invalidate edges that are no longer valid # Invalidate edges that are no longer valid
async def invalidate_edges( async def invalidate_edges(
self, self,
episode: EpisodicNode, episode: EpisodicNode,
new_nodes: list[SemanticNode], new_nodes: list[EntityNode],
new_edges: list[SemanticEdge], new_edges: list[EntityEdge],
relevant_schema: dict[str, any], relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode], previous_episodes: list[EpisodicNode],
): ... ): ...
@ -137,7 +176,7 @@ class Graphiti:
await asyncio.gather(*[node.save(self.driver) for node in nodes]) await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges]) await asyncio.gather(*[edge.save(self.driver) for edge in edges])
for node in nodes: for node in nodes:
if isinstance(node, SemanticNode): if isinstance(node, EntityNode):
await node.update_summary(self.driver) await node.update_summary(self.driver)
if success_callback: if success_callback:
await success_callback(episode) await success_callback(episode)

View file

@ -11,10 +11,10 @@ logger = logging.getLogger(__name__)
class Node(BaseModel, ABC): class Node(BaseModel, ABC):
uuid: Field(default_factory=lambda: uuid1().hex) uuid: str = Field(default_factory=lambda: uuid1().hex)
name: str name: str
labels: list[str] labels: list[str]
transaction_from: datetime created_at: datetime
@abstractmethod @abstractmethod
async def save(self, driver: AsyncDriver): ... async def save(self, driver: AsyncDriver): ...
@ -24,23 +24,23 @@ class EpisodicNode(Node):
source: str # source type source: str # source type
source_description: str # description of the data source source_description: str # description of the data source
content: str # raw episode data content: str # raw episode data
semantic_edges: list[str] # list of semantic edges referenced in this episode entity_edges: list[str] # list of entity edge ids referenced in this episode
valid_from: datetime = None # datetime of when the original document was created valid_at: datetime = None # datetime of when the original document was created
async def save(self, driver: AsyncDriver): async def save(self, driver: AsyncDriver):
result = await driver.execute_query( result = await driver.execute_query(
""" """
MERGE (n:Episodic {uuid: $uuid}) MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, source_description: $source_description, content: $content, SET n = {uuid: $uuid, name: $name, source_description: $source_description, content: $content,
semantic_edges: $semantic_edges, transaction_from: $transaction_from, valid_from: $valid_from} entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid""", RETURN n.uuid AS uuid""",
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,
source_description=self.source_description, source_description=self.source_description,
content=self.content, content=self.content,
semantic_edges=self.semantic_edges, entity_edges=self.entity_edges,
transaction_from=self.transaction_from, created_at=self.created_at,
valid_from=self.valid_from, valid_at=self.valid_at,
_database="neo4j", _database="neo4j",
) )
@ -50,7 +50,7 @@ class EpisodicNode(Node):
return result return result
class SemanticNode(Node): class EntityNode(Node):
summary: str # regional summary of surrounding edges summary: str # regional summary of surrounding edges
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
@ -58,13 +58,13 @@ class SemanticNode(Node):
async def save(self, driver: AsyncDriver): async def save(self, driver: AsyncDriver):
result = await driver.execute_query( result = await driver.execute_query(
""" """
MERGE (n:Semantic {uuid: $uuid}) MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, summary: $summary, transaction_from: $transaction_from} SET n = {uuid: $uuid, name: $name, summary: $summary, created_at: $created_at}
RETURN n.uuid AS uuid""", RETURN n.uuid AS uuid""",
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,
summary=self.summary, summary=self.summary,
transaction_from=self.transaction_from, created_at=self.created_at,
) )
logger.info(f"Saved Node to neo4j: {self.uuid}") logger.info(f"Saved Node to neo4j: {self.uuid}")

View file

@ -1,12 +1,12 @@
from typing import Tuple from typing import Tuple
from core.edges import EpisodicEdge, SemanticEdge, Edge from core.edges import EpisodicEdge, EntityEdge, Edge
from core.nodes import SemanticNode, EpisodicNode, Node from core.nodes import EntityNode, EpisodicNode, Node
async def bfs( async def bfs(
nodes: list[Node], edges: list[Edge], k: int nodes: list[Node], edges: list[Edge], k: int
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ... ) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# Breadth first search over nodes and edges with desired depth # Breadth first search over nodes and edges with desired depth
@ -14,7 +14,7 @@ async def bfs(
async def similarity_search( async def similarity_search(
query: str, embedder query: str, embedder
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ... ) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# vector similarity search over embedded facts # vector similarity search over embedded facts
@ -22,23 +22,23 @@ async def similarity_search(
async def fulltext_search( async def fulltext_search(
query: str, query: str,
) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ... ) -> Tuple[list[EntityNode], list[EntityEdge]]: ...
# fulltext search over names and summary # fulltext search over names and summary
def build_episodic_edges( def build_episodic_edges(
semantic_nodes: list[SemanticNode], episode: EpisodicNode entity_nodes: list[EntityNode], episode: EpisodicNode
) -> list[EpisodicEdge]: ) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = [] edges: list[EpisodicEdge] = []
for node in semantic_nodes: for node in entity_nodes:
edges.append( edges.append(
EpisodicEdge( EpisodicEdge(
source_node=episode, source_node=episode,
target_node=node, target_node=node,
transaction_from=episode.transaction_from, created_at=episode.created_at,
) )
) )

5
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "aiohappyeyeballs" name = "aiohappyeyeballs"
@ -2186,6 +2186,7 @@ description = "Nvidia JIT LTO Library"
optional = false optional = false
python-versions = ">=3" python-versions = ">=3"
files = [ files = [
{file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb"},
{file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79"}, {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79"},
{file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6"}, {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6"},
] ]
@ -4937,4 +4938,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "cefc4469afc33f38b93547ee72ed623000f15faae3889d432a12ddcb33643848" content-hash = "142d26cbdbf9c07019dfdb8599b70e8efb9c3842a3c95588d1f59b9c187e44ba"

View file

@ -23,6 +23,7 @@ python-dotenv = "^1.0.1"
pandas = "^2.2.2" pandas = "^2.2.2"
pytest-asyncio = "^0.23.8" pytest-asyncio = "^0.23.8"
pytest-xdist = "^3.6.1" pytest-xdist = "^3.6.1"
pytest = "^8.3.2"
[build-system] [build-system]

View file

@ -0,0 +1,85 @@
import os
import pytest
import asyncio
from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
from openai import OpenAI
from core.edges import EpisodicEdge, EntityEdge
from core.graphiti import Graphiti
from core.nodes import EpisodicNode, EntityNode
from datetime import datetime
pytest_plugins = ("pytest_asyncio",)
load_dotenv()
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4j_USER = os.getenv("NEO4J_USER")
NEO4j_PASSWORD = os.getenv("NEO4J_PASSWORD")
@pytest.mark.asyncio
async def test_graphiti_init():
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
await graphiti.build_indices()
graphiti.close()
@pytest.mark.asyncio
async def test_graph_integration():
driver = AsyncGraphDatabase.driver(
NEO4J_URI,
auth=(NEO4j_USER, NEO4j_PASSWORD),
)
embedder = OpenAI().embeddings
now = datetime.now()
episode = EpisodicNode(
name="test_episode",
labels=[],
created_at=now,
source="message",
source_description="conversation message",
content="Alice likes Bob",
entity_edges=[],
)
alice_node = EntityNode(
name="Alice",
labels=[],
created_at=now,
summary="Alice summary",
)
bob_node = EntityNode(name="Bob", labels=[], created_at=now, summary="Bob summary")
episodic_edge_1 = EpisodicEdge(
source_node=episode, target_node=alice_node, created_at=now
)
episodic_edge_2 = EpisodicEdge(
source_node=episode, target_node=bob_node, created_at=now
)
entity_edge = EntityEdge(
source_node=alice_node,
target_node=bob_node,
created_at=now,
name="likes",
fact="Alice likes Bob",
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
)
entity_edge.generate_embedding(embedder)
nodes = [episode, alice_node, bob_node]
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
await asyncio.gather(*[node.save(driver) for node in nodes])
await asyncio.gather(*[edge.save(driver) for edge in edges])