From 06d8d9359f5219acca39af08b196a4a51bc03f67 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Tue, 27 Aug 2024 16:18:01 -0400 Subject: [PATCH] Add Missing Node and edge CRUD (#51) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add CRUD operations and fix search limit bugs * format * update tests * å * update tests to double limit call * add default field * format * import correct field --- graphiti_core/edges.py | 109 +++++++++++++++++++++++- graphiti_core/helpers.py | 7 ++ graphiti_core/nodes.py | 97 ++++++++++++++++++++- graphiti_core/search/search.py | 26 +++--- graphiti_core/search/search_utils.py | 18 ++-- tests/test_graphiti_int.py | 25 ++++-- tests/utils/search/search_utils_test.py | 8 +- 7 files changed, 251 insertions(+), 39 deletions(-) create mode 100644 graphiti_core/helpers.py diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index b269fad9..645a3b32 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -23,6 +23,7 @@ from uuid import uuid4 from neo4j import AsyncDriver from pydantic import BaseModel, Field +from graphiti_core.helpers import parse_db_date from graphiti_core.llm_client.config import EMBEDDING_DIM from graphiti_core.nodes import Node @@ -38,6 +39,9 @@ class Edge(BaseModel, ABC): @abstractmethod async def save(self, driver: AsyncDriver): ... + @abstractmethod + async def delete(self, driver: AsyncDriver): ... + def __hash__(self): return hash(self.uuid) @@ -46,6 +50,9 @@ class Edge(BaseModel, ABC): return self.uuid == other.uuid return False + @classmethod + async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ... + class EpisodicEdge(Edge): async def save(self, driver: AsyncDriver): @@ -66,9 +73,48 @@ class EpisodicEdge(Edge): return result + async def delete(self, driver: AsyncDriver): + result = await driver.execute_query( + """ + MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) + DELETE e + """, + uuid=self.uuid, + ) -# TODO: Neo4j doesn't support variables for edge types and labels. -# Right now we have all edge nodes as type RELATES_TO + logger.info(f'Deleted Edge: {self.uuid}') + + return result + + @classmethod + async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + records, _, _ = await driver.execute_query( + """ + MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) + RETURN + e.uuid As uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.created_at AS created_at + """, + uuid=uuid, + ) + + edges: list[EpisodicEdge] = [] + + for record in records: + edges.append( + EpisodicEdge( + uuid=record['uuid'], + source_node_uuid=record['source_node_uuid'], + target_node_uuid=record['target_node_uuid'], + created_at=record['created_at'].to_native(), + ) + ) + + logger.info(f'Found Edge: {uuid}') + + return edges[0] class EntityEdge(Edge): @@ -97,7 +143,7 @@ class EntityEdge(Edge): self.fact_embedding = embedding[:EMBEDDING_DIM] end = time() - logger.info(f'embedded {text} in {end-start} ms') + logger.info(f'embedded {text} in {end - start} ms') return embedding @@ -127,3 +173,60 @@ class EntityEdge(Edge): logger.info(f'Saved edge to neo4j: {self.uuid}') return result + + async def delete(self, driver: AsyncDriver): + result = await driver.execute_query( + """ + MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) + DELETE e + """, + uuid=self.uuid, + ) + + logger.info(f'Deleted Edge: {self.uuid}') + + return result + + @classmethod + async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + records, _, _ = await driver.execute_query( + """ + MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) + RETURN + e.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.created_at AS created_at, + e.name AS name, + e.fact AS fact, + e.fact_embedding AS fact_embedding, + e.episodes AS episodes, + e.expired_at AS expired_at, + e.valid_at AS valid_at, + e.invalid_at AS invalid_at + """, + uuid=uuid, + ) + + edges: list[EntityEdge] = [] + + for record in records: + edges.append( + EntityEdge( + uuid=record['uuid'], + source_node_uuid=record['source_node_uuid'], + target_node_uuid=record['target_node_uuid'], + fact=record['fact'], + name=record['name'], + episodes=record['episodes'], + fact_embedding=record['fact_embedding'], + created_at=record['created_at'].to_native(), + expired_at=parse_db_date(record['expired_at']), + valid_at=parse_db_date(record['valid_at']), + invalid_at=parse_db_date(record['invalid_at']), + ) + ) + + logger.info(f'Found Edge: {uuid}') + + return edges[0] diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py new file mode 100644 index 00000000..6233d274 --- /dev/null +++ b/graphiti_core/helpers.py @@ -0,0 +1,7 @@ +from datetime import datetime + +from neo4j import time as neo4j_time + + +def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None: + return neo_date.to_native() if neo_date else None diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index fe4e0341..c2f7c833 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -75,6 +75,9 @@ class Node(BaseModel, ABC): @abstractmethod async def save(self, driver: AsyncDriver): ... + @abstractmethod + async def delete(self, driver: AsyncDriver): ... + def __hash__(self): return hash(self.uuid) @@ -83,6 +86,9 @@ class Node(BaseModel, ABC): return self.uuid == other.uuid return False + @classmethod + async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ... + class EpisodicNode(Node): source: EpisodeType = Field(description='source type') @@ -111,13 +117,58 @@ class EpisodicNode(Node): created_at=self.created_at, valid_at=self.valid_at, source=self.source.value, - _database='neo4j', ) logger.info(f'Saved Node to neo4j: {self.uuid}') return result + async def delete(self, driver: AsyncDriver): + result = await driver.execute_query( + """ + MATCH (n:Episodic {uuid: $uuid}) + DETACH DELETE n + """, + uuid=self.uuid, + ) + + logger.info(f'Deleted Node: {self.uuid}') + + return result + + @classmethod + async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + records, _, _ = await driver.execute_query( + """ + MATCH (e:Episodic {uuid: $uuid}) + RETURN e.content as content, + e.created_at as created_at, + e.valid_at as valid_at, + e.uuid as uuid, + e.name as name, + e.source_description as source_description, + e.source as source + """, + uuid=uuid, + ) + + episodes = [ + EpisodicNode( + content=record['content'], + created_at=record['created_at'].to_native().timestamp(), + valid_at=(record['valid_at'].to_native()), + uuid=record['uuid'], + source=EpisodeType.from_str(record['source']), + name=record['name'], + source_description=record['source_description'], + ) + for record in records + ] + + logger.info(f'Found Node: {uuid}') + + return episodes[0] + class EntityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') @@ -153,3 +204,47 @@ class EntityNode(Node): logger.info(f'Saved Node to neo4j: {self.uuid}') return result + + async def delete(self, driver: AsyncDriver): + result = await driver.execute_query( + """ + MATCH (n:Entity {uuid: $uuid}) + DETACH DELETE n + """, + uuid=self.uuid, + ) + + logger.info(f'Deleted Node: {self.uuid}') + + return result + + @classmethod + async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + records, _, _ = await driver.execute_query( + """ + MATCH (n:Entity {uuid: $uuid}) + RETURN + n.uuid As uuid, + n.name AS name, + n.created_at AS created_at, + n.summary AS summary + """, + uuid=uuid, + ) + + nodes: list[EntityNode] = [] + + for record in records: + nodes.append( + EntityNode( + uuid=record['uuid'], + name=record['name'], + labels=['Entity'], + created_at=record['created_at'].to_native(), + summary=record['summary'], + ) + ) + + logger.info(f'Found Node: {uuid}') + + return nodes[0] diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 03111225..03ed04a4 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -20,7 +20,7 @@ from enum import Enum from time import time from neo4j import AsyncDriver -from pydantic import BaseModel +from pydantic import BaseModel, Field from graphiti_core.edges import EntityEdge from graphiti_core.llm_client.config import EMBEDDING_DIM @@ -49,8 +49,8 @@ class Reranker(Enum): class SearchConfig(BaseModel): - num_edges: int = 10 - num_nodes: int = 10 + num_edges: int = Field(default=10) + num_nodes: int = Field(default=10) num_episodes: int = EPISODE_WINDOW_LEN search_methods: list[SearchMethod] reranker: Reranker | None @@ -63,12 +63,12 @@ class SearchResults(BaseModel): async def hybrid_search( - driver: AsyncDriver, - embedder, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + driver: AsyncDriver, + embedder, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ) -> SearchResults: start = time() @@ -79,11 +79,11 @@ async def hybrid_search( search_results = [] if config.num_episodes > 0: - episodes.extend(await retrieve_episodes(driver, timestamp)) + episodes.extend(await retrieve_episodes(driver, timestamp, config.num_episodes)) nodes.extend(await get_mentioned_nodes(driver, episodes)) if SearchMethod.bm25 in config.search_methods: - text_search = await edge_fulltext_search(query, driver) + text_search = await edge_fulltext_search(query, driver, 2 * config.num_edges) search_results.append(text_search) if SearchMethod.cosine_similarity in config.search_methods: @@ -94,7 +94,9 @@ async def hybrid_search( .embedding[:EMBEDDING_DIM] ) - similarity_search = await edge_similarity_search(search_vector, driver) + similarity_search = await edge_similarity_search( + search_vector, driver, 2 * config.num_edges + ) search_results.append(similarity_search) if len(search_results) > 1 and config.reranker is None: diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index e4789fef..6eeb5cea 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -3,13 +3,12 @@ import logging import re import typing from collections import defaultdict -from datetime import datetime from time import time from neo4j import AsyncDriver -from neo4j import time as neo4j_time from graphiti_core.edges import EntityEdge +from graphiti_core.helpers import parse_db_date from graphiti_core.nodes import EntityNode, EpisodicNode logger = logging.getLogger(__name__) @@ -17,10 +16,6 @@ logger = logging.getLogger(__name__) RELEVANT_SCHEMA_LIMIT = 3 -def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None: - return neo_date.to_native() if neo_date else None - - async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]): episode_uuids = [episode.uuid for episode in episodes] records, _, _ = await driver.execute_query( @@ -106,7 +101,7 @@ async def edge_similarity_search( # vector similarity search over embedded facts records, _, _ = await driver.execute_query( """ - CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector) + CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS r, score MATCH (n)-[r:RELATES_TO]->(m) RETURN @@ -121,7 +116,7 @@ async def edge_similarity_search( r.expired_at AS expired_at, r.valid_at AS valid_at, r.invalid_at AS invalid_at - ORDER BY score DESC LIMIT $limit + ORDER BY score DESC """, search_vector=search_vector, limit=limit, @@ -316,8 +311,11 @@ async def hybrid_node_search( relevant_node_uuids = set() results = await asyncio.gather( - *[entity_fulltext_search(q, driver, limit or RELEVANT_SCHEMA_LIMIT) for q in queries], - *[entity_similarity_search(e, driver, limit or RELEVANT_SCHEMA_LIMIT) for e in embeddings], + *[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries], + *[ + entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) + for e in embeddings + ], ) for result in results: diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index f1b48842..68c54a06 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -22,8 +22,6 @@ from datetime import datetime import pytest from dotenv import load_dotenv -from neo4j import AsyncGraphDatabase -from openai import OpenAI from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.graphiti import Graphiti @@ -74,7 +72,7 @@ def format_context(facts): @pytest.mark.asyncio async def test_graphiti_init(): logger = setup_logging() - graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None) + graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) edges = await graphiti.search('Freakenomics guest') @@ -92,11 +90,9 @@ async def test_graphiti_init(): @pytest.mark.asyncio async def test_graph_integration(): - driver = AsyncGraphDatabase.driver( - NEO4J_URI, - auth=(NEO4j_USER, NEO4j_PASSWORD), - ) - embedder = OpenAI().embeddings + client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) + embedder = client.llm_client.get_embedder() + driver = client.driver now = datetime.now() episode = EpisodicNode( @@ -139,10 +135,21 @@ async def test_graph_integration(): invalid_at=now, ) - entity_edge.generate_embedding(embedder) + await entity_edge.generate_embedding(embedder) nodes = [episode, alice_node, bob_node] edges = [episodic_edge_1, episodic_edge_2, entity_edge] + # test save await asyncio.gather(*[node.save(driver) for node in nodes]) await asyncio.gather(*[edge.save(driver) for edge in edges]) + + # test get + assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None + assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None + assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None + assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None + + # test delete + await asyncio.gather(*[node.delete(driver) for node in nodes]) + await asyncio.gather(*[edge.delete(driver) for edge in edges]) diff --git a/tests/utils/search/search_utils_test.py b/tests/utils/search/search_utils_test.py index 01d78aaa..e4760976 100644 --- a/tests/utils/search/search_utils_test.py +++ b/tests/utils/search/search_utils_test.py @@ -113,8 +113,8 @@ async def test_hybrid_node_search_with_limit(): assert mock_fulltext_search.call_count == 1 assert mock_similarity_search.call_count == 1 # Verify that the limit was passed to the search functions - mock_fulltext_search.assert_called_with('Test', mock_driver, 1) - mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 1) + mock_fulltext_search.assert_called_with('Test', mock_driver, 2) + mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2) @pytest.mark.asyncio @@ -148,5 +148,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'} assert mock_fulltext_search.call_count == 1 assert mock_similarity_search.call_count == 1 - mock_fulltext_search.assert_called_with('Test', mock_driver, 2) - mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2) + mock_fulltext_search.assert_called_with('Test', mock_driver, 4) + mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 4)