From a18b3179ee8a3ea8f3b094b8cca9c8a79ee697c3 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Wed, 18 Sep 2024 11:37:34 -0400 Subject: [PATCH] Add community update (#121) * documentation update * update communities * update runner * make format * mypy * oops * add update_communities --- examples/podcast/podcast_runner.py | 16 +++- graphiti_core/graphiti.py | 14 ++- .../utils/maintenance/community_operations.py | 86 ++++++++++++++++++- 3 files changed, 113 insertions(+), 3 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 37ec9345..ad49a71c 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -63,7 +63,7 @@ async def main(use_bulk: bool = True): messages = parse_podcast_messages() if not use_bulk: - for i, message in enumerate(messages[3:20]): + for i, message in enumerate(messages[3:14]): await client.add_episode( name=f'Message {i}', episode_body=f'{message.speaker_name} ({message.role}): {message.content}', @@ -71,6 +71,20 @@ async def main(use_bulk: bool = True): source_description='Podcast Transcript', group_id='1', ) + + # build communities + await client.build_communities() + + # add additional messages to update communities + for i, message in enumerate(messages[14:20]): + await client.add_episode( + name=f'Message {i}', + episode_body=f'{message.speaker_name} ({message.role}): {message.content}', + reference_time=message.actual_timestamp, + source_description='Podcast Transcript', + group_id='1', + ) + return episodes: list[RawEpisode] = [ diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index f388dccc..f029ce48 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -54,6 +54,7 @@ from graphiti_core.utils.bulk_utils import ( from graphiti_core.utils.maintenance.community_operations import ( build_communities, remove_communities, + update_community, ) from graphiti_core.utils.maintenance.edge_operations import ( extract_edges, @@ -224,6 +225,7 @@ class Graphiti: source: EpisodeType = EpisodeType.message, group_id: str | None = None, uuid: str | None = None, + update_communities: bool = False, ): """ Process an episode and update the graph. @@ -247,6 +249,8 @@ class Graphiti: An id for the graph partition the episode is a part of. uuid : str | None Optional uuid of the episode. + update_communities: bool + Optional. Determines if we should update communities Returns ------- @@ -415,6 +419,14 @@ class Graphiti: await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) + # Update any communities + if update_communities: + await asyncio.gather( + *[ + update_community(self.driver, self.llm_client, embedder, node) + for node in nodes + ] + ) end = time() logger.info(f'Completed add_episode in {(end - start) * 1000} ms') @@ -569,7 +581,7 @@ class Graphiti: Facts will be reranked based on proximity to this node group_ids : list[str | None] | None, optional The graph partitions to return data from. - limit : int, optional + num_results : int, optional The maximum number of results to return. Defaults to 10. Returns diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index 868342ad..53fe6122 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -7,7 +7,7 @@ from neo4j import AsyncDriver from graphiti_core.edges import CommunityEdge from graphiti_core.llm_client import LLMClient -from graphiti_core.nodes import CommunityNode, EntityNode +from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record from graphiti_core.prompts import prompt_library from graphiti_core.utils.maintenance.edge_operations import build_community_edges @@ -153,3 +153,87 @@ async def remove_communities(driver: AsyncDriver): MATCH (c:Community) DETACH DELETE c """) + + +async def determine_entity_community( + driver: AsyncDriver, entity: EntityNode +) -> tuple[CommunityNode | None, bool]: + # Check if the node is already part of a community + records, _, _ = await driver.execute_query( + """ + MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid}) + RETURN + c.uuid As uuid, + c.name AS name, + c.name_embedding AS name_embedding, + c.group_id AS group_id, + c.created_at AS created_at, + c.summary AS summary + """, + entity_uuid=entity.uuid, + ) + + if len(records) > 0: + return get_community_node_from_record(records[0]), False + + # If the node has no community, add it to the mode community of surrounding entities + records, _, _ = await driver.execute_query( + """ + MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid}) + RETURN + c.uuid As uuid, + c.name AS name, + c.name_embedding AS name_embedding, + c.group_id AS group_id, + c.created_at AS created_at, + c.summary AS summary + """, + entity_uuid=entity.uuid, + ) + + communities: list[CommunityNode] = [ + get_community_node_from_record(record) for record in records + ] + + community_map: dict[str, int] = defaultdict(int) + for community in communities: + community_map[community.uuid] += 1 + + community_uuid = None + max_count = 0 + for uuid, count in community_map.items(): + if count > max_count: + community_uuid = uuid + max_count = count + + if max_count == 0: + return None, False + + for community in communities: + if community.uuid == community_uuid: + return community, True + + return None, False + + +async def update_community( + driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode +): + community, is_new = await determine_entity_community(driver, entity) + + if community is None: + return + + new_summary = await summarize_pair(llm_client, (entity.summary, community.summary)) + new_name = await generate_summary_description(llm_client, new_summary) + + community.summary = new_summary + community.name = new_name + + if is_new: + community_edge = (build_community_edges([entity], community, datetime.now()))[0] + await community_edge.save(driver) + + await community.generate_name_embedding(embedder) + + await community.save(driver)