diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 8f657a1f..e9b065ac 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -103,6 +103,7 @@ class EpisodicEdge(Edge): """, uuid=uuid, database_=DEFAULT_DATABASE, + routing_='r', ) edges = [get_episodic_edge_from_record(record) for record in records] @@ -126,6 +127,7 @@ class EpisodicEdge(Edge): """, uuids=uuids, database_=DEFAULT_DATABASE, + routing_='r', ) edges = [get_episodic_edge_from_record(record) for record in records] @@ -149,6 +151,7 @@ class EpisodicEdge(Edge): """, group_ids=group_ids, database_=DEFAULT_DATABASE, + routing_='r', ) edges = [get_episodic_edge_from_record(record) for record in records] @@ -230,6 +233,7 @@ class EntityEdge(Edge): """, uuid=uuid, database_=DEFAULT_DATABASE, + routing_='r', ) edges = [get_entity_edge_from_record(record) for record in records] @@ -260,6 +264,7 @@ class EntityEdge(Edge): """, uuids=uuids, database_=DEFAULT_DATABASE, + routing_='r', ) edges = [get_entity_edge_from_record(record) for record in records] @@ -290,6 +295,7 @@ class EntityEdge(Edge): """, group_ids=group_ids, database_=DEFAULT_DATABASE, + routing_='r', ) edges = [get_entity_edge_from_record(record) for record in records] @@ -329,6 +335,7 @@ class CommunityEdge(Edge): """, uuid=uuid, database_=DEFAULT_DATABASE, + routing_='r', ) edges = [get_community_edge_from_record(record) for record in records] @@ -350,6 +357,7 @@ class CommunityEdge(Edge): """, uuids=uuids, database_=DEFAULT_DATABASE, + routing_='r', ) edges = [get_community_edge_from_record(record) for record in records] @@ -371,6 +379,7 @@ class CommunityEdge(Edge): """, group_ids=group_ids, database_=DEFAULT_DATABASE, + routing_='r', ) edges = [get_community_edge_from_record(record) for record in records] diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 41f76920..14629003 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -86,7 +86,7 @@ class Node(BaseModel, ABC): async def delete(self, driver: AsyncDriver): result = await driver.execute_query( """ - MATCH (n {uuid: $uuid}) + MATCH (n:Entity|Episodic|Community {uuid: $uuid}) DETACH DELETE n """, uuid=self.uuid, @@ -105,6 +105,19 @@ class Node(BaseModel, ABC): return self.uuid == other.uuid return False + @classmethod + async def delete_by_group_id(cls, driver: AsyncDriver, group_id: str): + await driver.execute_query( + """ + MATCH (n:Entity|Episodic|Community {group_id: $group_id}) + DETACH DELETE n + """, + group_id=group_id, + database_=DEFAULT_DATABASE, + ) + + return 'SUCCESS' + @classmethod async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ... @@ -159,6 +172,7 @@ class EpisodicNode(Node): """, uuid=uuid, database_=DEFAULT_DATABASE, + routing_='r', ) episodes = [get_episodic_node_from_record(record) for record in records] @@ -185,6 +199,7 @@ class EpisodicNode(Node): """, uuids=uuids, database_=DEFAULT_DATABASE, + routing_='r', ) episodes = [get_episodic_node_from_record(record) for record in records] @@ -208,6 +223,7 @@ class EpisodicNode(Node): """, group_ids=group_ids, database_=DEFAULT_DATABASE, + routing_='r', ) episodes = [get_episodic_node_from_record(record) for record in records] @@ -259,6 +275,7 @@ class EntityNode(Node): """, uuid=uuid, database_=DEFAULT_DATABASE, + routing_='r', ) nodes = [get_entity_node_from_record(record) for record in records] @@ -283,6 +300,7 @@ class EntityNode(Node): """, uuids=uuids, database_=DEFAULT_DATABASE, + routing_='r', ) nodes = [get_entity_node_from_record(record) for record in records] @@ -304,6 +322,7 @@ class EntityNode(Node): """, group_ids=group_ids, database_=DEFAULT_DATABASE, + routing_='r', ) nodes = [get_entity_node_from_record(record) for record in records] @@ -355,6 +374,7 @@ class CommunityNode(Node): """, uuid=uuid, database_=DEFAULT_DATABASE, + routing_='r', ) nodes = [get_community_node_from_record(record) for record in records] @@ -379,6 +399,7 @@ class CommunityNode(Node): """, uuids=uuids, database_=DEFAULT_DATABASE, + routing_='r', ) communities = [get_community_node_from_record(record) for record in records] @@ -400,6 +421,7 @@ class CommunityNode(Node): """, group_ids=group_ids, database_=DEFAULT_DATABASE, + routing_='r', ) communities = [get_community_node_from_record(record) for record in records] diff --git a/pyproject.toml b/pyproject.toml index d65c42b9..bc000c7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.4.1" +version = "0.4.2" description = "A temporal graph building library" authors = [ "Paul Paliychuk ", diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 5ce2ad74..e42a24d4 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -66,40 +66,9 @@ def setup_logging(): async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) - now = datetime.now(timezone.utc) - - alice_node = EntityNode( - name='Alice', - labels=[], - created_at=now, - summary='Alice summary', - group_id='test', - ) - - bob_node = EntityNode( - name='Bob', - labels=[], - created_at=now, - summary='Bob summary', - group_id='test', - ) - - entity_edge = EntityEdge( - source_node_uuid=alice_node.uuid, - target_node_uuid=bob_node.uuid, - created_at=now, - name='likes', - fact='Alice likes Bob', - episodes=[], - expired_at=now, - valid_at=now, - group_id='test', - ) - - await graphiti.add_triplet(alice_node, entity_edge, bob_node) results = await graphiti._search( - "Emily: I can't log in", + 'My name is Alice', COMBINED_HYBRID_SEARCH_CROSS_ENCODER, group_ids=['test'], )