Add community update (#121)

* documentation update

* update communities

* update runner

* make format

* mypy

* oops

* add update_communities
This commit is contained in:
Preston Rasmussen 2024-09-18 11:37:34 -04:00 committed by GitHub
parent ebb1ec2463
commit a18b3179ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 113 additions and 3 deletions

View file

@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
messages = parse_podcast_messages()
if not use_bulk:
for i, message in enumerate(messages[3:20]):
for i, message in enumerate(messages[3:14]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
@ -71,6 +71,20 @@ async def main(use_bulk: bool = True):
source_description='Podcast Transcript',
group_id='1',
)
# build communities
await client.build_communities()
# add additional messages to update communities
for i, message in enumerate(messages[14:20]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='1',
)
return
episodes: list[RawEpisode] = [

View file

@ -54,6 +54,7 @@ from graphiti_core.utils.bulk_utils import (
from graphiti_core.utils.maintenance.community_operations import (
build_communities,
remove_communities,
update_community,
)
from graphiti_core.utils.maintenance.edge_operations import (
extract_edges,
@ -224,6 +225,7 @@ class Graphiti:
source: EpisodeType = EpisodeType.message,
group_id: str | None = None,
uuid: str | None = None,
update_communities: bool = False,
):
"""
Process an episode and update the graph.
@ -247,6 +249,8 @@ class Graphiti:
An id for the graph partition the episode is a part of.
uuid : str | None
Optional uuid of the episode.
update_communities: bool
Optional. Determines if we should update communities
Returns
-------
@ -415,6 +419,14 @@ class Graphiti:
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
# Update any communities
if update_communities:
await asyncio.gather(
*[
update_community(self.driver, self.llm_client, embedder, node)
for node in nodes
]
)
end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
@ -569,7 +581,7 @@ class Graphiti:
Facts will be reranked based on proximity to this node
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
limit : int, optional
num_results : int, optional
The maximum number of results to return. Defaults to 10.
Returns

View file

@ -7,7 +7,7 @@ from neo4j import AsyncDriver
from graphiti_core.edges import CommunityEdge
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
from graphiti_core.prompts import prompt_library
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
@ -153,3 +153,87 @@ async def remove_communities(driver: AsyncDriver):
MATCH (c:Community)
DETACH DELETE c
""")
async def determine_entity_community(
driver: AsyncDriver, entity: EntityNode
) -> tuple[CommunityNode | None, bool]:
# Check if the node is already part of a community
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
RETURN
c.uuid As uuid,
c.name AS name,
c.name_embedding AS name_embedding,
c.group_id AS group_id,
c.created_at AS created_at,
c.summary AS summary
""",
entity_uuid=entity.uuid,
)
if len(records) > 0:
return get_community_node_from_record(records[0]), False
# If the node has no community, add it to the mode community of surrounding entities
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
RETURN
c.uuid As uuid,
c.name AS name,
c.name_embedding AS name_embedding,
c.group_id AS group_id,
c.created_at AS created_at,
c.summary AS summary
""",
entity_uuid=entity.uuid,
)
communities: list[CommunityNode] = [
get_community_node_from_record(record) for record in records
]
community_map: dict[str, int] = defaultdict(int)
for community in communities:
community_map[community.uuid] += 1
community_uuid = None
max_count = 0
for uuid, count in community_map.items():
if count > max_count:
community_uuid = uuid
max_count = count
if max_count == 0:
return None, False
for community in communities:
if community.uuid == community_uuid:
return community, True
return None, False
async def update_community(
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
):
community, is_new = await determine_entity_community(driver, entity)
if community is None:
return
new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
new_name = await generate_summary_description(llm_client, new_summary)
community.summary = new_summary
community.name = new_name
if is_new:
community_edge = (build_community_edges([entity], community, datetime.now()))[0]
await community_edge.save(driver)
await community.generate_name_embedding(embedder)
await community.save(driver)