In memory label propagation community detection (#136)
* WIP * in memory graph detection * format * add comments * update readme * fixed an issue where solo nodes would throw an error when building communities
This commit is contained in:
parent
2fc1b00602
commit
5506a01e24
5 changed files with 106 additions and 43 deletions
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
<img width="350" alt="Graphiti-ts-small" src="https://github.com/user-attachments/assets/bbd02947-e435-4a05-b25a-bbbac36d52c8">
|
<img width="350" alt="Graphiti-ts-small" src="https://github.com/user-attachments/assets/bbd02947-e435-4a05-b25a-bbbac36d52c8">
|
||||||
|
|
||||||
|
|
||||||
## Temporal Knowledge Graphs for Agentic Applications
|
## Temporal Knowledge Graphs for Agentic Applications
|
||||||
|
|
||||||
<br />
|
<br />
|
||||||
|
|
@ -80,7 +79,6 @@ Requirements:
|
||||||
|
|
||||||
- Python 3.10 or higher
|
- Python 3.10 or higher
|
||||||
- Neo4j 5.21 or higher
|
- Neo4j 5.21 or higher
|
||||||
- Neo4j GraphDataScience Plugin (required for community flows)
|
|
||||||
- OpenAI API key (for LLM inference and embedding)
|
- OpenAI API key (for LLM inference and embedding)
|
||||||
|
|
||||||
Optional:
|
Optional:
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
|
|
||||||
if not use_bulk:
|
if not use_bulk:
|
||||||
for i, message in enumerate(messages[3:14]):
|
for i, message in enumerate(messages[3:4]):
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
name=f'Message {i}',
|
name=f'Message {i}',
|
||||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||||
|
|
@ -76,15 +76,15 @@ async def main(use_bulk: bool = True):
|
||||||
await client.build_communities()
|
await client.build_communities()
|
||||||
|
|
||||||
# add additional messages to update communities
|
# add additional messages to update communities
|
||||||
for i, message in enumerate(messages[14:20]):
|
# for i, message in enumerate(messages[14:20]):
|
||||||
await client.add_episode(
|
# await client.add_episode(
|
||||||
name=f'Message {i}',
|
# name=f'Message {i}',
|
||||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
# episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||||
reference_time=message.actual_timestamp,
|
# reference_time=message.actual_timestamp,
|
||||||
source_description='Podcast Transcript',
|
# source_description='Podcast Transcript',
|
||||||
group_id='1',
|
# group_id='1',
|
||||||
update_communities=True,
|
# update_communities=True,
|
||||||
)
|
# )
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -579,7 +579,7 @@ class Graphiti:
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
group_ids: list[str | None] | None = None,
|
group_ids: list[str | None] | None = None,
|
||||||
num_results=DEFAULT_SEARCH_LIMIT,
|
num_results=DEFAULT_SEARCH_LIMIT,
|
||||||
):
|
) -> list[EntityEdge]:
|
||||||
"""
|
"""
|
||||||
Perform a hybrid search on the knowledge graph.
|
Perform a hybrid search on the knowledge graph.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from graphiti_core.edges import CommunityEdge
|
from graphiti_core.edges import CommunityEdge
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
|
@ -17,6 +18,11 @@ MAX_COMMUNITY_BUILD_CONCURRENCY = 10
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Neighbor(BaseModel):
|
||||||
|
node_uuid: str
|
||||||
|
edge_count: int
|
||||||
|
|
||||||
|
|
||||||
async def build_community_projection(driver: AsyncDriver) -> str:
|
async def build_community_projection(driver: AsyncDriver) -> str:
|
||||||
records, _, _ = await driver.execute_query("""
|
records, _, _ = await driver.execute_query("""
|
||||||
CALL gds.graph.project("communities", "Entity",
|
CALL gds.graph.project("communities", "Entity",
|
||||||
|
|
@ -32,36 +38,96 @@ async def build_community_projection(driver: AsyncDriver) -> str:
|
||||||
return records[0]['graph']
|
return records[0]['graph']
|
||||||
|
|
||||||
|
|
||||||
async def destroy_projection(driver: AsyncDriver, projection_name: str):
|
async def get_community_clusters(driver: AsyncDriver) -> list[list[EntityNode]]:
|
||||||
await driver.execute_query(
|
community_clusters: list[list[EntityNode]] = []
|
||||||
"""
|
|
||||||
CALL gds.graph.drop($projection_name)
|
|
||||||
""",
|
|
||||||
projection_name=projection_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
group_id_values, _, _ = await driver.execute_query("""
|
||||||
async def get_community_clusters(
|
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
|
||||||
driver: AsyncDriver, projection_name: str
|
RETURN
|
||||||
) -> list[list[EntityNode]]:
|
collect(DISTINCT n.group_id) AS group_ids
|
||||||
records, _, _ = await driver.execute_query("""
|
|
||||||
CALL gds.leiden.stream("communities")
|
|
||||||
YIELD nodeId, communityId
|
|
||||||
RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
|
|
||||||
""")
|
""")
|
||||||
community_map: dict[int, list[str]] = defaultdict(list)
|
|
||||||
for record in records:
|
|
||||||
community_map[record['communityId']].append(record['entity_uuid'])
|
|
||||||
|
|
||||||
community_clusters: list[list[EntityNode]] = list(
|
group_ids = group_id_values[0]['group_ids']
|
||||||
await asyncio.gather(
|
for group_id in group_ids:
|
||||||
*[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()]
|
projection: dict[str, list[Neighbor]] = {}
|
||||||
|
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
|
||||||
|
for node in nodes:
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id})
|
||||||
|
WITH count(r) AS count, m.uuid AS uuid
|
||||||
|
RETURN
|
||||||
|
uuid,
|
||||||
|
count
|
||||||
|
""",
|
||||||
|
uuid=node.uuid,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
projection[node.uuid] = [
|
||||||
|
Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
|
||||||
|
]
|
||||||
|
|
||||||
|
cluster_uuids = label_propagation(projection)
|
||||||
|
|
||||||
|
community_clusters.extend(
|
||||||
|
list(
|
||||||
|
await asyncio.gather(
|
||||||
|
*[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return community_clusters
|
return community_clusters
|
||||||
|
|
||||||
|
|
||||||
|
def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
|
||||||
|
# Implement the label propagation community detection algorithm.
|
||||||
|
# 1. Start with each node being assigned its own community
|
||||||
|
# 2. Each node will take on the community of the plurality of its neighbors
|
||||||
|
# 3. Ties are broken by going to the largest community
|
||||||
|
# 4. Continue until no communities change during propagation
|
||||||
|
|
||||||
|
community_map = {uuid: i for i, uuid in enumerate(projection.keys())}
|
||||||
|
|
||||||
|
while True:
|
||||||
|
no_change = True
|
||||||
|
new_community_map: dict[str, int] = {}
|
||||||
|
|
||||||
|
for uuid, neighbors in projection.items():
|
||||||
|
curr_community = community_map[uuid]
|
||||||
|
|
||||||
|
community_candidates: dict[int, int] = defaultdict(int)
|
||||||
|
for neighbor in neighbors:
|
||||||
|
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
|
||||||
|
|
||||||
|
community_lst = [
|
||||||
|
(count, community) for community, count in community_candidates.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
community_lst.sort(reverse=True)
|
||||||
|
community_candidate = community_lst[0][1] if len(community_lst) > 0 else -1
|
||||||
|
|
||||||
|
new_community = max(community_candidate, curr_community)
|
||||||
|
|
||||||
|
new_community_map[uuid] = new_community
|
||||||
|
|
||||||
|
if new_community != curr_community:
|
||||||
|
no_change = False
|
||||||
|
|
||||||
|
if no_change:
|
||||||
|
break
|
||||||
|
|
||||||
|
community_map = new_community_map
|
||||||
|
|
||||||
|
community_cluster_map = defaultdict(list)
|
||||||
|
for uuid, community in community_map.items():
|
||||||
|
community_cluster_map[community].append(uuid)
|
||||||
|
|
||||||
|
clusters = [cluster for cluster in community_cluster_map.values()]
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
|
||||||
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
|
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
|
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
|
||||||
|
|
@ -88,7 +154,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
|
||||||
|
|
||||||
|
|
||||||
async def build_community(
|
async def build_community(
|
||||||
llm_client: LLMClient, community_cluster: list[EntityNode]
|
llm_client: LLMClient, community_cluster: list[EntityNode]
|
||||||
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
||||||
summaries = [entity.summary for entity in community_cluster]
|
summaries = [entity.summary for entity in community_cluster]
|
||||||
length = len(summaries)
|
length = len(summaries)
|
||||||
|
|
@ -102,7 +168,7 @@ async def build_community(
|
||||||
*[
|
*[
|
||||||
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
||||||
for left_summary, right_summary in zip(
|
for left_summary, right_summary in zip(
|
||||||
summaries[: int(length / 2)], summaries[int(length / 2) :]
|
summaries[: int(length / 2)], summaries[int(length / 2):]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -130,10 +196,9 @@ async def build_community(
|
||||||
|
|
||||||
|
|
||||||
async def build_communities(
|
async def build_communities(
|
||||||
driver: AsyncDriver, llm_client: LLMClient
|
driver: AsyncDriver, llm_client: LLMClient
|
||||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||||
projection = await build_community_projection(driver)
|
community_clusters = await get_community_clusters(driver)
|
||||||
community_clusters = await get_community_clusters(driver, projection)
|
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
|
semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
|
||||||
|
|
||||||
|
|
@ -151,7 +216,6 @@ async def build_communities(
|
||||||
community_nodes.append(community[0])
|
community_nodes.append(community[0])
|
||||||
community_edges.extend(community[1])
|
community_edges.extend(community[1])
|
||||||
|
|
||||||
await destroy_projection(driver, projection)
|
|
||||||
return community_nodes, community_edges
|
return community_nodes, community_edges
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -163,7 +227,7 @@ async def remove_communities(driver: AsyncDriver):
|
||||||
|
|
||||||
|
|
||||||
async def determine_entity_community(
|
async def determine_entity_community(
|
||||||
driver: AsyncDriver, entity: EntityNode
|
driver: AsyncDriver, entity: EntityNode
|
||||||
) -> tuple[CommunityNode | None, bool]:
|
) -> tuple[CommunityNode | None, bool]:
|
||||||
# Check if the node is already part of a community
|
# Check if the node is already part of a community
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -224,7 +288,7 @@ async def determine_entity_community(
|
||||||
|
|
||||||
|
|
||||||
async def update_community(
|
async def update_community(
|
||||||
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
|
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
|
||||||
):
|
):
|
||||||
community, is_new = await determine_entity_community(driver, entity)
|
community, is_new = await determine_entity_community(driver, entity)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,7 @@ def format_context(facts):
|
||||||
async def test_graphiti_init():
|
async def test_graphiti_init():
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||||
|
await graphiti.build_communities()
|
||||||
|
|
||||||
edges = await graphiti.search('tania tetlow', group_ids=['1'])
|
edges = await graphiti.search('tania tetlow', group_ids=['1'])
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue