From 5506a01e24a21dd375f41442dfc2bd4adfcfff2e Mon Sep 17 00:00:00 2001
From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
Date: Mon, 23 Sep 2024 11:05:44 -0400
Subject: [PATCH] In memory label propagation community detection (#136)
* WIP
* in memory graph detection
* format
* add comments
* update readme
* fixed an issue where solo nodes would throw an error when building communities
---
README.md | 2 -
examples/podcast/podcast_runner.py | 20 +--
graphiti_core/graphiti.py | 2 +-
.../utils/maintenance/community_operations.py | 124 +++++++++++++-----
tests/test_graphiti_int.py | 1 +
5 files changed, 106 insertions(+), 43 deletions(-)
diff --git a/README.md b/README.md
index beab724a..d2f304b7 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,6 @@
-
## Temporal Knowledge Graphs for Agentic Applications
@@ -80,7 +79,6 @@ Requirements:
- Python 3.10 or higher
- Neo4j 5.21 or higher
-- Neo4j GraphDataScience Plugin (required for community flows)
- OpenAI API key (for LLM inference and embedding)
Optional:
diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py
index 22e1d90b..90a4a205 100644
--- a/examples/podcast/podcast_runner.py
+++ b/examples/podcast/podcast_runner.py
@@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
messages = parse_podcast_messages()
if not use_bulk:
- for i, message in enumerate(messages[3:14]):
+ for i, message in enumerate(messages[3:4]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
@@ -76,15 +76,15 @@ async def main(use_bulk: bool = True):
await client.build_communities()
# add additional messages to update communities
- for i, message in enumerate(messages[14:20]):
- await client.add_episode(
- name=f'Message {i}',
- episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
- reference_time=message.actual_timestamp,
- source_description='Podcast Transcript',
- group_id='1',
- update_communities=True,
- )
+ # for i, message in enumerate(messages[14:20]):
+ # await client.add_episode(
+ # name=f'Message {i}',
+ # episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
+ # reference_time=message.actual_timestamp,
+ # source_description='Podcast Transcript',
+ # group_id='1',
+ # update_communities=True,
+ # )
return
diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py
index bf6f65e9..8a5eb23a 100644
--- a/graphiti_core/graphiti.py
+++ b/graphiti_core/graphiti.py
@@ -579,7 +579,7 @@ class Graphiti:
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
num_results=DEFAULT_SEARCH_LIMIT,
- ):
+ ) -> list[EntityEdge]:
"""
Perform a hybrid search on the knowledge graph.
diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py
index 014469c9..7a384fec 100644
--- a/graphiti_core/utils/maintenance/community_operations.py
+++ b/graphiti_core/utils/maintenance/community_operations.py
@@ -4,6 +4,7 @@ from collections import defaultdict
from datetime import datetime
from neo4j import AsyncDriver
+from pydantic import BaseModel
from graphiti_core.edges import CommunityEdge
from graphiti_core.llm_client import LLMClient
@@ -17,6 +18,11 @@ MAX_COMMUNITY_BUILD_CONCURRENCY = 10
logger = logging.getLogger(__name__)
+class Neighbor(BaseModel):
+ node_uuid: str
+ edge_count: int
+
+
async def build_community_projection(driver: AsyncDriver) -> str:
records, _, _ = await driver.execute_query("""
CALL gds.graph.project("communities", "Entity",
@@ -32,36 +38,96 @@ async def build_community_projection(driver: AsyncDriver) -> str:
return records[0]['graph']
-async def destroy_projection(driver: AsyncDriver, projection_name: str):
- await driver.execute_query(
- """
- CALL gds.graph.drop($projection_name)
- """,
- projection_name=projection_name,
- )
+async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]:
+ community_clusters: list[list[EntityNode]] = []
-
-async def get_community_clusters(
- driver: AsyncDriver, projection_name: str
-) -> list[list[EntityNode]]:
- records, _, _ = await driver.execute_query("""
- CALL gds.leiden.stream("communities")
- YIELD nodeId, communityId
- RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
+ group_id_values, _, _ = await driver.execute_query("""
+ MATCH (n:Entity WHERE n.group_id IS NOT NULL)
+ RETURN
+ collect(DISTINCT n.group_id) AS group_ids
""")
- community_map: dict[int, list[str]] = defaultdict(list)
- for record in records:
- community_map[record['communityId']].append(record['entity_uuid'])
- community_clusters: list[list[EntityNode]] = list(
- await asyncio.gather(
- *[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()]
+ group_ids = group_id_values[0]['group_ids']
+ for group_id in group_ids:
+ projection: dict[str, list[Neighbor]] = {}
+ nodes = await EntityNode.get_by_group_ids(driver, [group_id])
+ for node in nodes:
+ records, _, _ = await driver.execute_query(
+ """
+ MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id})
+ WITH count(r) AS count, m.uuid AS uuid
+ RETURN
+ uuid,
+ count
+ """,
+ uuid=node.uuid,
+ group_id=group_id,
+ )
+
+ projection[node.uuid] = [
+ Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
+ ]
+
+ cluster_uuids = label_propagation(projection)
+
+ community_clusters.extend(
+ list(
+ await asyncio.gather(
+ *[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
+ )
+ )
)
- )
return community_clusters
+def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
+ # Implement the label propagation community detection algorithm.
+ # 1. Start with each node being assigned its own community
+ # 2. Each node will take on the community of the plurality of its neighbors
+ # 3. Ties are broken by going to the largest community
+ # 4. Continue until no communities change during propagation
+
+ community_map = {uuid: i for i, uuid in enumerate(projection.keys())}
+
+ while True:
+ no_change = True
+ new_community_map: dict[str, int] = {}
+
+ for uuid, neighbors in projection.items():
+ curr_community = community_map[uuid]
+
+ community_candidates: dict[int, int] = defaultdict(int)
+ for neighbor in neighbors:
+ community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
+
+ community_lst = [
+ (count, community) for community, count in community_candidates.items()
+ ]
+
+ community_lst.sort(reverse=True)
+ community_candidate = community_lst[0][1] if len(community_lst) > 0 else -1
+
+ new_community = max(community_candidate, curr_community)
+
+ new_community_map[uuid] = new_community
+
+ if new_community != curr_community:
+ no_change = False
+
+ if no_change:
+ break
+
+ community_map = new_community_map
+
+ community_cluster_map = defaultdict(list)
+ for uuid, community in community_map.items():
+ community_cluster_map[community].append(uuid)
+
+ clusters = [cluster for cluster in community_cluster_map.values()]
+ return clusters
+
+
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
# Prepare context for LLM
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
@@ -88,7 +154,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
async def build_community(
- llm_client: LLMClient, community_cluster: list[EntityNode]
+ llm_client: LLMClient, community_cluster: list[EntityNode]
) -> tuple[CommunityNode, list[CommunityEdge]]:
summaries = [entity.summary for entity in community_cluster]
length = len(summaries)
@@ -102,7 +168,7 @@ async def build_community(
*[
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
for left_summary, right_summary in zip(
- summaries[: int(length / 2)], summaries[int(length / 2) :]
+ summaries[: int(length / 2)], summaries[int(length / 2):]
)
]
)
@@ -130,10 +196,9 @@ async def build_community(
async def build_communities(
- driver: AsyncDriver, llm_client: LLMClient
+ driver: AsyncDriver, llm_client: LLMClient
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
- projection = await build_community_projection(driver)
- community_clusters = await get_community_clusters(driver, projection)
+ community_clusters = await get_community_clusters(driver)
semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
@@ -151,7 +216,6 @@ async def build_communities(
community_nodes.append(community[0])
community_edges.extend(community[1])
- await destroy_projection(driver, projection)
return community_nodes, community_edges
@@ -163,7 +227,7 @@ async def remove_communities(driver: AsyncDriver):
async def determine_entity_community(
- driver: AsyncDriver, entity: EntityNode
+ driver: AsyncDriver, entity: EntityNode
) -> tuple[CommunityNode | None, bool]:
# Check if the node is already part of a community
records, _, _ = await driver.execute_query(
@@ -224,7 +288,7 @@ async def determine_entity_community(
async def update_community(
- driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
+ driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
):
community, is_new = await determine_entity_community(driver, entity)
diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py
index 9c9af450..488361d2 100644
--- a/tests/test_graphiti_int.py
+++ b/tests/test_graphiti_int.py
@@ -74,6 +74,7 @@ def format_context(facts):
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
+ await graphiti.build_communities()
edges = await graphiti.search('tania tetlow', group_ids=['1'])