From b728ff0f680eff6ef67c86329794687d99dbf304 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:04:57 -0400 Subject: [PATCH] renaming and add indices (#3) rename and add indices --- core/edges.py | 40 ++++++++--------- core/graphiti.py | 65 ++++++++++++++++++++++------ core/nodes.py | 24 +++++------ core/utils.py | 16 +++---- poetry.lock | 5 ++- pyproject.toml | 1 + tests/graphiti_int_tests.py | 85 +++++++++++++++++++++++++++++++++++++ 7 files changed, 181 insertions(+), 55 deletions(-) create mode 100644 tests/graphiti_int_tests.py diff --git a/core/edges.py b/core/edges.py index ca0749fd..bc0ff449 100644 --- a/core/edges.py +++ b/core/edges.py @@ -11,10 +11,10 @@ logger = logging.getLogger(__name__) class Edge(BaseModel, ABC): - uuid: Field(default_factory=lambda: uuid1().hex) + uuid: str = Field(default_factory=lambda: uuid1().hex) source_node: Node target_node: Node - transaction_from: datetime + created_at: datetime @abstractmethod async def save(self, driver: AsyncDriver): ... @@ -25,14 +25,14 @@ class EpisodicEdge(Edge): result = await driver.execute_query( """ MATCH (episode:Episodic {uuid: $episode_uuid}) - MATCH (node:Semantic {uuid: $semantic_uuid}) + MATCH (node:Entity {uuid: $entity_uuid}) MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node) - SET r = {uuid: $uuid, transaction_from: $transaction_from} + SET r = {uuid: $uuid, created_at: $created_at} RETURN r.uuid AS uuid""", episode_uuid=self.source_node.uuid, - semantic_uuid=self.target_node.uuid, + entity_uuid=self.target_node.uuid, uuid=self.uuid, - transaction_from=self.transaction_from, + created_at=self.created_at, ) logger.info(f"Saved edge to neo4j: {self.uuid}") @@ -44,14 +44,14 @@ class EpisodicEdge(Edge): # Right now we have all edge nodes as type RELATES_TO -class SemanticEdge(Edge): +class EntityEdge(Edge): name: str fact: str - fact_embedding: list[int] = None - episodes: list[str] = None # list of episodes that reference these semantic edges - transaction_to: datetime = None # datetime of when the node was invalidated - valid_from: datetime = None # datetime of when the fact became true - valid_to: datetime = None # datetime of when the fact stopped being true + fact_embedding: list[float] = None + episodes: list[str] = None # list of episode ids that reference these entity edges + expired_at: datetime = None # datetime of when the node was invalidated + valid_at: datetime = None # datetime of when the fact became true + invalid_at: datetime = None # datetime of when the fact stopped being true def generate_embedding(self, embedder, model="text-embedding-3-large"): text = self.fact.replace("\n", " ") @@ -63,12 +63,12 @@ class SemanticEdge(Edge): async def save(self, driver: AsyncDriver): result = await driver.execute_query( """ - MATCH (source:Semantic {uuid: $source_uuid}) - MATCH (target:Semantic {uuid: $target_uuid}) + MATCH (source:Entity {uuid: $source_uuid}) + MATCH (target:Entity {uuid: $target_uuid}) MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding, - episodes: $episodes, transaction_from: $transaction_from, transaction_to: $transaction_to, - valid_from: $valid_from, valid_to: $valid_to} + episodes: $episodes, created_at: $created_at, expired_at: $expired_at, + valid_at: $valid_at, invalid_at: $invalid_at} RETURN r.uuid AS uuid""", source_uuid=self.source_node.uuid, target_uuid=self.target_node.uuid, @@ -77,10 +77,10 @@ class SemanticEdge(Edge): fact=self.fact, fact_embedding=self.fact_embedding, episodes=self.episodes, - transaction_from=self.transaction_from, - transaction_to=self.transaction_to, - valid_from=self.valid_from, - valid_to=self.valid_to, + created_at=self.created_at, + expired_at=self.expired_at, + valid_at=self.valid_at, + invalid_at=self.invalid_at, ) logger.info(f"Saved Node to neo4j: {self.uuid}") diff --git a/core/graphiti.py b/core/graphiti.py index 8ba9060e..00c5a273 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -1,11 +1,11 @@ import asyncio from datetime import datetime import logging -from typing import Callable, Tuple +from typing import Callable, Tuple, LiteralString from neo4j import AsyncGraphDatabase -from core.nodes import SemanticNode, EpisodicNode, Node -from core.edges import SemanticEdge, Edge +from core.nodes import EntityNode, EpisodicNode, Node +from core.edges import EntityEdge, Edge from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges logger = logging.getLogger(__name__) @@ -31,6 +31,9 @@ class Graphiti: ): self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) self.database = "neo4j" + + self.build_indices() + if llm_config: self.llm_config = llm_config else: @@ -39,6 +42,40 @@ class Graphiti: def close(self): self.driver.close() + async def build_indices(self): + index_queries: list[LiteralString] = [ + "CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)", + "CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)", + "CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)", + "CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)", + "CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.name)", + "CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.created_at)", + "CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.expired_at)", + "CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.valid_at)", + "CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.invalid_at)", + ] + # Add the range indices + for query in index_queries: + await self.driver.execute_query(query) + + # Add the entity indices + await self.driver.execute_query( + """ + CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary] + """ + ) + + await self.driver.execute_query( + """ + CREATE VECTOR INDEX fact_embedding IF NOT EXISTS + FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding) + OPTIONS {indexConfig: { + `vector.dimensions`: 1024, + `vector.similarity_function`: 'cosine' + }} + """ + ) + async def retrieve_episodes( self, last_n: int, sources: list[str] | None = "messages" ) -> list[EpisodicNode]: @@ -48,8 +85,9 @@ class Graphiti: # Utility function, to be removed from this class async def clear_data(self): ... - async def search(self, query: str, config) -> ( - list)[Tuple[SemanticNode, list[SemanticEdge]]]: + async def search( + self, query: str, config + ) -> (list)[Tuple[EntityNode, list[EntityEdge]]]: (vec_nodes, vec_edges) = similarity_search(query, embedder) (text_nodes, text_edges) = fulltext_search(query) @@ -64,8 +102,9 @@ class Graphiti: return [(node, edges)], episodes - async def get_relevant_schema(self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]) -> ( - list)[Tuple[SemanticNode, list[SemanticEdge]]]: + async def get_relevant_schema( + self, episode: EpisodicNode, previous_episodes: list[EpisodicNode] + ) -> list[Tuple[EntityNode, list[EntityEdge]]]: pass # Call llm with the specified messages, and return the response @@ -76,10 +115,10 @@ class Graphiti: async def extract_new_edges( self, episode: EpisodicNode, - new_nodes: list[SemanticNode], + new_nodes: list[EntityNode], relevant_schema: dict[str, any], previous_episodes: list[EpisodicNode], - ) -> list[SemanticEdge]: ... + ) -> list[EntityEdge]: ... # Extract new nodes from the episode async def extract_new_nodes( @@ -87,14 +126,14 @@ class Graphiti: episode: EpisodicNode, relevant_schema: dict[str, any], previous_episodes: list[EpisodicNode], - ) -> list[SemanticNode]: ... + ) -> list[EntityNode]: ... # Invalidate edges that are no longer valid async def invalidate_edges( self, episode: EpisodicNode, - new_nodes: list[SemanticNode], - new_edges: list[SemanticEdge], + new_nodes: list[EntityNode], + new_edges: list[EntityEdge], relevant_schema: dict[str, any], previous_episodes: list[EpisodicNode], ): ... @@ -137,7 +176,7 @@ class Graphiti: await asyncio.gather(*[node.save(self.driver) for node in nodes]) await asyncio.gather(*[edge.save(self.driver) for edge in edges]) for node in nodes: - if isinstance(node, SemanticNode): + if isinstance(node, EntityNode): await node.update_summary(self.driver) if success_callback: await success_callback(episode) diff --git a/core/nodes.py b/core/nodes.py index 864e94bb..61fcd7b9 100644 --- a/core/nodes.py +++ b/core/nodes.py @@ -11,10 +11,10 @@ logger = logging.getLogger(__name__) class Node(BaseModel, ABC): - uuid: Field(default_factory=lambda: uuid1().hex) + uuid: str = Field(default_factory=lambda: uuid1().hex) name: str labels: list[str] - transaction_from: datetime + created_at: datetime @abstractmethod async def save(self, driver: AsyncDriver): ... @@ -24,23 +24,23 @@ class EpisodicNode(Node): source: str # source type source_description: str # description of the data source content: str # raw episode data - semantic_edges: list[str] # list of semantic edges referenced in this episode - valid_from: datetime = None # datetime of when the original document was created + entity_edges: list[str] # list of entity edge ids referenced in this episode + valid_at: datetime = None # datetime of when the original document was created async def save(self, driver: AsyncDriver): result = await driver.execute_query( """ MERGE (n:Episodic {uuid: $uuid}) SET n = {uuid: $uuid, name: $name, source_description: $source_description, content: $content, - semantic_edges: $semantic_edges, transaction_from: $transaction_from, valid_from: $valid_from} + entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} RETURN n.uuid AS uuid""", uuid=self.uuid, name=self.name, source_description=self.source_description, content=self.content, - semantic_edges=self.semantic_edges, - transaction_from=self.transaction_from, - valid_from=self.valid_from, + entity_edges=self.entity_edges, + created_at=self.created_at, + valid_at=self.valid_at, _database="neo4j", ) @@ -50,7 +50,7 @@ class EpisodicNode(Node): return result -class SemanticNode(Node): +class EntityNode(Node): summary: str # regional summary of surrounding edges async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... @@ -58,13 +58,13 @@ class SemanticNode(Node): async def save(self, driver: AsyncDriver): result = await driver.execute_query( """ - MERGE (n:Semantic {uuid: $uuid}) - SET n = {uuid: $uuid, name: $name, summary: $summary, transaction_from: $transaction_from} + MERGE (n:Entity {uuid: $uuid}) + SET n = {uuid: $uuid, name: $name, summary: $summary, created_at: $created_at} RETURN n.uuid AS uuid""", uuid=self.uuid, name=self.name, summary=self.summary, - transaction_from=self.transaction_from, + created_at=self.created_at, ) logger.info(f"Saved Node to neo4j: {self.uuid}") diff --git a/core/utils.py b/core/utils.py index 74688ae6..72ed57d8 100644 --- a/core/utils.py +++ b/core/utils.py @@ -1,12 +1,12 @@ from typing import Tuple -from core.edges import EpisodicEdge, SemanticEdge, Edge -from core.nodes import SemanticNode, EpisodicNode, Node +from core.edges import EpisodicEdge, EntityEdge, Edge +from core.nodes import EntityNode, EpisodicNode, Node async def bfs( nodes: list[Node], edges: list[Edge], k: int -) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ... +) -> Tuple[list[EntityNode], list[EntityEdge]]: ... # Breadth first search over nodes and edges with desired depth @@ -14,7 +14,7 @@ async def bfs( async def similarity_search( query: str, embedder -) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ... +) -> Tuple[list[EntityNode], list[EntityEdge]]: ... # vector similarity search over embedded facts @@ -22,23 +22,23 @@ async def similarity_search( async def fulltext_search( query: str, -) -> Tuple[list[SemanticNode], list[SemanticEdge]]: ... +) -> Tuple[list[EntityNode], list[EntityEdge]]: ... # fulltext search over names and summary def build_episodic_edges( - semantic_nodes: list[SemanticNode], episode: EpisodicNode + entity_nodes: list[EntityNode], episode: EpisodicNode ) -> list[EpisodicEdge]: edges: list[EpisodicEdge] = [] - for node in semantic_nodes: + for node in entity_nodes: edges.append( EpisodicEdge( source_node=episode, target_node=node, - transaction_from=episode.transaction_from, + created_at=episode.created_at, ) ) diff --git a/poetry.lock b/poetry.lock index f1bf0121..cbdc32ca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -2186,6 +2186,7 @@ description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb"}, {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79"}, {file = "nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6"}, ] @@ -4937,4 +4938,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "cefc4469afc33f38b93547ee72ed623000f15faae3889d432a12ddcb33643848" +content-hash = "142d26cbdbf9c07019dfdb8599b70e8efb9c3842a3c95588d1f59b9c187e44ba" diff --git a/pyproject.toml b/pyproject.toml index be49232b..94aacf87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ python-dotenv = "^1.0.1" pandas = "^2.2.2" pytest-asyncio = "^0.23.8" pytest-xdist = "^3.6.1" +pytest = "^8.3.2" [build-system] diff --git a/tests/graphiti_int_tests.py b/tests/graphiti_int_tests.py new file mode 100644 index 00000000..1915c81f --- /dev/null +++ b/tests/graphiti_int_tests.py @@ -0,0 +1,85 @@ +import os + +import pytest +import asyncio +from dotenv import load_dotenv + +from neo4j import AsyncGraphDatabase +from openai import OpenAI + +from core.edges import EpisodicEdge, EntityEdge +from core.graphiti import Graphiti +from core.nodes import EpisodicNode, EntityNode +from datetime import datetime + +pytest_plugins = ("pytest_asyncio",) + +load_dotenv() + +NEO4J_URI = os.getenv("NEO4J_URI") +NEO4j_USER = os.getenv("NEO4J_USER") +NEO4j_PASSWORD = os.getenv("NEO4J_PASSWORD") + + +@pytest.mark.asyncio +async def test_graphiti_init(): + graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None) + await graphiti.build_indices() + graphiti.close() + + +@pytest.mark.asyncio +async def test_graph_integration(): + driver = AsyncGraphDatabase.driver( + NEO4J_URI, + auth=(NEO4j_USER, NEO4j_PASSWORD), + ) + embedder = OpenAI().embeddings + + now = datetime.now() + episode = EpisodicNode( + name="test_episode", + labels=[], + created_at=now, + source="message", + source_description="conversation message", + content="Alice likes Bob", + entity_edges=[], + ) + + alice_node = EntityNode( + name="Alice", + labels=[], + created_at=now, + summary="Alice summary", + ) + + bob_node = EntityNode(name="Bob", labels=[], created_at=now, summary="Bob summary") + + episodic_edge_1 = EpisodicEdge( + source_node=episode, target_node=alice_node, created_at=now + ) + + episodic_edge_2 = EpisodicEdge( + source_node=episode, target_node=bob_node, created_at=now + ) + + entity_edge = EntityEdge( + source_node=alice_node, + target_node=bob_node, + created_at=now, + name="likes", + fact="Alice likes Bob", + episodes=[], + expired_at=now, + valid_at=now, + invalid_at=now, + ) + + entity_edge.generate_embedding(embedder) + + nodes = [episode, alice_node, bob_node] + edges = [episodic_edge_1, episodic_edge_2, entity_edge] + + await asyncio.gather(*[node.save(driver) for node in nodes]) + await asyncio.gather(*[edge.save(driver) for edge in edges])