Add community update (#121)
* documentation update * update communities * update runner * make format * mypy * oops * add update_communities
This commit is contained in:
parent
ebb1ec2463
commit
a18b3179ee
3 changed files with 113 additions and 3 deletions
|
|
@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
|
|||
messages = parse_podcast_messages()
|
||||
|
||||
if not use_bulk:
|
||||
for i, message in enumerate(messages[3:20]):
|
||||
for i, message in enumerate(messages[3:14]):
|
||||
await client.add_episode(
|
||||
name=f'Message {i}',
|
||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
|
|
@ -71,6 +71,20 @@ async def main(use_bulk: bool = True):
|
|||
source_description='Podcast Transcript',
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
# build communities
|
||||
await client.build_communities()
|
||||
|
||||
# add additional messages to update communities
|
||||
for i, message in enumerate(messages[14:20]):
|
||||
await client.add_episode(
|
||||
name=f'Message {i}',
|
||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description='Podcast Transcript',
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
episodes: list[RawEpisode] = [
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ from graphiti_core.utils.bulk_utils import (
|
|||
from graphiti_core.utils.maintenance.community_operations import (
|
||||
build_communities,
|
||||
remove_communities,
|
||||
update_community,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
extract_edges,
|
||||
|
|
@ -224,6 +225,7 @@ class Graphiti:
|
|||
source: EpisodeType = EpisodeType.message,
|
||||
group_id: str | None = None,
|
||||
uuid: str | None = None,
|
||||
update_communities: bool = False,
|
||||
):
|
||||
"""
|
||||
Process an episode and update the graph.
|
||||
|
|
@ -247,6 +249,8 @@ class Graphiti:
|
|||
An id for the graph partition the episode is a part of.
|
||||
uuid : str | None
|
||||
Optional uuid of the episode.
|
||||
update_communities: bool
|
||||
Optional. Determines if we should update communities
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -415,6 +419,14 @@ class Graphiti:
|
|||
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
||||
|
||||
# Update any communities
|
||||
if update_communities:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
update_community(self.driver, self.llm_client, embedder, node)
|
||||
for node in nodes
|
||||
]
|
||||
)
|
||||
end = time()
|
||||
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
||||
|
||||
|
|
@ -569,7 +581,7 @@ class Graphiti:
|
|||
Facts will be reranked based on proximity to this node
|
||||
group_ids : list[str | None] | None, optional
|
||||
The graph partitions to return data from.
|
||||
limit : int, optional
|
||||
num_results : int, optional
|
||||
The maximum number of results to return. Defaults to 10.
|
||||
|
||||
Returns
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from neo4j import AsyncDriver
|
|||
|
||||
from graphiti_core.edges import CommunityEdge
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
||||
|
||||
|
|
@ -153,3 +153,87 @@ async def remove_communities(driver: AsyncDriver):
|
|||
MATCH (c:Community)
|
||||
DETACH DELETE c
|
||||
""")
|
||||
|
||||
|
||||
async def determine_entity_community(
|
||||
driver: AsyncDriver, entity: EntityNode
|
||||
) -> tuple[CommunityNode | None, bool]:
|
||||
# Check if the node is already part of a community
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
||||
RETURN
|
||||
c.uuid As uuid,
|
||||
c.name AS name,
|
||||
c.name_embedding AS name_embedding,
|
||||
c.group_id AS group_id,
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
entity_uuid=entity.uuid,
|
||||
)
|
||||
|
||||
if len(records) > 0:
|
||||
return get_community_node_from_record(records[0]), False
|
||||
|
||||
# If the node has no community, add it to the mode community of surrounding entities
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
||||
RETURN
|
||||
c.uuid As uuid,
|
||||
c.name AS name,
|
||||
c.name_embedding AS name_embedding,
|
||||
c.group_id AS group_id,
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
entity_uuid=entity.uuid,
|
||||
)
|
||||
|
||||
communities: list[CommunityNode] = [
|
||||
get_community_node_from_record(record) for record in records
|
||||
]
|
||||
|
||||
community_map: dict[str, int] = defaultdict(int)
|
||||
for community in communities:
|
||||
community_map[community.uuid] += 1
|
||||
|
||||
community_uuid = None
|
||||
max_count = 0
|
||||
for uuid, count in community_map.items():
|
||||
if count > max_count:
|
||||
community_uuid = uuid
|
||||
max_count = count
|
||||
|
||||
if max_count == 0:
|
||||
return None, False
|
||||
|
||||
for community in communities:
|
||||
if community.uuid == community_uuid:
|
||||
return community, True
|
||||
|
||||
return None, False
|
||||
|
||||
|
||||
async def update_community(
|
||||
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
|
||||
):
|
||||
community, is_new = await determine_entity_community(driver, entity)
|
||||
|
||||
if community is None:
|
||||
return
|
||||
|
||||
new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
|
||||
new_name = await generate_summary_description(llm_client, new_summary)
|
||||
|
||||
community.summary = new_summary
|
||||
community.name = new_name
|
||||
|
||||
if is_new:
|
||||
community_edge = (build_community_edges([entity], community, datetime.now()))[0]
|
||||
await community_edge.save(driver)
|
||||
|
||||
await community.generate_name_embedding(embedder)
|
||||
|
||||
await community.save(driver)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue