From 5506a01e24a21dd375f41442dfc2bd4adfcfff2e Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Mon, 23 Sep 2024 11:05:44 -0400 Subject: [PATCH] In memory label propagation community detection (#136) * WIP * in memory graph detection * format * add comments * update readme * fixed an issue where solo nodes would throw an error when building communities --- README.md | 2 - examples/podcast/podcast_runner.py | 20 +-- graphiti_core/graphiti.py | 2 +- .../utils/maintenance/community_operations.py | 124 +++++++++++++----- tests/test_graphiti_int.py | 1 + 5 files changed, 106 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index beab724a..d2f304b7 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,6 @@ Graphiti-ts-small - ## Temporal Knowledge Graphs for Agentic Applications
@@ -80,7 +79,6 @@ Requirements: - Python 3.10 or higher - Neo4j 5.21 or higher -- Neo4j GraphDataScience Plugin (required for community flows) - OpenAI API key (for LLM inference and embedding) Optional: diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 22e1d90b..90a4a205 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -63,7 +63,7 @@ async def main(use_bulk: bool = True): messages = parse_podcast_messages() if not use_bulk: - for i, message in enumerate(messages[3:14]): + for i, message in enumerate(messages[3:4]): await client.add_episode( name=f'Message {i}', episode_body=f'{message.speaker_name} ({message.role}): {message.content}', @@ -76,15 +76,15 @@ async def main(use_bulk: bool = True): await client.build_communities() # add additional messages to update communities - for i, message in enumerate(messages[14:20]): - await client.add_episode( - name=f'Message {i}', - episode_body=f'{message.speaker_name} ({message.role}): {message.content}', - reference_time=message.actual_timestamp, - source_description='Podcast Transcript', - group_id='1', - update_communities=True, - ) + # for i, message in enumerate(messages[14:20]): + # await client.add_episode( + # name=f'Message {i}', + # episode_body=f'{message.speaker_name} ({message.role}): {message.content}', + # reference_time=message.actual_timestamp, + # source_description='Podcast Transcript', + # group_id='1', + # update_communities=True, + # ) return diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index bf6f65e9..8a5eb23a 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -579,7 +579,7 @@ class Graphiti: center_node_uuid: str | None = None, group_ids: list[str | None] | None = None, num_results=DEFAULT_SEARCH_LIMIT, - ): + ) -> list[EntityEdge]: """ Perform a hybrid search on the knowledge graph. diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index 014469c9..7a384fec 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -4,6 +4,7 @@ from collections import defaultdict from datetime import datetime from neo4j import AsyncDriver +from pydantic import BaseModel from graphiti_core.edges import CommunityEdge from graphiti_core.llm_client import LLMClient @@ -17,6 +18,11 @@ MAX_COMMUNITY_BUILD_CONCURRENCY = 10 logger = logging.getLogger(__name__) +class Neighbor(BaseModel): + node_uuid: str + edge_count: int + + async def build_community_projection(driver: AsyncDriver) -> str: records, _, _ = await driver.execute_query(""" CALL gds.graph.project("communities", "Entity", @@ -32,36 +38,96 @@ async def build_community_projection(driver: AsyncDriver) -> str: return records[0]['graph'] -async def destroy_projection(driver: AsyncDriver, projection_name: str): - await driver.execute_query( - """ - CALL gds.graph.drop($projection_name) - """, - projection_name=projection_name, - ) +async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]: + community_clusters: list[list[EntityNode]] = [] - -async def get_community_clusters( - driver: AsyncDriver, projection_name: str -) -> list[list[EntityNode]]: - records, _, _ = await driver.execute_query(""" - CALL gds.leiden.stream("communities") - YIELD nodeId, communityId - RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId + group_id_values, _, _ = await driver.execute_query(""" + MATCH (n:Entity WHERE n.group_id IS NOT NULL) + RETURN + collect(DISTINCT n.group_id) AS group_ids """) - community_map: dict[int, list[str]] = defaultdict(list) - for record in records: - community_map[record['communityId']].append(record['entity_uuid']) - community_clusters: list[list[EntityNode]] = list( - await asyncio.gather( - *[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()] + group_ids = group_id_values[0]['group_ids'] + for group_id in group_ids: + projection: dict[str, list[Neighbor]] = {} + nodes = await EntityNode.get_by_group_ids(driver, [group_id]) + for node in nodes: + records, _, _ = await driver.execute_query( + """ + MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id}) + WITH count(r) AS count, m.uuid AS uuid + RETURN + uuid, + count + """, + uuid=node.uuid, + group_id=group_id, + ) + + projection[node.uuid] = [ + Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records + ] + + cluster_uuids = label_propagation(projection) + + community_clusters.extend( + list( + await asyncio.gather( + *[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids] + ) + ) ) - ) return community_clusters +def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]: + # Implement the label propagation community detection algorithm. + # 1. Start with each node being assigned its own community + # 2. Each node will take on the community of the plurality of its neighbors + # 3. Ties are broken by going to the largest community + # 4. Continue until no communities change during propagation + + community_map = {uuid: i for i, uuid in enumerate(projection.keys())} + + while True: + no_change = True + new_community_map: dict[str, int] = {} + + for uuid, neighbors in projection.items(): + curr_community = community_map[uuid] + + community_candidates: dict[int, int] = defaultdict(int) + for neighbor in neighbors: + community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count + + community_lst = [ + (count, community) for community, count in community_candidates.items() + ] + + community_lst.sort(reverse=True) + community_candidate = community_lst[0][1] if len(community_lst) > 0 else -1 + + new_community = max(community_candidate, curr_community) + + new_community_map[uuid] = new_community + + if new_community != curr_community: + no_change = False + + if no_change: + break + + community_map = new_community_map + + community_cluster_map = defaultdict(list) + for uuid, community in community_map.items(): + community_cluster_map[community].append(uuid) + + clusters = [cluster for cluster in community_cluster_map.values()] + return clusters + + async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str: # Prepare context for LLM context = {'node_summaries': [{'summary': summary} for summary in summary_pair]} @@ -88,7 +154,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s async def build_community( - llm_client: LLMClient, community_cluster: list[EntityNode] + llm_client: LLMClient, community_cluster: list[EntityNode] ) -> tuple[CommunityNode, list[CommunityEdge]]: summaries = [entity.summary for entity in community_cluster] length = len(summaries) @@ -102,7 +168,7 @@ async def build_community( *[ summarize_pair(llm_client, (str(left_summary), str(right_summary))) for left_summary, right_summary in zip( - summaries[: int(length / 2)], summaries[int(length / 2) :] + summaries[: int(length / 2)], summaries[int(length / 2):] ) ] ) @@ -130,10 +196,9 @@ async def build_community( async def build_communities( - driver: AsyncDriver, llm_client: LLMClient + driver: AsyncDriver, llm_client: LLMClient ) -> tuple[list[CommunityNode], list[CommunityEdge]]: - projection = await build_community_projection(driver) - community_clusters = await get_community_clusters(driver, projection) + community_clusters = await get_community_clusters(driver) semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY) @@ -151,7 +216,6 @@ async def build_communities( community_nodes.append(community[0]) community_edges.extend(community[1]) - await destroy_projection(driver, projection) return community_nodes, community_edges @@ -163,7 +227,7 @@ async def remove_communities(driver: AsyncDriver): async def determine_entity_community( - driver: AsyncDriver, entity: EntityNode + driver: AsyncDriver, entity: EntityNode ) -> tuple[CommunityNode | None, bool]: # Check if the node is already part of a community records, _, _ = await driver.execute_query( @@ -224,7 +288,7 @@ async def determine_entity_community( async def update_community( - driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode + driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode ): community, is_new = await determine_entity_community(driver, entity) diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 9c9af450..488361d2 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -74,6 +74,7 @@ def format_context(facts): async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) + await graphiti.build_communities() edges = await graphiti.search('tania tetlow', group_ids=['1'])