In memory label propagation community detection (#136)

* WIP

* in memory graph detection

* format

* add comments

* update readme

* fixed an issue where solo nodes would throw an error when building communities
This commit is contained in:
Preston Rasmussen 2024-09-23 11:05:44 -04:00 committed by GitHub
parent 2fc1b00602
commit 5506a01e24
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 106 additions and 43 deletions

View file

@ -2,7 +2,6 @@
<img width="350" alt="Graphiti-ts-small" src="https://github.com/user-attachments/assets/bbd02947-e435-4a05-b25a-bbbac36d52c8"> <img width="350" alt="Graphiti-ts-small" src="https://github.com/user-attachments/assets/bbd02947-e435-4a05-b25a-bbbac36d52c8">
## Temporal Knowledge Graphs for Agentic Applications ## Temporal Knowledge Graphs for Agentic Applications
<br /> <br />
@ -80,7 +79,6 @@ Requirements:
- Python 3.10 or higher - Python 3.10 or higher
- Neo4j 5.21 or higher - Neo4j 5.21 or higher
- Neo4j GraphDataScience Plugin (required for community flows)
- OpenAI API key (for LLM inference and embedding) - OpenAI API key (for LLM inference and embedding)
Optional: Optional:

View file

@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
messages = parse_podcast_messages() messages = parse_podcast_messages()
if not use_bulk: if not use_bulk:
for i, message in enumerate(messages[3:14]): for i, message in enumerate(messages[3:4]):
await client.add_episode( await client.add_episode(
name=f'Message {i}', name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}', episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
@ -76,15 +76,15 @@ async def main(use_bulk: bool = True):
await client.build_communities() await client.build_communities()
# add additional messages to update communities # add additional messages to update communities
for i, message in enumerate(messages[14:20]): # for i, message in enumerate(messages[14:20]):
await client.add_episode( # await client.add_episode(
name=f'Message {i}', # name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}', # episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp, # reference_time=message.actual_timestamp,
source_description='Podcast Transcript', # source_description='Podcast Transcript',
group_id='1', # group_id='1',
update_communities=True, # update_communities=True,
) # )
return return

View file

@ -579,7 +579,7 @@ class Graphiti:
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None, group_ids: list[str | None] | None = None,
num_results=DEFAULT_SEARCH_LIMIT, num_results=DEFAULT_SEARCH_LIMIT,
): ) -> list[EntityEdge]:
""" """
Perform a hybrid search on the knowledge graph. Perform a hybrid search on the knowledge graph.

View file

@ -4,6 +4,7 @@ from collections import defaultdict
from datetime import datetime from datetime import datetime
from neo4j import AsyncDriver from neo4j import AsyncDriver
from pydantic import BaseModel
from graphiti_core.edges import CommunityEdge from graphiti_core.edges import CommunityEdge
from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client import LLMClient
@ -17,6 +18,11 @@ MAX_COMMUNITY_BUILD_CONCURRENCY = 10
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Neighbor(BaseModel):
node_uuid: str
edge_count: int
async def build_community_projection(driver: AsyncDriver) -> str: async def build_community_projection(driver: AsyncDriver) -> str:
records, _, _ = await driver.execute_query(""" records, _, _ = await driver.execute_query("""
CALL gds.graph.project("communities", "Entity", CALL gds.graph.project("communities", "Entity",
@ -32,36 +38,96 @@ async def build_community_projection(driver: AsyncDriver) -> str:
return records[0]['graph'] return records[0]['graph']
async def destroy_projection(driver: AsyncDriver, projection_name: str): async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]:
await driver.execute_query( community_clusters: list[list[EntityNode]] = []
group_id_values, _, _ = await driver.execute_query("""
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
RETURN
collect(DISTINCT n.group_id) AS group_ids
""")
group_ids = group_id_values[0]['group_ids']
for group_id in group_ids:
projection: dict[str, list[Neighbor]] = {}
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
for node in nodes:
records, _, _ = await driver.execute_query(
""" """
CALL gds.graph.drop($projection_name) MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id})
WITH count(r) AS count, m.uuid AS uuid
RETURN
uuid,
count
""", """,
projection_name=projection_name, uuid=node.uuid,
group_id=group_id,
) )
projection[node.uuid] = [
Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
]
async def get_community_clusters( cluster_uuids = label_propagation(projection)
driver: AsyncDriver, projection_name: str
) -> list[list[EntityNode]]:
records, _, _ = await driver.execute_query("""
CALL gds.leiden.stream("communities")
YIELD nodeId, communityId
RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
""")
community_map: dict[int, list[str]] = defaultdict(list)
for record in records:
community_map[record['communityId']].append(record['entity_uuid'])
community_clusters: list[list[EntityNode]] = list( community_clusters.extend(
list(
await asyncio.gather( await asyncio.gather(
*[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()] *[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
)
) )
) )
return community_clusters return community_clusters
def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
# Implement the label propagation community detection algorithm.
# 1. Start with each node being assigned its own community
# 2. Each node will take on the community of the plurality of its neighbors
# 3. Ties are broken by going to the largest community
# 4. Continue until no communities change during propagation
community_map = {uuid: i for i, uuid in enumerate(projection.keys())}
while True:
no_change = True
new_community_map: dict[str, int] = {}
for uuid, neighbors in projection.items():
curr_community = community_map[uuid]
community_candidates: dict[int, int] = defaultdict(int)
for neighbor in neighbors:
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
community_lst = [
(count, community) for community, count in community_candidates.items()
]
community_lst.sort(reverse=True)
community_candidate = community_lst[0][1] if len(community_lst) > 0 else -1
new_community = max(community_candidate, curr_community)
new_community_map[uuid] = new_community
if new_community != curr_community:
no_change = False
if no_change:
break
community_map = new_community_map
community_cluster_map = defaultdict(list)
for uuid, community in community_map.items():
community_cluster_map[community].append(uuid)
clusters = [cluster for cluster in community_cluster_map.values()]
return clusters
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str: async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
# Prepare context for LLM # Prepare context for LLM
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]} context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
@ -132,8 +198,7 @@ async def build_community(
async def build_communities( async def build_communities(
driver: AsyncDriver, llm_client: LLMClient driver: AsyncDriver, llm_client: LLMClient
) -> tuple[list[CommunityNode], list[CommunityEdge]]: ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
projection = await build_community_projection(driver) community_clusters = await get_community_clusters(driver)
community_clusters = await get_community_clusters(driver, projection)
semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY) semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
@ -151,7 +216,6 @@ async def build_communities(
community_nodes.append(community[0]) community_nodes.append(community[0])
community_edges.extend(community[1]) community_edges.extend(community[1])
await destroy_projection(driver, projection)
return community_nodes, community_edges return community_nodes, community_edges

View file

@ -74,6 +74,7 @@ def format_context(facts):
async def test_graphiti_init(): async def test_graphiti_init():
logger = setup_logging() logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
await graphiti.build_communities()
edges = await graphiti.search('tania tetlow', group_ids=['1']) edges = await graphiti.search('tania tetlow', group_ids=['1'])