Merge febf8923f6 into d6ff7bb78c
This commit is contained in:
commit
d3e156b162
1 changed files with 77 additions and 19 deletions
|
|
@ -84,50 +84,108 @@ async def get_community_clusters(
|
|||
|
||||
|
||||
def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
|
||||
# Implement the label propagation community detection algorithm.
|
||||
# 1. Start with each node being assigned its own community
|
||||
# 2. Each node will take on the community of the plurality of its neighbors
|
||||
# 3. Ties are broken by going to the largest community
|
||||
# 4. Continue until no communities change during propagation
|
||||
"""
|
||||
Implement the label propagation community detection algorithm.
|
||||
|
||||
Algorithm:
|
||||
1. Start with each node being assigned its own community
|
||||
2. Each node will take on the community of the plurality of its neighbors
|
||||
3. Ties are broken by going to the largest community
|
||||
4. Continue until no communities change during propagation
|
||||
|
||||
Oscillation prevention:
|
||||
- Uses asynchronous updates (randomized node order)
|
||||
- Maximum iteration limit to prevent infinite loops
|
||||
- Early stopping if oscillation is detected
|
||||
"""
|
||||
import random
|
||||
|
||||
MAX_ITERATIONS = 100
|
||||
OSCILLATION_CHECK_WINDOW = 5
|
||||
|
||||
community_map = {uuid: i for i, uuid in enumerate(projection.keys())}
|
||||
node_uuids = list(projection.keys())
|
||||
|
||||
while True:
|
||||
no_change = True
|
||||
new_community_map: dict[str, int] = {}
|
||||
# Track history to detect oscillations
|
||||
history: list[dict[str, int]] = []
|
||||
|
||||
for uuid, neighbors in projection.items():
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
# Asynchronous update: randomize node processing order to prevent oscillation
|
||||
random.shuffle(node_uuids)
|
||||
|
||||
changed_count = 0
|
||||
|
||||
for uuid in node_uuids:
|
||||
neighbors = projection[uuid]
|
||||
curr_community = community_map[uuid]
|
||||
|
||||
# Count votes from neighbors
|
||||
community_candidates: dict[int, int] = defaultdict(int)
|
||||
for neighbor in neighbors:
|
||||
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
|
||||
|
||||
if not community_candidates:
|
||||
continue
|
||||
|
||||
# Sort by count (descending), then by community ID for deterministic tie-breaking
|
||||
community_lst = [
|
||||
(count, community) for community, count in community_candidates.items()
|
||||
]
|
||||
community_lst.sort(key=lambda x: (-x[0], x[1]))
|
||||
|
||||
community_lst.sort(reverse=True)
|
||||
candidate_rank, community_candidate = community_lst[0] if community_lst else (0, -1)
|
||||
if community_candidate != -1 and candidate_rank > 1:
|
||||
candidate_rank, community_candidate = community_lst[0]
|
||||
|
||||
# Determine new community:
|
||||
# - If strong signal (edge count > 1), adopt the neighbor's community
|
||||
# - Otherwise, prefer the larger community ID (original behavior)
|
||||
if candidate_rank > 1:
|
||||
new_community = community_candidate
|
||||
else:
|
||||
new_community = max(community_candidate, curr_community)
|
||||
|
||||
new_community_map[uuid] = new_community
|
||||
|
||||
if new_community != curr_community:
|
||||
no_change = False
|
||||
community_map[uuid] = new_community
|
||||
changed_count += 1
|
||||
|
||||
if no_change:
|
||||
# Check for convergence
|
||||
if changed_count == 0:
|
||||
logger.debug(f'Label propagation converged after {iteration + 1} iterations')
|
||||
break
|
||||
|
||||
community_map = new_community_map
|
||||
# Check for oscillation by comparing with recent history
|
||||
current_state = community_map.copy()
|
||||
history.append(current_state)
|
||||
|
||||
community_cluster_map = defaultdict(list)
|
||||
# Keep only recent history
|
||||
if len(history) > OSCILLATION_CHECK_WINDOW:
|
||||
history.pop(0)
|
||||
|
||||
# Detect oscillation: if current state matches any recent state
|
||||
if len(history) >= 2:
|
||||
for past_state in history[:-1]:
|
||||
if past_state == current_state:
|
||||
logger.warning(
|
||||
f'Label propagation oscillation detected at iteration {iteration + 1}, '
|
||||
'stopping early'
|
||||
)
|
||||
# Break out of the for loop
|
||||
break
|
||||
else:
|
||||
# No oscillation detected, continue to next iteration
|
||||
continue
|
||||
# Oscillation detected, break out of the main loop
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f'Label propagation reached maximum iterations ({MAX_ITERATIONS}) without converging'
|
||||
)
|
||||
|
||||
# Group nodes by community
|
||||
community_cluster_map: dict[int, list[str]] = defaultdict(list)
|
||||
for uuid, community in community_map.items():
|
||||
community_cluster_map[community].append(uuid)
|
||||
|
||||
clusters = [cluster for cluster in community_cluster_map.values()]
|
||||
clusters = list(community_cluster_map.values())
|
||||
return clusters
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue