Group id fix (#152)

* node distance and group_ids fixed

* get all with no group_id passed

* push

* push

* remove comments

* mypy

* mypy ids

* please mypy

* trust

* last one
This commit is contained in:
Preston Rasmussen 2024-09-24 15:55:30 -04:00 committed by GitHub
parent cfeb58daba
commit 794b705664
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 93 additions and 110 deletions

View file

@ -63,28 +63,27 @@ async def main(use_bulk: bool = True):
messages = parse_podcast_messages() messages = parse_podcast_messages()
if not use_bulk: if not use_bulk:
for i, message in enumerate(messages[3:4]): for i, message in enumerate(messages[3:14]):
await client.add_episode( await client.add_episode(
name=f'Message {i}', name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}', episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp, reference_time=message.actual_timestamp,
source_description='Podcast Transcript', source_description='Podcast Transcript',
group_id='1',
) )
# build communities # build communities
await client.build_communities() await client.build_communities()
# add additional messages to update communities # add additional messages to update communities
# for i, message in enumerate(messages[14:20]): for i, message in enumerate(messages[14:20]):
# await client.add_episode( await client.add_episode(
# name=f'Message {i}', name=f'Message {i}',
# episode_body=f'{message.speaker_name} ({message.role}): {message.content}', episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
# reference_time=message.actual_timestamp, reference_time=message.actual_timestamp,
# source_description='Podcast Transcript', source_description='Podcast Transcript',
# group_id='1', group_id='1',
# update_communities=True, update_communities=True,
# ) )
return return

View file

@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC): class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4())) uuid: str = Field(default_factory=lambda: str(uuid4()))
group_id: str | None = Field(description='partition of the graph') group_id: str = Field(description='partition of the graph')
source_node_uuid: str source_node_uuid: str
target_node_uuid: str target_node_uuid: str
created_at: datetime created_at: datetime
@ -131,7 +131,7 @@ class EpisodicEdge(Edge):
return edges return edges
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
@ -270,7 +270,7 @@ class EntityEdge(Edge):
return edges return edges
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
@ -360,7 +360,7 @@ class CommunityEdge(Edge):
return edges return edges
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community) MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)

View file

@ -197,7 +197,7 @@ class Graphiti:
self, self,
reference_time: datetime, reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN, last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
""" """
Retrieve the last n episodic nodes from the graph. Retrieve the last n episodic nodes from the graph.
@ -233,7 +233,7 @@ class Graphiti:
source_description: str, source_description: str,
reference_time: datetime, reference_time: datetime,
source: EpisodeType = EpisodeType.message, source: EpisodeType = EpisodeType.message,
group_id: str | None = None, group_id: str = '',
uuid: str | None = None, uuid: str | None = None,
update_communities: bool = False, update_communities: bool = False,
): ):
@ -446,7 +446,7 @@ class Graphiti:
except Exception as e: except Exception as e:
raise e raise e
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None): async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
""" """
Process multiple episodes in bulk and update the graph. Process multiple episodes in bulk and update the graph.
@ -577,7 +577,7 @@ class Graphiti:
self, self,
query: str, query: str,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
num_results=DEFAULT_SEARCH_LIMIT, num_results=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
""" """
@ -633,7 +633,7 @@ class Graphiti:
self, self,
query: str, query: str,
config: SearchConfig, config: SearchConfig,
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
) -> SearchResults: ) -> SearchResults:
return await search( return await search(
@ -644,7 +644,7 @@ class Graphiti:
self, self,
query: str, query: str,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
limit: int = DEFAULT_SEARCH_LIMIT, limit: int = DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
""" """

View file

@ -29,7 +29,7 @@ from .errors import RateLimitError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gpt-4o-2024-08-06' DEFAULT_MODEL = 'gpt-4o-mini'
class OpenAIClient(LLMClient): class OpenAIClient(LLMClient):

View file

@ -70,7 +70,7 @@ class EpisodeType(Enum):
class Node(BaseModel, ABC): class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4())) uuid: str = Field(default_factory=lambda: str(uuid4()))
name: str = Field(description='name of the node') name: str = Field(description='name of the node')
group_id: str | None = Field(description='partition of the graph') group_id: str = Field(description='partition of the graph')
labels: list[str] = Field(default_factory=list) labels: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now()) created_at: datetime = Field(default_factory=lambda: datetime.now())
@ -186,7 +186,7 @@ class EpisodicNode(Node):
return episodes return episodes
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (e:Episodic) WHERE e.group_id IN $group_ids MATCH (e:Episodic) WHERE e.group_id IN $group_ids
@ -281,7 +281,7 @@ class EntityNode(Node):
return nodes return nodes
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Entity) WHERE n.group_id IN $group_ids MATCH (n:Entity) WHERE n.group_id IN $group_ids
@ -374,7 +374,7 @@ class CommunityNode(Node):
return communities return communities
@classmethod @classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
MATCH (n:Community) WHERE n.group_id IN $group_ids MATCH (n:Community) WHERE n.group_id IN $group_ids

View file

@ -15,6 +15,7 @@ limitations under the License.
""" """
import logging import logging
from collections import defaultdict
from time import time from time import time
from neo4j import AsyncDriver from neo4j import AsyncDriver
@ -56,7 +57,7 @@ async def search(
driver: AsyncDriver, driver: AsyncDriver,
embedder, embedder,
query: str, query: str,
group_ids: list[str | None] | None, group_ids: list[str] | None,
config: SearchConfig, config: SearchConfig,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
) -> SearchResults: ) -> SearchResults:
@ -103,7 +104,7 @@ async def edge_search(
driver: AsyncDriver, driver: AsyncDriver,
embedder, embedder,
query: str, query: str,
group_ids: list[str | None] | None, group_ids: list[str] | None,
config: EdgeSearchConfig, config: EdgeSearchConfig,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
@ -140,14 +141,21 @@ async def edge_search(
if center_node_uuid is None: if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker') raise SearchRerankerError('No center node provided for Node Distance reranker')
source_to_edge_uuid_map = { # use rrf as a preliminary sort
edge.source_node_uuid: edge.uuid for result in search_results for edge in result sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results])
} sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
# node distance reranking
source_to_edge_uuid_map = defaultdict(list)
for edge in sorted_results:
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
source_uuids = [edge.source_node_uuid for edge in sorted_results]
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid) reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids] for node_uuid in reranked_node_uuids:
reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
@ -161,7 +169,7 @@ async def node_search(
driver: AsyncDriver, driver: AsyncDriver,
embedder, embedder,
query: str, query: str,
group_ids: list[str | None] | None, group_ids: list[str] | None,
config: NodeSearchConfig, config: NodeSearchConfig,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
@ -198,7 +206,9 @@ async def node_search(
elif config.reranker == NodeReranker.node_distance: elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None: if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker') raise SearchRerankerError('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid) reranked_uuids = await node_distance_reranker(
driver, rrf(search_result_uuids), center_node_uuid
)
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids] reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
@ -209,7 +219,7 @@ async def community_search(
driver: AsyncDriver, driver: AsyncDriver,
embedder, embedder,
query: str, query: str,
group_ids: list[str | None] | None, group_ids: list[str] | None,
config: CommunitySearchConfig, config: CommunitySearchConfig,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]: ) -> list[CommunityNode]:

View file

@ -87,7 +87,7 @@ async def edge_fulltext_search(
query: str, query: str,
source_node_uuid: str | None, source_node_uuid: str | None,
target_node_uuid: str | None, target_node_uuid: str | None,
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# fulltext search over facts # fulltext search over facts
@ -95,10 +95,7 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query) CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE CASE WHERE $group_ids IS NULL OR n.group_id IN $group_ids
WHEN $group_ids IS NULL THEN n.group_id IS NULL
ELSE n.group_id IN $group_ids
END
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -120,10 +117,7 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query) CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE CASE WHERE $group_ids IS NULL OR r.group_id IN $group_ids
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -144,10 +138,7 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query) CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE CASE WHERE $group_ids IS NULL OR r.group_id IN $group_ids
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -168,10 +159,7 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query) CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE CASE WHERE $group_ids IS NULL OR r.group_id IN $group_ids
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -209,7 +197,7 @@ async def edge_similarity_search(
search_vector: list[float], search_vector: list[float],
source_node_uuid: str | None, source_node_uuid: str | None,
target_node_uuid: str | None, target_node_uuid: str | None,
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# vector similarity search over embedded facts # vector similarity search over embedded facts
@ -217,10 +205,7 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE CASE WHERE $group_ids IS NULL OR r.group_id IN $group_ids
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -242,10 +227,7 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE CASE WHERE $group_ids IS NULL OR r.group_id IN $group_ids
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -266,10 +248,7 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE CASE WHERE $group_ids IS NULL OR r.group_id IN $group_ids
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -290,10 +269,7 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE CASE WHERE $group_ids IS NULL OR r.group_id IN $group_ids
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
@ -327,7 +303,7 @@ async def edge_similarity_search(
async def node_fulltext_search( async def node_fulltext_search(
driver: AsyncDriver, driver: AsyncDriver,
query: str, query: str,
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
# BM25 search to get top nodes # BM25 search to get top nodes
@ -336,10 +312,7 @@ async def node_fulltext_search(
""" """
CALL db.index.fulltext.queryNodes("name_and_summary", $query) CALL db.index.fulltext.queryNodes("name_and_summary", $query)
YIELD node AS n, score YIELD node AS n, score
WHERE CASE WHERE $group_ids IS NULL OR n.group_id IN $group_ids
WHEN $group_ids IS NULL THEN n.group_id IS NULL
ELSE n.group_id IN $group_ids
END
RETURN RETURN
n.uuid AS uuid, n.uuid AS uuid,
n.group_id AS group_id, n.group_id AS group_id,
@ -362,17 +335,16 @@ async def node_fulltext_search(
async def node_similarity_search( async def node_similarity_search(
driver: AsyncDriver, driver: AsyncDriver,
search_vector: list[float], search_vector: list[float],
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over entity names # vector similarity search over entity names
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score YIELD node AS n, score
MATCH (n WHERE n.group_id IN $group_ids) MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
RETURN RETURN
n.uuid As uuid, n.uuid As uuid,
n.group_id AS group_id, n.group_id AS group_id,
@ -394,18 +366,17 @@ async def node_similarity_search(
async def community_fulltext_search( async def community_fulltext_search(
driver: AsyncDriver, driver: AsyncDriver,
query: str, query: str,
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]: ) -> list[CommunityNode]:
group_ids = group_ids if group_ids is not None else [None]
# BM25 search to get top communities # BM25 search to get top communities
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CALL db.index.fulltext.queryNodes("community_name", $query) CALL db.index.fulltext.queryNodes("community_name", $query)
YIELD node AS comm, score YIELD node AS comm, score
MATCH (comm WHERE comm.group_id in $group_ids) MATCH (comm:Community)
WHERE $group_ids IS NULL OR comm.group_id in $group_ids
RETURN RETURN
comm.uuid AS uuid, comm.uuid AS uuid,
comm.group_id AS group_id, comm.group_id AS group_id,
@ -428,17 +399,16 @@ async def community_fulltext_search(
async def community_similarity_search( async def community_similarity_search(
driver: AsyncDriver, driver: AsyncDriver,
search_vector: list[float], search_vector: list[float],
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]: ) -> list[CommunityNode]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over entity names # vector similarity search over entity names
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
""" """
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector) CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
YIELD node AS comm, score YIELD node AS comm, score
MATCH (comm WHERE comm.group_id IN $group_ids) MATCH (comm:Community)
WHERE $group_ids IS NULL OR comm.group_id IN $group_ids
RETURN RETURN
comm.uuid As uuid, comm.uuid As uuid,
comm.group_id AS group_id, comm.group_id AS group_id,
@ -461,7 +431,7 @@ async def hybrid_node_search(
queries: list[str], queries: list[str],
embeddings: list[list[float]], embeddings: list[list[float]],
driver: AsyncDriver, driver: AsyncDriver,
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
""" """
@ -503,7 +473,6 @@ async def hybrid_node_search(
""" """
start = time() start = time()
results: list[list[EntityNode]] = list( results: list[list[EntityNode]] = list(
await asyncio.gather( await asyncio.gather(
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries], *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
@ -625,10 +594,10 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
async def node_distance_reranker( async def node_distance_reranker(
driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
) -> list[str]: ) -> list[str]:
# use rrf as a preliminary ranker # filter out node_uuid center node node uuid
sorted_uuids = rrf(node_uuids) filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
scores: dict[str, float] = {} scores: dict[str, float] = {}
# Find the shortest path to center node # Find the shortest path to center node
@ -644,21 +613,23 @@ async def node_distance_reranker(
node_uuid=uuid, node_uuid=uuid,
center_uuid=center_node_uuid, center_uuid=center_node_uuid,
) )
for uuid in sorted_uuids for uuid in filtered_uuids
] ]
) )
for uuid, result in zip(sorted_uuids, path_results): for uuid, result in zip(filtered_uuids, path_results):
records = result[0] records = result[0]
record = records[0] if len(records) > 0 else None record = records[0] if len(records) > 0 else None
distance: float = record['score'] if record is not None else float('inf') distance: float = record['score'] if record is not None else float('inf')
distance = 0 if uuid == center_node_uuid else distance
scores[uuid] = distance scores[uuid] = distance
# rerank on shortest distance # rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
return sorted_uuids # add back in filtered center uuids
filtered_uuids = [center_node_uuid] + filtered_uuids
return filtered_uuids
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]: async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:

View file

@ -154,7 +154,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
async def build_community( async def build_community(
llm_client: LLMClient, community_cluster: list[EntityNode] llm_client: LLMClient, community_cluster: list[EntityNode]
) -> tuple[CommunityNode, list[CommunityEdge]]: ) -> tuple[CommunityNode, list[CommunityEdge]]:
summaries = [entity.summary for entity in community_cluster] summaries = [entity.summary for entity in community_cluster]
length = len(summaries) length = len(summaries)
@ -168,7 +168,7 @@ async def build_community(
*[ *[
summarize_pair(llm_client, (str(left_summary), str(right_summary))) summarize_pair(llm_client, (str(left_summary), str(right_summary)))
for left_summary, right_summary in zip( for left_summary, right_summary in zip(
summaries[: int(length / 2)], summaries[int(length / 2):] summaries[: int(length / 2)], summaries[int(length / 2) :]
) )
] ]
) )
@ -196,7 +196,7 @@ async def build_community(
async def build_communities( async def build_communities(
driver: AsyncDriver, llm_client: LLMClient driver: AsyncDriver, llm_client: LLMClient
) -> tuple[list[CommunityNode], list[CommunityEdge]]: ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
community_clusters = await get_community_clusters(driver) community_clusters = await get_community_clusters(driver)
@ -227,7 +227,7 @@ async def remove_communities(driver: AsyncDriver):
async def determine_entity_community( async def determine_entity_community(
driver: AsyncDriver, entity: EntityNode driver: AsyncDriver, entity: EntityNode
) -> tuple[CommunityNode | None, bool]: ) -> tuple[CommunityNode | None, bool]:
# Check if the node is already part of a community # Check if the node is already part of a community
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -288,7 +288,7 @@ async def determine_entity_community(
async def update_community( async def update_community(
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
): ):
community, is_new = await determine_entity_community(driver, entity) community, is_new = await determine_entity_community(driver, entity)

View file

@ -73,7 +73,7 @@ async def extract_edges(
episode: EpisodicNode, episode: EpisodicNode,
nodes: list[EntityNode], nodes: list[EntityNode],
previous_episodes: list[EpisodicNode], previous_episodes: list[EpisodicNode],
group_id: str | None, group_id: str = '',
) -> list[EntityEdge]: ) -> list[EntityEdge]:
start = time() start = time()

View file

@ -101,7 +101,7 @@ async def retrieve_episodes(
driver: AsyncDriver, driver: AsyncDriver,
reference_time: datetime, reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN, last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None, group_ids: list[str] | None = None,
) -> list[EpisodicNode]: ) -> list[EpisodicNode]:
""" """
Retrieve the last n episodic nodes from the graph. Retrieve the last n episodic nodes from the graph.
@ -119,7 +119,8 @@ async def retrieve_episodes(
""" """
result = await driver.execute_query( result = await driver.execute_query(
""" """
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
AND ($group_ids IS NULL) OR e.group_id in $group_ids
RETURN e.content AS content, RETURN e.content AS content,
e.created_at AS created_at, e.created_at AS created_at,
e.valid_at AS valid_at, e.valid_at AS valid_at,

View file

@ -76,16 +76,18 @@ async def test_graphiti_init():
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
await graphiti.build_communities() await graphiti.build_communities()
edges = await graphiti.search('tania tetlow', group_ids=['1']) edges = await graphiti.search(
'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None
)
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges])) logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
edges = await graphiti.search('issues with higher ed', group_ids=['1']) edges = await graphiti.search('issues with higher ed', group_ids=None)
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges])) logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
results = await graphiti._search( results = await graphiti._search(
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=['1'] 'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=None
) )
pretty_results = { pretty_results = {
'edges': [edge.fact for edge in results.edges], 'edges': [edge.fact for edge in results.edges],