Add community update (#121)
* documentation update * update communities * update runner * make format * mypy * oops * add update_communities
This commit is contained in:
parent
ebb1ec2463
commit
a18b3179ee
3 changed files with 113 additions and 3 deletions
|
|
@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
|
|
||||||
if not use_bulk:
|
if not use_bulk:
|
||||||
for i, message in enumerate(messages[3:20]):
|
for i, message in enumerate(messages[3:14]):
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
name=f'Message {i}',
|
name=f'Message {i}',
|
||||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||||
|
|
@ -71,6 +71,20 @@ async def main(use_bulk: bool = True):
|
||||||
source_description='Podcast Transcript',
|
source_description='Podcast Transcript',
|
||||||
group_id='1',
|
group_id='1',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# build communities
|
||||||
|
await client.build_communities()
|
||||||
|
|
||||||
|
# add additional messages to update communities
|
||||||
|
for i, message in enumerate(messages[14:20]):
|
||||||
|
await client.add_episode(
|
||||||
|
name=f'Message {i}',
|
||||||
|
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||||
|
reference_time=message.actual_timestamp,
|
||||||
|
source_description='Podcast Transcript',
|
||||||
|
group_id='1',
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
episodes: list[RawEpisode] = [
|
episodes: list[RawEpisode] = [
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,7 @@ from graphiti_core.utils.bulk_utils import (
|
||||||
from graphiti_core.utils.maintenance.community_operations import (
|
from graphiti_core.utils.maintenance.community_operations import (
|
||||||
build_communities,
|
build_communities,
|
||||||
remove_communities,
|
remove_communities,
|
||||||
|
update_community,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
extract_edges,
|
extract_edges,
|
||||||
|
|
@ -224,6 +225,7 @@ class Graphiti:
|
||||||
source: EpisodeType = EpisodeType.message,
|
source: EpisodeType = EpisodeType.message,
|
||||||
group_id: str | None = None,
|
group_id: str | None = None,
|
||||||
uuid: str | None = None,
|
uuid: str | None = None,
|
||||||
|
update_communities: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Process an episode and update the graph.
|
Process an episode and update the graph.
|
||||||
|
|
@ -247,6 +249,8 @@ class Graphiti:
|
||||||
An id for the graph partition the episode is a part of.
|
An id for the graph partition the episode is a part of.
|
||||||
uuid : str | None
|
uuid : str | None
|
||||||
Optional uuid of the episode.
|
Optional uuid of the episode.
|
||||||
|
update_communities: bool
|
||||||
|
Optional. Determines if we should update communities
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|
@ -415,6 +419,14 @@ class Graphiti:
|
||||||
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
||||||
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
||||||
|
|
||||||
|
# Update any communities
|
||||||
|
if update_communities:
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
update_community(self.driver, self.llm_client, embedder, node)
|
||||||
|
for node in nodes
|
||||||
|
]
|
||||||
|
)
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
||||||
|
|
||||||
|
|
@ -569,7 +581,7 @@ class Graphiti:
|
||||||
Facts will be reranked based on proximity to this node
|
Facts will be reranked based on proximity to this node
|
||||||
group_ids : list[str | None] | None, optional
|
group_ids : list[str | None] | None, optional
|
||||||
The graph partitions to return data from.
|
The graph partitions to return data from.
|
||||||
limit : int, optional
|
num_results : int, optional
|
||||||
The maximum number of results to return. Defaults to 10.
|
The maximum number of results to return. Defaults to 10.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from neo4j import AsyncDriver
|
||||||
|
|
||||||
from graphiti_core.edges import CommunityEdge
|
from graphiti_core.edges import CommunityEdge
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
||||||
|
|
||||||
|
|
@ -153,3 +153,87 @@ async def remove_communities(driver: AsyncDriver):
|
||||||
MATCH (c:Community)
|
MATCH (c:Community)
|
||||||
DETACH DELETE c
|
DETACH DELETE c
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
async def determine_entity_community(
|
||||||
|
driver: AsyncDriver, entity: EntityNode
|
||||||
|
) -> tuple[CommunityNode | None, bool]:
|
||||||
|
# Check if the node is already part of a community
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
||||||
|
RETURN
|
||||||
|
c.uuid As uuid,
|
||||||
|
c.name AS name,
|
||||||
|
c.name_embedding AS name_embedding,
|
||||||
|
c.group_id AS group_id,
|
||||||
|
c.created_at AS created_at,
|
||||||
|
c.summary AS summary
|
||||||
|
""",
|
||||||
|
entity_uuid=entity.uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(records) > 0:
|
||||||
|
return get_community_node_from_record(records[0]), False
|
||||||
|
|
||||||
|
# If the node has no community, add it to the mode community of surrounding entities
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
||||||
|
RETURN
|
||||||
|
c.uuid As uuid,
|
||||||
|
c.name AS name,
|
||||||
|
c.name_embedding AS name_embedding,
|
||||||
|
c.group_id AS group_id,
|
||||||
|
c.created_at AS created_at,
|
||||||
|
c.summary AS summary
|
||||||
|
""",
|
||||||
|
entity_uuid=entity.uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
communities: list[CommunityNode] = [
|
||||||
|
get_community_node_from_record(record) for record in records
|
||||||
|
]
|
||||||
|
|
||||||
|
community_map: dict[str, int] = defaultdict(int)
|
||||||
|
for community in communities:
|
||||||
|
community_map[community.uuid] += 1
|
||||||
|
|
||||||
|
community_uuid = None
|
||||||
|
max_count = 0
|
||||||
|
for uuid, count in community_map.items():
|
||||||
|
if count > max_count:
|
||||||
|
community_uuid = uuid
|
||||||
|
max_count = count
|
||||||
|
|
||||||
|
if max_count == 0:
|
||||||
|
return None, False
|
||||||
|
|
||||||
|
for community in communities:
|
||||||
|
if community.uuid == community_uuid:
|
||||||
|
return community, True
|
||||||
|
|
||||||
|
return None, False
|
||||||
|
|
||||||
|
|
||||||
|
async def update_community(
|
||||||
|
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
|
||||||
|
):
|
||||||
|
community, is_new = await determine_entity_community(driver, entity)
|
||||||
|
|
||||||
|
if community is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
|
||||||
|
new_name = await generate_summary_description(llm_client, new_summary)
|
||||||
|
|
||||||
|
community.summary = new_summary
|
||||||
|
community.name = new_name
|
||||||
|
|
||||||
|
if is_new:
|
||||||
|
community_edge = (build_community_edges([entity], community, datetime.now()))[0]
|
||||||
|
await community_edge.save(driver)
|
||||||
|
|
||||||
|
await community.generate_name_embedding(embedder)
|
||||||
|
|
||||||
|
await community.save(driver)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue