Add group ids (#89)

* set and retrieve group ids

* update add episode with group id support

* add episode and search functional

* update bulk

* mypy updates

* remove unused imports

* update unit tests

* unit tests

* add optional uuid field

* format

* mypy

* ellipsis
This commit is contained in:
Preston Rasmussen 2024-09-06 12:33:42 -04:00 committed by GitHub
parent c7fc057106
commit 42fb590606
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 329 additions and 356 deletions

View file

@ -69,6 +69,7 @@ async def main(use_bulk: bool = True):
episode_body=f'{message.speaker_name} ({message.role}): {message.content}', episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp, reference_time=message.actual_timestamp,
source_description='Podcast Transcript', source_description='Podcast Transcript',
group_id='1',
) )
return return

View file

@ -18,6 +18,7 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from time import time from time import time
from typing import Any
from uuid import uuid4 from uuid import uuid4
from neo4j import AsyncDriver from neo4j import AsyncDriver
@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC): class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: uuid4().hex) uuid: str = Field(default_factory=lambda: uuid4().hex)
group_id: str | None = Field(description='partition of the graph')
source_node_uuid: str source_node_uuid: str
target_node_uuid: str target_node_uuid: str
created_at: datetime created_at: datetime
@ -61,11 +63,12 @@ class EpisodicEdge(Edge):
MATCH (episode:Episodic {uuid: $episode_uuid}) MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid}) MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node) MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, created_at: $created_at} SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid""", RETURN r.uuid AS uuid""",
episode_uuid=self.source_node_uuid, episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid, entity_uuid=self.target_node_uuid,
uuid=self.uuid, uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at, created_at=self.created_at,
) )
@ -92,7 +95,8 @@ class EpisodicEdge(Edge):
""" """
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
RETURN RETURN
e.uuid As uuid, e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
e.created_at AS created_at e.created_at AS created_at
@ -100,17 +104,7 @@ class EpisodicEdge(Edge):
uuid=uuid, uuid=uuid,
) )
edges: list[EpisodicEdge] = [] edges = [get_episodic_edge_from_record(record) for record in records]
for record in records:
edges.append(
EpisodicEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)
)
logger.info(f'Found Edge: {uuid}') logger.info(f'Found Edge: {uuid}')
@ -153,7 +147,7 @@ class EntityEdge(Edge):
MATCH (source:Entity {uuid: $source_uuid}) MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid}) MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding, SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
episodes: $episodes, created_at: $created_at, expired_at: $expired_at, episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
valid_at: $valid_at, invalid_at: $invalid_at} valid_at: $valid_at, invalid_at: $invalid_at}
RETURN r.uuid AS uuid""", RETURN r.uuid AS uuid""",
@ -161,6 +155,7 @@ class EntityEdge(Edge):
target_uuid=self.target_node_uuid, target_uuid=self.target_node_uuid,
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,
group_id=self.group_id,
fact=self.fact, fact=self.fact,
fact_embedding=self.fact_embedding, fact_embedding=self.fact_embedding,
episodes=self.episodes, episodes=self.episodes,
@ -198,6 +193,7 @@ class EntityEdge(Edge):
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
e.created_at AS created_at, e.created_at AS created_at,
e.name AS name, e.name AS name,
e.group_id AS group_id,
e.fact AS fact, e.fact AS fact,
e.fact_embedding AS fact_embedding, e.fact_embedding AS fact_embedding,
e.episodes AS episodes, e.episodes AS episodes,
@ -208,25 +204,36 @@ class EntityEdge(Edge):
uuid=uuid, uuid=uuid,
) )
edges: list[EntityEdge] = [] edges = [get_entity_edge_from_record(record) for record in records]
for record in records:
edges.append(
EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
)
logger.info(f'Found Edge: {uuid}') logger.info(f'Found Edge: {uuid}')
return edges[0] return edges[0]
# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
return EpisodicEdge(
uuid=record['uuid'],
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)
def get_entity_edge_from_record(record: Any) -> EntityEdge:
return EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
group_id=record['group_id'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)

View file

@ -18,7 +18,6 @@ import asyncio
import logging import logging
from datetime import datetime from datetime import datetime
from time import time from time import time
from typing import Callable
from dotenv import load_dotenv from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
@ -120,7 +119,7 @@ class Graphiti:
Parameters Parameters
---------- ----------
None self
Returns Returns
------- -------
@ -151,7 +150,7 @@ class Graphiti:
Parameters Parameters
---------- ----------
None self
Returns Returns
------- -------
@ -178,6 +177,7 @@ class Graphiti:
self, self,
reference_time: datetime, reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN, last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None,
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
""" """
Retrieve the last n episodic nodes from the graph. Retrieve the last n episodic nodes from the graph.
@ -191,6 +191,8 @@ class Graphiti:
The reference time to retrieve episodes before. The reference time to retrieve episodes before.
last_n : int, optional last_n : int, optional
The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN. The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
group_ids : list[str | None], optional
The group ids to return data from.
Returns Returns
------- -------
@ -202,7 +204,7 @@ class Graphiti:
The actual retrieval is performed by the `retrieve_episodes` function The actual retrieval is performed by the `retrieve_episodes` function
from the `graphiti_core.utils` module. from the `graphiti_core.utils` module.
""" """
return await retrieve_episodes(self.driver, reference_time, last_n) return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
async def add_episode( async def add_episode(
self, self,
@ -211,8 +213,8 @@ class Graphiti:
source_description: str, source_description: str,
reference_time: datetime, reference_time: datetime,
source: EpisodeType = EpisodeType.message, source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None, group_id: str | None = None,
error_callback: Callable | None = None, uuid: str | None = None,
): ):
""" """
Process an episode and update the graph. Process an episode and update the graph.
@ -232,10 +234,10 @@ class Graphiti:
The reference time for the episode. The reference time for the episode.
source : EpisodeType, optional source : EpisodeType, optional
The type of the episode. Defaults to EpisodeType.message. The type of the episode. Defaults to EpisodeType.message.
success_callback : Callable | None, optional group_id : str | None
A callback function to be called upon successful processing. An id for the graph partition the episode is a part of.
error_callback : Callable | None, optional uuid : str | None
A callback function to be called if an error occurs during processing. Optional uuid of the episode.
Returns Returns
------- -------
@ -266,9 +268,12 @@ class Graphiti:
embedder = self.llm_client.get_embedder() embedder = self.llm_client.get_embedder()
now = datetime.now() now = datetime.now()
previous_episodes = await self.retrieve_episodes(reference_time, last_n=3) previous_episodes = await self.retrieve_episodes(
reference_time, last_n=3, group_ids=[group_id]
)
episode = EpisodicNode( episode = EpisodicNode(
name=name, name=name,
group_id=group_id,
labels=[], labels=[],
source=source, source=source,
content=episode_body, content=episode_body,
@ -276,6 +281,7 @@ class Graphiti:
created_at=now, created_at=now,
valid_at=reference_time, valid_at=reference_time,
) )
episode.uuid = uuid if uuid is not None else episode.uuid
# Extract entities as nodes # Extract entities as nodes
@ -299,7 +305,9 @@ class Graphiti:
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather( (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists), resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes), extract_edges(
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
),
) )
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}') logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes) nodes.extend(mentioned_nodes)
@ -388,11 +396,7 @@ class Graphiti:
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}') logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
episodic_edges: list[EpisodicEdge] = build_episodic_edges( episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
mentioned_nodes,
episode,
now,
)
logger.info(f'Built episodic edges: {episodic_edges}') logger.info(f'Built episodic edges: {episodic_edges}')
@ -405,18 +409,10 @@ class Graphiti:
end = time() end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms') logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
if success_callback:
await success_callback(episode)
except Exception as e: except Exception as e:
if error_callback: raise e
await error_callback(episode, e)
else:
raise e
async def add_episode_bulk( async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
self,
bulk_episodes: list[RawEpisode],
):
""" """
Process multiple episodes in bulk and update the graph. Process multiple episodes in bulk and update the graph.
@ -427,6 +423,8 @@ class Graphiti:
---------- ----------
bulk_episodes : list[RawEpisode] bulk_episodes : list[RawEpisode]
A list of RawEpisode objects to be processed and added to the graph. A list of RawEpisode objects to be processed and added to the graph.
group_id : str | None
An id for the graph partition the episode is a part of.
Returns Returns
------- -------
@ -463,6 +461,7 @@ class Graphiti:
source=episode.source, source=episode.source,
content=episode.content, content=episode.content,
source_description=episode.source_description, source_description=episode.source_description,
group_id=group_id,
created_at=now, created_at=now,
valid_at=episode.reference_time, valid_at=episode.reference_time,
) )
@ -527,7 +526,13 @@ class Graphiti:
except Exception as e: except Exception as e:
raise e raise e
async def search(self, query: str, center_node_uuid: str | None = None, num_results=10): async def search(
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
num_results=10,
):
""" """
Perform a hybrid search on the knowledge graph. Perform a hybrid search on the knowledge graph.
@ -540,6 +545,8 @@ class Graphiti:
The search query string. The search query string.
center_node_uuid: str, optional center_node_uuid: str, optional
Facts will be reranked based on proximity to this node Facts will be reranked based on proximity to this node
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
num_results : int, optional num_results : int, optional
The maximum number of results to return. Defaults to 10. The maximum number of results to return. Defaults to 10.
@ -562,6 +569,7 @@ class Graphiti:
num_episodes=0, num_episodes=0,
num_edges=num_results, num_edges=num_results,
num_nodes=0, num_nodes=0,
group_ids=group_ids,
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity], search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
reranker=reranker, reranker=reranker,
) )
@ -590,7 +598,10 @@ class Graphiti:
) )
async def get_nodes_by_query( async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT self,
query: str,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
""" """
Retrieve nodes from the graph database based on a text query. Retrieve nodes from the graph database based on a text query.
@ -602,6 +613,8 @@ class Graphiti:
---------- ----------
query : str query : str
The text query to search for in the graph. The text query to search for in the graph.
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
limit : int | None, optional limit : int | None, optional
The maximum number of results to return per search method. The maximum number of results to return per search method.
If None, a default limit will be applied. If None, a default limit will be applied.
@ -626,5 +639,7 @@ class Graphiti:
""" """
embedder = self.llm_client.get_embedder() embedder = self.llm_client.get_embedder()
query_embedding = await generate_embedding(embedder, query) query_embedding = await generate_embedding(embedder, query)
relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit) relevant_nodes = await hybrid_node_search(
[query], [query_embedding], self.driver, group_ids, limit
)
return relevant_nodes return relevant_nodes

View file

@ -19,10 +19,10 @@ from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from time import time from time import time
from typing import Any
from uuid import uuid4 from uuid import uuid4
from neo4j import AsyncDriver from neo4j import AsyncDriver
from openai import OpenAI
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from graphiti_core.llm_client.config import EMBEDDING_DIM from graphiti_core.llm_client.config import EMBEDDING_DIM
@ -69,6 +69,7 @@ class EpisodeType(Enum):
class Node(BaseModel, ABC): class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: uuid4().hex) uuid: str = Field(default_factory=lambda: uuid4().hex)
name: str = Field(description='name of the node') name: str = Field(description='name of the node')
group_id: str | None = Field(description='partition of the graph')
labels: list[str] = Field(default_factory=list) labels: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now()) created_at: datetime = Field(default_factory=lambda: datetime.now())
@ -106,11 +107,12 @@ class EpisodicNode(Node):
result = await driver.execute_query( result = await driver.execute_query(
""" """
MERGE (n:Episodic {uuid: $uuid}) MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content, SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid""", RETURN n.uuid AS uuid""",
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,
group_id=self.group_id,
source_description=self.source_description, source_description=self.source_description,
content=self.content, content=self.content,
entity_edges=self.entity_edges, entity_edges=self.entity_edges,
@ -141,29 +143,19 @@ class EpisodicNode(Node):
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (e:Episodic {uuid: $uuid}) MATCH (e:Episodic {uuid: $uuid})
RETURN e.content as content, RETURN e.content AS content,
e.created_at as created_at, e.created_at AS created_at,
e.valid_at as valid_at, e.valid_at AS valid_at,
e.uuid as uuid, e.uuid AS uuid,
e.name as name, e.name AS name,
e.source_description as source_description, e.group_id AS group_id
e.source as source e.source_description AS source_description,
e.source AS source
""", """,
uuid=uuid, uuid=uuid,
) )
episodes = [ episodes = [get_episodic_node_from_record(record) for record in records]
EpisodicNode(
content=record['content'],
created_at=record['created_at'].to_native().timestamp(),
valid_at=(record['valid_at'].to_native()),
uuid=record['uuid'],
source=EpisodeType.from_str(record['source']),
name=record['name'],
source_description=record['source_description'],
)
for record in records
]
logger.info(f'Found Node: {uuid}') logger.info(f'Found Node: {uuid}')
@ -174,10 +166,6 @@ class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name') name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='regional summary of surrounding edges', default_factory=str) summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
async def update_summary(self, driver: AsyncDriver): ...
async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'): async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
start = time() start = time()
text = self.name.replace('\n', ' ') text = self.name.replace('\n', ' ')
@ -192,10 +180,11 @@ class EntityNode(Node):
result = await driver.execute_query( result = await driver.execute_query(
""" """
MERGE (n:Entity {uuid: $uuid}) MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at} SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
RETURN n.uuid AS uuid""", RETURN n.uuid AS uuid""",
uuid=self.uuid, uuid=self.uuid,
name=self.name, name=self.name,
group_id=self.group_id,
summary=self.summary, summary=self.summary,
name_embedding=self.name_embedding, name_embedding=self.name_embedding,
created_at=self.created_at, created_at=self.created_at,
@ -227,25 +216,14 @@ class EntityNode(Node):
n.uuid As uuid, n.uuid As uuid,
n.name AS name, n.name AS name,
n.name_embedding AS name_embedding, n.name_embedding AS name_embedding,
n.group_id AS group_id
n.created_at AS created_at, n.created_at AS created_at,
n.summary AS summary n.summary AS summary
""", """,
uuid=uuid, uuid=uuid,
) )
nodes: list[EntityNode] = [] nodes = [get_entity_node_from_record(record) for record in records]
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
logger.info(f'Found Node: {uuid}') logger.info(f'Found Node: {uuid}')
@ -253,3 +231,26 @@ class EntityNode(Node):
# Node helpers # Node helpers
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
return EpisodicNode(
content=record['content'],
created_at=record['created_at'].to_native().timestamp(),
valid_at=(record['valid_at'].to_native()),
uuid=record['uuid'],
group_id=record['group_id'],
source=EpisodeType.from_str(record['source']),
name=record['name'],
source_description=record['source_description'],
)
def get_entity_node_from_record(record: Any) -> EntityNode:
return EntityNode(
uuid=record['uuid'],
name=record['name'],
group_id=record['group_id'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)

View file

@ -52,6 +52,7 @@ class SearchConfig(BaseModel):
num_edges: int = Field(default=10) num_edges: int = Field(default=10)
num_nodes: int = Field(default=10) num_nodes: int = Field(default=10)
num_episodes: int = EPISODE_WINDOW_LEN num_episodes: int = EPISODE_WINDOW_LEN
group_ids: list[str | None] | None
search_methods: list[SearchMethod] search_methods: list[SearchMethod]
reranker: Reranker | None reranker: Reranker | None
@ -83,7 +84,9 @@ async def hybrid_search(
nodes.extend(await get_mentioned_nodes(driver, episodes)) nodes.extend(await get_mentioned_nodes(driver, episodes))
if SearchMethod.bm25 in config.search_methods: if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(driver, query, None, None, 2 * config.num_edges) text_search = await edge_fulltext_search(
driver, query, None, None, config.group_ids, 2 * config.num_edges
)
search_results.append(text_search) search_results.append(text_search)
if SearchMethod.cosine_similarity in config.search_methods: if SearchMethod.cosine_similarity in config.search_methods:
@ -95,7 +98,7 @@ async def hybrid_search(
) )
similarity_search = await edge_similarity_search( similarity_search = await edge_similarity_search(
driver, search_vector, None, None, 2 * config.num_edges driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
) )
search_results.append(similarity_search) search_results.append(similarity_search)

View file

@ -3,13 +3,11 @@ import logging
import re import re
from collections import defaultdict from collections import defaultdict
from time import time from time import time
from typing import Any
from neo4j import AsyncDriver, Query from neo4j import AsyncDriver, Query
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.helpers import parse_db_date from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record
from graphiti_core.nodes import EntityNode, EpisodicNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +21,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT RETURN DISTINCT
n.uuid As uuid, n.uuid As uuid,
n.group_id AS group_id,
n.name AS name, n.name AS name,
n.name_embedding AS name_embedding n.name_embedding AS name_embedding
n.created_at AS created_at, n.created_at AS created_at,
@ -31,86 +30,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
uuids=episode_uuids, uuids=episode_uuids,
) )
nodes: list[EntityNode] = [] nodes = [get_entity_node_from_record(record) for record in records]
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
return nodes return nodes
async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query(
"""
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
RETURN DISTINCT
n.uuid AS source_node_uuid,
n.name AS source_name,
n.summary AS source_summary,
m.uuid AS target_node_uuid,
m.name AS target_name,
m.summary AS target_summary,
r.uuid AS uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
""",
node_ids=node_ids,
)
context: dict[str, Any] = {}
for record in records:
n_uuid = record['source_node_uuid']
if n_uuid in context:
context[n_uuid]['facts'].append(record['fact'])
else:
context[n_uuid] = {
'name': record['source_name'],
'summary': record['source_summary'],
'facts': [record['fact']],
}
m_uuid = record['target_node_uuid']
if m_uuid not in context:
context[m_uuid] = {
'name': record['target_name'],
'summary': record['target_summary'],
'facts': [],
}
logger.info(f'bfs search returned context: {context}')
return context
async def edge_similarity_search( async def edge_similarity_search(
driver: AsyncDriver, driver: AsyncDriver,
search_vector: list[float], search_vector: list[float],
source_node_uuid: str | None, source_node_uuid: str | None,
target_node_uuid: str | None, target_node_uuid: str | None,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over embedded facts # vector similarity search over embedded facts
query = Query(""" query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
r.created_at AS created_at, r.created_at AS created_at,
@ -129,8 +71,10 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
r.created_at AS created_at, r.created_at AS created_at,
@ -148,8 +92,10 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
r.created_at AS created_at, r.created_at AS created_at,
@ -167,8 +113,10 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
r.created_at AS created_at, r.created_at AS created_at,
@ -187,41 +135,32 @@ async def edge_similarity_search(
search_vector=search_vector, search_vector=search_vector,
source_uuid=source_node_uuid, source_uuid=source_node_uuid,
target_uuid=target_node_uuid, target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit, limit=limit,
) )
edges: list[EntityEdge] = [] edges = [get_entity_edge_from_record(record) for record in records]
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
edges.append(edge)
return edges return edges
async def entity_similarity_search( async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT search_vector: list[float],
driver: AsyncDriver,
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over entity names # vector similarity search over entity names
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score YIELD node AS n, score
MATCH (n WHERE n.group_id IN $group_ids)
RETURN RETURN
n.uuid As uuid, n.uuid As uuid,
n.group_id AS group_id,
n.name AS name, n.name AS name,
n.name_embedding AS name_embedding, n.name_embedding AS name_embedding,
n.created_at AS created_at, n.created_at AS created_at,
@ -229,58 +168,44 @@ async def entity_similarity_search(
ORDER BY score DESC ORDER BY score DESC
""", """,
search_vector=search_vector, search_vector=search_vector,
group_ids=group_ids,
limit=limit, limit=limit,
) )
nodes: list[EntityNode] = [] nodes = [get_entity_node_from_record(record) for record in records]
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
return nodes return nodes
async def entity_fulltext_search( async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT query: str,
driver: AsyncDriver,
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
group_ids = group_ids if group_ids is not None else [None]
# BM25 search to get top nodes # BM25 search to get top nodes
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score CALL db.index.fulltext.queryNodes("name_and_summary", $query)
YIELD node AS n, score
MATCH (n WHERE n.group_id in $group_ids)
RETURN RETURN
node.uuid AS uuid, n.uuid AS uuid,
node.name AS name, n.group_id AS group_id,
node.name_embedding AS name_embedding, n.name AS name,
node.created_at AS created_at, n.name_embedding AS name_embedding,
node.summary AS summary n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
""", """,
query=fuzzy_query, query=fuzzy_query,
group_ids=group_ids,
limit=limit, limit=limit,
) )
nodes: list[EntityNode] = [] nodes = [get_entity_node_from_record(record) for record in records]
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record['name_embedding'],
labels=['Entity'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
return nodes return nodes
@ -290,15 +215,20 @@ async def edge_fulltext_search(
query: str, query: str,
source_node_uuid: str | None, source_node_uuid: str | None,
target_node_uuid: str | None, target_node_uuid: str | None,
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
group_ids = group_ids if group_ids is not None else [None]
# fulltext search over facts # fulltext search over facts
cypher_query = Query(""" cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query) CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
r.created_at AS created_at, r.created_at AS created_at,
@ -317,8 +247,10 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query) CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
r.created_at AS created_at, r.created_at AS created_at,
@ -335,9 +267,11 @@ async def edge_fulltext_search(
cypher_query = Query(""" cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query) CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
r.created_at AS created_at, r.created_at AS created_at,
@ -354,9 +288,11 @@ async def edge_fulltext_search(
cypher_query = Query(""" cypher_query = Query("""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query) CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid, m.uuid AS target_node_uuid,
r.created_at AS created_at, r.created_at AS created_at,
@ -377,27 +313,11 @@ async def edge_fulltext_search(
query=fuzzy_query, query=fuzzy_query,
source_uuid=source_node_uuid, source_uuid=source_node_uuid,
target_uuid=target_node_uuid, target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit, limit=limit,
) )
edges: list[EntityEdge] = [] edges = [get_entity_edge_from_record(record) for record in records]
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
edges.append(edge)
return edges return edges
@ -406,6 +326,7 @@ async def hybrid_node_search(
queries: list[str], queries: list[str],
embeddings: list[list[float]], embeddings: list[list[float]],
driver: AsyncDriver, driver: AsyncDriver,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
""" """
@ -422,6 +343,8 @@ async def hybrid_node_search(
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed. A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
driver : AsyncDriver driver : AsyncDriver
The Neo4j driver instance for database operations. The Neo4j driver instance for database operations.
group_ids : list[str] | None, optional
The list of group ids to retrieve nodes from.
limit : int | None, optional limit : int | None, optional
The maximum number of results to return per search method. If None, a default limit will be applied. The maximum number of results to return per search method. If None, a default limit will be applied.
@ -448,8 +371,8 @@ async def hybrid_node_search(
results: list[list[EntityNode]] = list( results: list[list[EntityNode]] = list(
await asyncio.gather( await asyncio.gather(
*[entity_fulltext_search(q, driver, 2 * limit) for q in queries], *[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries],
*[entity_similarity_search(e, driver, 2 * limit) for e in embeddings], *[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings],
) )
) )
@ -500,6 +423,7 @@ async def get_relevant_nodes(
[node.name for node in nodes], [node.name for node in nodes],
[node.name_embedding for node in nodes if node.name_embedding is not None], [node.name_embedding for node in nodes if node.name_embedding is not None],
driver, driver,
[node.group_id for node in nodes],
) )
return relevant_nodes return relevant_nodes
@ -518,13 +442,20 @@ async def get_relevant_edges(
results = await asyncio.gather( results = await asyncio.gather(
*[ *[
edge_similarity_search( edge_similarity_search(
driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit driver,
edge.fact_embedding,
source_node_uuid,
target_node_uuid,
[edge.group_id],
limit,
) )
for edge in edges for edge in edges
if edge.fact_embedding is not None if edge.fact_embedding is not None
], ],
*[ *[
edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit) edge_fulltext_search(
driver, edge.fact, source_node_uuid, target_node_uuid, [edge.group_id], limit
)
for edge in edges for edge in edges
], ],
) )

View file

@ -17,6 +17,7 @@ limitations under the License.
import asyncio import asyncio
import logging import logging
import typing import typing
from collections import defaultdict
from datetime import datetime from datetime import datetime
from math import ceil from math import ceil
@ -42,7 +43,6 @@ from graphiti_core.utils.maintenance.node_operations import (
extract_nodes, extract_nodes,
) )
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
from graphiti_core.utils.utils import chunk_edges_by_nodes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -62,7 +62,9 @@ async def retrieve_previous_episodes_bulk(
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await asyncio.gather( previous_episodes_list = await asyncio.gather(
*[ *[
retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN) retrieve_episodes(
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
)
for episode in episodes for episode in episodes
] ]
) )
@ -90,7 +92,13 @@ async def extract_nodes_and_edges_bulk(
extracted_edges_bulk = await asyncio.gather( extracted_edges_bulk = await asyncio.gather(
*[ *[
extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i]) extract_edges(
llm_client,
episode,
extracted_nodes_bulk[i],
previous_episodes_list[i],
episode.group_id,
)
for i, episode in enumerate(episodes) for i, episode in enumerate(episodes)
] ]
) )
@ -343,3 +351,23 @@ async def extract_edge_dates_bulk(
edge.expired_at = datetime.now() edge.expired_at = datetime.now()
return edges return edges
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
# We only want to dedupe edges that are between the same pair of nodes
# We build a map of the edges based on their source and target nodes.
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
for edge in edges:
# We drop loop edges
if edge.source_node_uuid == edge.target_node_uuid:
continue
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
pointers = [edge.source_node_uuid, edge.target_node_uuid]
pointers.sort()
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
return edge_chunks

View file

@ -37,15 +37,15 @@ def build_episodic_edges(
episode: EpisodicNode, episode: EpisodicNode,
created_at: datetime, created_at: datetime,
) -> List[EpisodicEdge]: ) -> List[EpisodicEdge]:
edges: List[EpisodicEdge] = [] edges: List[EpisodicEdge] = [
EpisodicEdge(
for node in entity_nodes:
edge = EpisodicEdge(
source_node_uuid=episode.uuid, source_node_uuid=episode.uuid,
target_node_uuid=node.uuid, target_node_uuid=node.uuid,
created_at=created_at, created_at=created_at,
group_id=episode.group_id,
) )
edges.append(edge) for node in entity_nodes
]
return edges return edges
@ -55,6 +55,7 @@ async def extract_edges(
episode: EpisodicNode, episode: EpisodicNode,
nodes: list[EntityNode], nodes: list[EntityNode],
previous_episodes: list[EpisodicNode], previous_episodes: list[EpisodicNode],
group_id: str | None,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
start = time() start = time()
@ -88,6 +89,7 @@ async def extract_edges(
source_node_uuid=edge_data['source_node_uuid'], source_node_uuid=edge_data['source_node_uuid'],
target_node_uuid=edge_data['target_node_uuid'], target_node_uuid=edge_data['target_node_uuid'],
name=edge_data['relation_type'], name=edge_data['relation_type'],
group_id=group_id,
fact=edge_data['fact'], fact=edge_data['fact'],
episodes=[episode.uuid], episodes=[episode.uuid],
created_at=datetime.now(), created_at=datetime.now(),

View file

@ -34,6 +34,10 @@ async def build_indices_and_constraints(driver: AsyncDriver):
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
@ -86,6 +90,7 @@ async def retrieve_episodes(
driver: AsyncDriver, driver: AsyncDriver,
reference_time: datetime, reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN, last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None,
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
""" """
Retrieve the last n episodic nodes from the graph. Retrieve the last n episodic nodes from the graph.
@ -96,25 +101,28 @@ async def retrieve_episodes(
less than or equal to this reference_time will be retrieved. This allows for less than or equal to this reference_time will be retrieved. This allows for
querying the graph's state at a specific point in time. querying the graph's state at a specific point in time.
last_n (int, optional): The number of most recent episodes to retrieve, relative to the reference_time. last_n (int, optional): The number of most recent episodes to retrieve, relative to the reference_time.
group_ids (list[str], optional): The list of group ids to return data from.
Returns: Returns:
list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes. list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
""" """
result = await driver.execute_query( result = await driver.execute_query(
""" """
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids
RETURN e.content as content, RETURN e.content AS content,
e.created_at as created_at, e.created_at AS created_at,
e.valid_at as valid_at, e.valid_at AS valid_at,
e.uuid as uuid, e.uuid AS uuid,
e.name as name, e.group_id AS group_id,
e.source_description as source_description, e.name AS name,
e.source as source e.source_description AS source_description,
e.source AS source
ORDER BY e.created_at DESC ORDER BY e.created_at DESC
LIMIT $num_episodes LIMIT $num_episodes
""", """,
reference_time=reference_time, reference_time=reference_time,
num_episodes=last_n, num_episodes=last_n,
group_ids=group_ids,
) )
episodes = [ episodes = [
EpisodicNode( EpisodicNode(
@ -124,6 +132,7 @@ async def retrieve_episodes(
), ),
valid_at=(record['valid_at'].to_native()), valid_at=(record['valid_at'].to_native()),
uuid=record['uuid'], uuid=record['uuid'],
group_id=record['group_id'],
source=EpisodeType.from_str(record['source']), source=EpisodeType.from_str(record['source']),
name=record['name'], name=record['name'],
source_description=record['source_description'], source_description=record['source_description'],

View file

@ -85,6 +85,7 @@ async def extract_nodes(
for node_data in extracted_node_data: for node_data in extracted_node_data:
new_node = EntityNode( new_node = EntityNode(
name=node_data['name'], name=node_data['name'],
group_id=episode.group_id,
labels=node_data['labels'], labels=node_data['labels'],
summary=node_data['summary'], summary=node_data['summary'],
created_at=datetime.now(), created_at=datetime.now(),

View file

@ -1,60 +0,0 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections import defaultdict
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.nodes import EntityNode, EpisodicNode
logger = logging.getLogger(__name__)
def build_episodic_edges(
entity_nodes: list[EntityNode], episode: EpisodicNode
) -> list[EpisodicEdge]:
edges: list[EpisodicEdge] = []
for node in entity_nodes:
edges.append(
EpisodicEdge(
source_node_uuid=episode.uuid,
target_node_uuid=node.uuid,
created_at=episode.created_at,
)
)
return edges
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
# We only want to dedupe edges that are between the same pair of nodes
# We build a map of the edges based on their source and target nodes.
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
for edge in edges:
# We drop loop edges
if edge.source_node_uuid == edge.target_node_uuid:
continue
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
pointers = [edge.source_node_uuid, edge.target_node_uuid]
pointers.sort()
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
return edge_chunks

View file

@ -74,15 +74,15 @@ async def test_graphiti_init():
logger = setup_logging() logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
edges = await graphiti.search('Freakenomics guest') edges = await graphiti.search('Freakenomics guest', group_ids=['1'])
logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges])) logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges]))
edges = await graphiti.search('tania tetlow\n') edges = await graphiti.search('tania tetlow', group_ids=['1'])
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges])) logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
edges = await graphiti.search('issues with higher ed') edges = await graphiti.search('issues with higher ed', group_ids=['1'])
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges])) logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
graphiti.close() graphiti.close()

View file

@ -33,9 +33,9 @@ def create_test_data():
now = datetime.now() now = datetime.now()
# Create nodes # Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now) node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
# Create edges # Create edges
existing_edge1 = EntityEdge( existing_edge1 = EntityEdge(
@ -45,6 +45,7 @@ def create_test_data():
name='KNOWS', name='KNOWS',
fact='Node1 knows Node2', fact='Node1 knows Node2',
created_at=now, created_at=now,
group_id='1',
) )
existing_edge2 = EntityEdge( existing_edge2 = EntityEdge(
uuid='e2', uuid='e2',
@ -53,6 +54,7 @@ def create_test_data():
name='LIKES', name='LIKES',
fact='Node2 likes Node3', fact='Node2 likes Node3',
created_at=now, created_at=now,
group_id='1',
) )
new_edge1 = EntityEdge( new_edge1 = EntityEdge(
uuid='e3', uuid='e3',
@ -61,6 +63,7 @@ def create_test_data():
name='WORKS_WITH', name='WORKS_WITH',
fact='Node1 works with Node3', fact='Node1 works with Node3',
created_at=now, created_at=now,
group_id='1',
) )
new_edge2 = EntityEdge( new_edge2 = EntityEdge(
uuid='e4', uuid='e4',
@ -69,6 +72,7 @@ def create_test_data():
name='DISLIKES', name='DISLIKES',
fact='Node1 dislikes Node2', fact='Node1 dislikes Node2',
created_at=now, created_at=now,
group_id='1',
) )
return { return {
@ -135,9 +139,9 @@ def test_prepare_invalidation_context():
now = datetime.now() now = datetime.now()
# Create nodes # Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now) node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now, group_id='1')
# Create edges # Create edges
edge1 = EntityEdge( edge1 = EntityEdge(
@ -147,6 +151,7 @@ def test_prepare_invalidation_context():
name='KNOWS', name='KNOWS',
fact='Node1 knows Node2', fact='Node1 knows Node2',
created_at=now, created_at=now,
group_id='1',
) )
edge2 = EntityEdge( edge2 = EntityEdge(
uuid='e2', uuid='e2',
@ -155,6 +160,7 @@ def test_prepare_invalidation_context():
name='LIKES', name='LIKES',
fact='Node2 likes Node3', fact='Node2 likes Node3',
created_at=now, created_at=now,
group_id='1',
) )
# Create NodeEdgeNodeTriplet objects # Create NodeEdgeNodeTriplet objects
@ -173,6 +179,7 @@ def test_prepare_invalidation_context():
valid_at=now, valid_at=now,
source=EpisodeType.message, source=EpisodeType.message,
source_description='Test episode for unit testing', source_description='Test episode for unit testing',
group_id='1',
) )
previous_episodes = [ previous_episodes = [
EpisodicNode( EpisodicNode(
@ -182,6 +189,7 @@ def test_prepare_invalidation_context():
valid_at=now - timedelta(days=1), valid_at=now - timedelta(days=1),
source=EpisodeType.message, source=EpisodeType.message,
source_description='Test previous episode 1 for unit testing', source_description='Test previous episode 1 for unit testing',
group_id='1',
), ),
EpisodicNode( EpisodicNode(
name='Previous Episode 2', name='Previous Episode 2',
@ -190,6 +198,7 @@ def test_prepare_invalidation_context():
valid_at=now - timedelta(days=2), valid_at=now - timedelta(days=2),
source=EpisodeType.message, source=EpisodeType.message,
source_description='Test previous episode 2 for unit testing', source_description='Test previous episode 2 for unit testing',
group_id='1',
), ),
] ]
@ -235,6 +244,7 @@ def test_prepare_invalidation_context_empty_input():
valid_at=now, valid_at=now,
source=EpisodeType.message, source=EpisodeType.message,
source_description='Test empty episode for unit testing', source_description='Test empty episode for unit testing',
group_id='1',
) )
result = prepare_invalidation_context([], [], current_episode, []) result = prepare_invalidation_context([], [], current_episode, [])
assert isinstance(result, dict) assert isinstance(result, dict)
@ -252,8 +262,8 @@ def test_prepare_invalidation_context_sorting():
now = datetime.now() now = datetime.now()
# Create nodes # Create nodes
node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now, group_id='1')
# Create edges with different timestamps # Create edges with different timestamps
edge1 = EntityEdge( edge1 = EntityEdge(
@ -263,6 +273,7 @@ def test_prepare_invalidation_context_sorting():
name='KNOWS', name='KNOWS',
fact='Node1 knows Node2', fact='Node1 knows Node2',
created_at=now, created_at=now,
group_id='1',
) )
edge2 = EntityEdge( edge2 = EntityEdge(
uuid='e2', uuid='e2',
@ -271,6 +282,7 @@ def test_prepare_invalidation_context_sorting():
name='LIKES', name='LIKES',
fact='Node2 likes Node1', fact='Node2 likes Node1',
created_at=now + timedelta(hours=1), created_at=now + timedelta(hours=1),
group_id='1',
) )
edge_with_nodes1 = (node1, edge1, node2) edge_with_nodes1 = (node1, edge1, node2)
@ -287,6 +299,7 @@ def test_prepare_invalidation_context_sorting():
valid_at=now, valid_at=now,
source=EpisodeType.message, source=EpisodeType.message,
source_description='Test episode for unit testing', source_description='Test episode for unit testing',
group_id='1',
) )
previous_episodes = [ previous_episodes = [
EpisodicNode( EpisodicNode(
@ -296,6 +309,7 @@ def test_prepare_invalidation_context_sorting():
valid_at=now - timedelta(days=1), valid_at=now - timedelta(days=1),
source=EpisodeType.message, source=EpisodeType.message,
source_description='Test previous episode for unit testing', source_description='Test previous episode for unit testing',
group_id='1',
), ),
] ]
@ -321,6 +335,7 @@ class TestExtractDateStringsFromEdge(unittest.TestCase):
created_at=datetime.now(), created_at=datetime.now(),
valid_at=valid_at, valid_at=valid_at,
invalid_at=invalid_at, invalid_at=invalid_at,
group_id='1',
) )
def test_both_dates_present(self): def test_both_dates_present(self):

View file

@ -76,6 +76,7 @@ def create_test_data():
valid_at=now, valid_at=now,
source=EpisodeType.message, source=EpisodeType.message,
source_description='Test episode for unit testing', source_description='Test episode for unit testing',
group_id='1',
) )
# Create previous episodes # Create previous episodes
@ -87,6 +88,7 @@ def create_test_data():
valid_at=now - timedelta(days=1), valid_at=now - timedelta(days=1),
source=EpisodeType.message, source=EpisodeType.message,
source_description='Test previous episode for unit testing', source_description='Test previous episode for unit testing',
group_id='1',
) )
] ]
@ -142,10 +144,12 @@ def create_complex_test_data():
now = datetime.now() now = datetime.now()
# Create nodes # Create nodes
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now) node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now) node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now) node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now, group_id='1')
node4 = EntityNode(uuid='4', name='Company XYZ', labels=['Organization'], created_at=now) node4 = EntityNode(
uuid='4', name='Company XYZ', labels=['Organization'], created_at=now, group_id='1'
)
# Create edges # Create edges
edge1 = EntityEdge( edge1 = EntityEdge(
@ -154,6 +158,7 @@ def create_complex_test_data():
target_node_uuid='2', target_node_uuid='2',
name='LIKES', name='LIKES',
fact='Alice likes Bob', fact='Alice likes Bob',
group_id='1',
created_at=now - timedelta(days=5), created_at=now - timedelta(days=5),
) )
edge2 = EntityEdge( edge2 = EntityEdge(
@ -162,6 +167,7 @@ def create_complex_test_data():
target_node_uuid='3', target_node_uuid='3',
name='FRIENDS_WITH', name='FRIENDS_WITH',
fact='Alice is friends with Charlie', fact='Alice is friends with Charlie',
group_id='1',
created_at=now - timedelta(days=3), created_at=now - timedelta(days=3),
) )
edge3 = EntityEdge( edge3 = EntityEdge(
@ -170,6 +176,7 @@ def create_complex_test_data():
target_node_uuid='4', target_node_uuid='4',
name='WORKS_FOR', name='WORKS_FOR',
fact='Bob works for Company XYZ', fact='Bob works for Company XYZ',
group_id='1',
created_at=now - timedelta(days=2), created_at=now - timedelta(days=2),
) )
@ -199,6 +206,7 @@ async def test_invalidate_edges_complex():
target_node_uuid='2', target_node_uuid='2',
name='DISLIKES', name='DISLIKES',
fact='Alice dislikes Bob', fact='Alice dislikes Bob',
group_id='1',
created_at=datetime.now(), created_at=datetime.now(),
), ),
nodes[1], nodes[1],
@ -225,6 +233,7 @@ async def test_invalidate_edges_temporal_update():
target_node_uuid='4', target_node_uuid='4',
name='LEFT_JOB', name='LEFT_JOB',
fact='Bob left his job at Company XYZ', fact='Bob left his job at Company XYZ',
group_id='1',
created_at=datetime.now(), created_at=datetime.now(),
), ),
nodes[3], nodes[3],
@ -251,6 +260,7 @@ async def test_invalidate_edges_multiple_invalidations():
target_node_uuid='2', target_node_uuid='2',
name='ENEMIES_WITH', name='ENEMIES_WITH',
fact='Alice and Bob are now enemies', fact='Alice and Bob are now enemies',
group_id='1',
created_at=datetime.now(), created_at=datetime.now(),
), ),
nodes[1], nodes[1],
@ -263,6 +273,7 @@ async def test_invalidate_edges_multiple_invalidations():
target_node_uuid='3', target_node_uuid='3',
name='ENDED_FRIENDSHIP', name='ENDED_FRIENDSHIP',
fact='Alice ended her friendship with Charlie', fact='Alice ended her friendship with Charlie',
group_id='1',
created_at=datetime.now(), created_at=datetime.now(),
), ),
nodes[2], nodes[2],
@ -292,6 +303,7 @@ async def test_invalidate_edges_no_effect():
target_node_uuid='4', target_node_uuid='4',
name='APPLIED_TO', name='APPLIED_TO',
fact='Charlie applied to Company XYZ', fact='Charlie applied to Company XYZ',
group_id='1',
created_at=datetime.now(), created_at=datetime.now(),
), ),
nodes[3], nodes[3],
@ -316,6 +328,7 @@ async def test_invalidate_edges_partial_update():
target_node_uuid='4', target_node_uuid='4',
name='CHANGED_POSITION', name='CHANGED_POSITION',
fact='Bob changed his position at Company XYZ', fact='Bob changed his position at Company XYZ',
group_id='1',
created_at=datetime.now(), created_at=datetime.now(),
), ),
nodes[3], nodes[3],

View file

@ -19,12 +19,12 @@ async def test_hybrid_node_search_deduplication():
) as mock_similarity_search: ) as mock_similarity_search:
# Set up mock return values # Set up mock return values
mock_fulltext_search.side_effect = [ mock_fulltext_search.side_effect = [
[EntityNode(uuid='1', name='Alice', labels=['Entity'])], [EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
[EntityNode(uuid='2', name='Bob', labels=['Entity'])], [EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1')],
] ]
mock_similarity_search.side_effect = [ mock_similarity_search.side_effect = [
[EntityNode(uuid='1', name='Alice', labels=['Entity'])], [EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
[EntityNode(uuid='3', name='Charlie', labels=['Entity'])], [EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1')],
] ]
# Call the function with test data # Call the function with test data
@ -70,7 +70,9 @@ async def test_hybrid_node_search_only_fulltext():
) as mock_fulltext_search, patch( ) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.entity_similarity_search' 'graphiti_core.search.search_utils.entity_similarity_search'
) as mock_similarity_search: ) as mock_similarity_search:
mock_fulltext_search.return_value = [EntityNode(uuid='1', name='Alice', labels=['Entity'])] mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')
]
mock_similarity_search.return_value = [] mock_similarity_search.return_value = []
queries = ['Alice'] queries = ['Alice']
@ -93,18 +95,23 @@ async def test_hybrid_node_search_with_limit():
'graphiti_core.search.search_utils.entity_similarity_search' 'graphiti_core.search.search_utils.entity_similarity_search'
) as mock_similarity_search: ) as mock_similarity_search:
mock_fulltext_search.return_value = [ mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity']), EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
EntityNode(uuid='2', name='Bob', labels=['Entity']), EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
] ]
mock_similarity_search.return_value = [ mock_similarity_search.return_value = [
EntityNode(uuid='3', name='Charlie', labels=['Entity']), EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
EntityNode(uuid='4', name='David', labels=['Entity']), EntityNode(
uuid='4',
name='David',
labels=['Entity'],
group_id='1',
),
] ]
queries = ['Test'] queries = ['Test']
embeddings = [[0.1, 0.2, 0.3]] embeddings = [[0.1, 0.2, 0.3]]
limit = 1 limit = 1
results = await hybrid_node_search(queries, embeddings, mock_driver, limit) results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
# We expect 4 results because the limit is applied per search method # We expect 4 results because the limit is applied per search method
# before deduplication, and we're not actually limiting the results # before deduplication, and we're not actually limiting the results
@ -113,8 +120,8 @@ async def test_hybrid_node_search_with_limit():
assert mock_fulltext_search.call_count == 1 assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1 assert mock_similarity_search.call_count == 1
# Verify that the limit was passed to the search functions # Verify that the limit was passed to the search functions
mock_fulltext_search.assert_called_with('Test', mock_driver, 2) mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 2)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 2) mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 2)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -127,18 +134,18 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
'graphiti_core.search.search_utils.entity_similarity_search' 'graphiti_core.search.search_utils.entity_similarity_search'
) as mock_similarity_search: ) as mock_similarity_search:
mock_fulltext_search.return_value = [ mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity']), EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
EntityNode(uuid='2', name='Bob', labels=['Entity']), EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
] ]
mock_similarity_search.return_value = [ mock_similarity_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity']), # Duplicate EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), # Duplicate
EntityNode(uuid='3', name='Charlie', labels=['Entity']), EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
] ]
queries = ['Test'] queries = ['Test']
embeddings = [[0.1, 0.2, 0.3]] embeddings = [[0.1, 0.2, 0.3]]
limit = 2 limit = 2
results = await hybrid_node_search(queries, embeddings, mock_driver, limit) results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit)
# We expect 3 results because: # We expect 3 results because:
# 1. The limit of 2 is applied to each search method # 1. The limit of 2 is applied to each search method
@ -148,5 +155,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'} assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
assert mock_fulltext_search.call_count == 1 assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1 assert mock_similarity_search.call_count == 1
mock_fulltext_search.assert_called_with('Test', mock_driver, 4) mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 4)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, 4) mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 4)