Merge febf8923f6 into d6ff7bb78c
This commit is contained in:
commit
d3e156b162
1 changed files with 77 additions and 19 deletions
|
|
@ -84,50 +84,108 @@ async def get_community_clusters(
|
||||||
|
|
||||||
|
|
||||||
def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
|
def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
|
||||||
# Implement the label propagation community detection algorithm.
|
"""
|
||||||
# 1. Start with each node being assigned its own community
|
Implement the label propagation community detection algorithm.
|
||||||
# 2. Each node will take on the community of the plurality of its neighbors
|
|
||||||
# 3. Ties are broken by going to the largest community
|
Algorithm:
|
||||||
# 4. Continue until no communities change during propagation
|
1. Start with each node being assigned its own community
|
||||||
|
2. Each node will take on the community of the plurality of its neighbors
|
||||||
|
3. Ties are broken by going to the largest community
|
||||||
|
4. Continue until no communities change during propagation
|
||||||
|
|
||||||
|
Oscillation prevention:
|
||||||
|
- Uses asynchronous updates (randomized node order)
|
||||||
|
- Maximum iteration limit to prevent infinite loops
|
||||||
|
- Early stopping if oscillation is detected
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
|
||||||
|
MAX_ITERATIONS = 100
|
||||||
|
OSCILLATION_CHECK_WINDOW = 5
|
||||||
|
|
||||||
community_map = {uuid: i for i, uuid in enumerate(projection.keys())}
|
community_map = {uuid: i for i, uuid in enumerate(projection.keys())}
|
||||||
|
node_uuids = list(projection.keys())
|
||||||
|
|
||||||
while True:
|
# Track history to detect oscillations
|
||||||
no_change = True
|
history: list[dict[str, int]] = []
|
||||||
new_community_map: dict[str, int] = {}
|
|
||||||
|
|
||||||
for uuid, neighbors in projection.items():
|
for iteration in range(MAX_ITERATIONS):
|
||||||
|
# Asynchronous update: randomize node processing order to prevent oscillation
|
||||||
|
random.shuffle(node_uuids)
|
||||||
|
|
||||||
|
changed_count = 0
|
||||||
|
|
||||||
|
for uuid in node_uuids:
|
||||||
|
neighbors = projection[uuid]
|
||||||
curr_community = community_map[uuid]
|
curr_community = community_map[uuid]
|
||||||
|
|
||||||
|
# Count votes from neighbors
|
||||||
community_candidates: dict[int, int] = defaultdict(int)
|
community_candidates: dict[int, int] = defaultdict(int)
|
||||||
for neighbor in neighbors:
|
for neighbor in neighbors:
|
||||||
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
|
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
|
||||||
|
|
||||||
|
if not community_candidates:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Sort by count (descending), then by community ID for deterministic tie-breaking
|
||||||
community_lst = [
|
community_lst = [
|
||||||
(count, community) for community, count in community_candidates.items()
|
(count, community) for community, count in community_candidates.items()
|
||||||
]
|
]
|
||||||
|
community_lst.sort(key=lambda x: (-x[0], x[1]))
|
||||||
|
|
||||||
community_lst.sort(reverse=True)
|
candidate_rank, community_candidate = community_lst[0]
|
||||||
candidate_rank, community_candidate = community_lst[0] if community_lst else (0, -1)
|
|
||||||
if community_candidate != -1 and candidate_rank > 1:
|
# Determine new community:
|
||||||
|
# - If strong signal (edge count > 1), adopt the neighbor's community
|
||||||
|
# - Otherwise, prefer the larger community ID (original behavior)
|
||||||
|
if candidate_rank > 1:
|
||||||
new_community = community_candidate
|
new_community = community_candidate
|
||||||
else:
|
else:
|
||||||
new_community = max(community_candidate, curr_community)
|
new_community = max(community_candidate, curr_community)
|
||||||
|
|
||||||
new_community_map[uuid] = new_community
|
|
||||||
|
|
||||||
if new_community != curr_community:
|
if new_community != curr_community:
|
||||||
no_change = False
|
community_map[uuid] = new_community
|
||||||
|
changed_count += 1
|
||||||
|
|
||||||
if no_change:
|
# Check for convergence
|
||||||
|
if changed_count == 0:
|
||||||
|
logger.debug(f'Label propagation converged after {iteration + 1} iterations')
|
||||||
break
|
break
|
||||||
|
|
||||||
community_map = new_community_map
|
# Check for oscillation by comparing with recent history
|
||||||
|
current_state = community_map.copy()
|
||||||
|
history.append(current_state)
|
||||||
|
|
||||||
community_cluster_map = defaultdict(list)
|
# Keep only recent history
|
||||||
|
if len(history) > OSCILLATION_CHECK_WINDOW:
|
||||||
|
history.pop(0)
|
||||||
|
|
||||||
|
# Detect oscillation: if current state matches any recent state
|
||||||
|
if len(history) >= 2:
|
||||||
|
for past_state in history[:-1]:
|
||||||
|
if past_state == current_state:
|
||||||
|
logger.warning(
|
||||||
|
f'Label propagation oscillation detected at iteration {iteration + 1}, '
|
||||||
|
'stopping early'
|
||||||
|
)
|
||||||
|
# Break out of the for loop
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# No oscillation detected, continue to next iteration
|
||||||
|
continue
|
||||||
|
# Oscillation detected, break out of the main loop
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f'Label propagation reached maximum iterations ({MAX_ITERATIONS}) without converging'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Group nodes by community
|
||||||
|
community_cluster_map: dict[int, list[str]] = defaultdict(list)
|
||||||
for uuid, community in community_map.items():
|
for uuid, community in community_map.items():
|
||||||
community_cluster_map[community].append(uuid)
|
community_cluster_map[community].append(uuid)
|
||||||
|
|
||||||
clusters = [cluster for cluster in community_cluster_map.values()]
|
clusters = list(community_cluster_map.values())
|
||||||
return clusters
|
return clusters
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue