diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index ec144987..d3b2404c 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -84,7 +84,7 @@ async def main(use_bulk: bool = True): for i, message in enumerate(messages[3:20]) ] - await client.add_episode_bulk(episodes) + await client.add_episode_bulk(episodes, None) asyncio.run(main(False)) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 1e60c944..3b0f3992 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -41,8 +41,18 @@ class Edge(BaseModel, ABC): @abstractmethod async def save(self, driver: AsyncDriver): ... - @abstractmethod - async def delete(self, driver: AsyncDriver): ... + async def delete(self, driver: AsyncDriver): + result = await driver.execute_query( + """ + MATCH (n)-[e {uuid: $uuid}]->(m) + DELETE e + """, + uuid=self.uuid, + ) + + logger.info(f'Deleted Edge: {self.uuid}') + + return result def __hash__(self): return hash(self.uuid) @@ -76,19 +86,6 @@ class EpisodicEdge(Edge): return result - async def delete(self, driver: AsyncDriver): - result = await driver.execute_query( - """ - MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) - DELETE e - """, - uuid=self.uuid, - ) - - logger.info(f'Deleted Edge: {self.uuid}') - - return result - @classmethod async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): records, _, _ = await driver.execute_query( @@ -169,19 +166,6 @@ class EntityEdge(Edge): return result - async def delete(self, driver: AsyncDriver): - result = await driver.execute_query( - """ - MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) - DELETE e - """, - uuid=self.uuid, - ) - - logger.info(f'Deleted Edge: {self.uuid}') - - return result - @classmethod async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): records, _, _ = await driver.execute_query( @@ -211,6 +195,48 @@ class EntityEdge(Edge): return edges[0] +class CommunityEdge(Edge): + async def save(self, driver: AsyncDriver): + result = await driver.execute_query( + """ + MATCH (community:Community {uuid: $community_uuid}) + MATCH (node:Entity | Community {uuid: $entity_uuid}) + MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node) + SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at} + RETURN r.uuid AS uuid""", + community_uuid=self.source_node_uuid, + entity_uuid=self.target_node_uuid, + uuid=self.uuid, + group_id=self.group_id, + created_at=self.created_at, + ) + + logger.info(f'Saved edge to neo4j: {self.uuid}') + + return result + + @classmethod + async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + records, _, _ = await driver.execute_query( + """ + MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community) + RETURN + e.uuid As uuid, + e.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.created_at AS created_at + """, + uuid=uuid, + ) + + edges = [get_community_edge_from_record(record) for record in records] + + logger.info(f'Found Edge: {uuid}') + + return edges[0] + + # Edge helpers def get_episodic_edge_from_record(record: Any) -> EpisodicEdge: return EpisodicEdge( @@ -237,3 +263,13 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge: valid_at=parse_db_date(record['valid_at']), invalid_at=parse_db_date(record['invalid_at']), ) + + +def get_community_edge_from_record(record: Any): + return CommunityEdge( + uuid=record['uuid'], + group_id=record['group_id'], + source_node_uuid=record['source_node_uuid'], + target_node_uuid=record['target_node_uuid'], + created_at=record['created_at'].to_native(), + ) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 2b945d9b..4536f75e 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -46,6 +46,10 @@ from graphiti_core.utils.bulk_utils import ( resolve_edge_pointers, retrieve_previous_episodes_bulk, ) +from graphiti_core.utils.maintenance.community_operations import ( + build_communities, + remove_communities, +) from graphiti_core.utils.maintenance.edge_operations import ( extract_edges, resolve_extracted_edges, @@ -526,6 +530,19 @@ class Graphiti: except Exception as e: raise e + async def build_communities(self): + embedder = self.llm_client.get_embedder() + + # Clear existing communities + await remove_communities(self.driver) + + community_nodes, community_edges = await build_communities(self.driver, self.llm_client) + + await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes]) + + await asyncio.gather(*[node.save(self.driver) for node in community_nodes]) + await asyncio.gather(*[edge.save(self.driver) for edge in community_edges]) + async def search( self, query: str, diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 907d52b4..354c9d67 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -76,8 +76,18 @@ class Node(BaseModel, ABC): @abstractmethod async def save(self, driver: AsyncDriver): ... - @abstractmethod - async def delete(self, driver: AsyncDriver): ... + async def delete(self, driver: AsyncDriver): + result = await driver.execute_query( + """ + MATCH (n {uuid: $uuid}) + DETACH DELETE n + """, + uuid=self.uuid, + ) + + logger.info(f'Deleted Node: {self.uuid}') + + return result def __hash__(self): return hash(self.uuid) @@ -90,6 +100,9 @@ class Node(BaseModel, ABC): @classmethod async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ... + @classmethod + async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ... + class EpisodicNode(Node): source: EpisodeType = Field(description='source type') @@ -125,19 +138,6 @@ class EpisodicNode(Node): return result - async def delete(self, driver: AsyncDriver): - result = await driver.execute_query( - """ - MATCH (n:Episodic {uuid: $uuid}) - DETACH DELETE n - """, - uuid=self.uuid, - ) - - logger.info(f'Deleted Node: {self.uuid}') - - return result - @classmethod async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): records, _, _ = await driver.execute_query( @@ -161,6 +161,29 @@ class EpisodicNode(Node): return episodes[0] + @classmethod + async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + records, _, _ = await driver.execute_query( + """ + MATCH (e:Episodic) WHERE e.uuid IN $uuids + RETURN e.content AS content, + e.created_at AS created_at, + e.valid_at AS valid_at, + e.uuid AS uuid, + e.name AS name, + e.group_id AS group_id + e.source_description AS source_description, + e.source AS source + """, + uuids=uuids, + ) + + episodes = [get_episodic_node_from_record(record) for record in records] + + logger.info(f'Found Nodes: {uuids}') + + return episodes + class EntityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') @@ -194,19 +217,6 @@ class EntityNode(Node): return result - async def delete(self, driver: AsyncDriver): - result = await driver.execute_query( - """ - MATCH (n:Entity {uuid: $uuid}) - DETACH DELETE n - """, - uuid=self.uuid, - ) - - logger.info(f'Deleted Node: {self.uuid}') - - return result - @classmethod async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): records, _, _ = await driver.execute_query( @@ -229,6 +239,105 @@ class EntityNode(Node): return nodes[0] + @classmethod + async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + records, _, _ = await driver.execute_query( + """ + MATCH (n:Entity) WHERE n.uuid IN $uuids + RETURN + n.uuid As uuid, + n.name AS name, + n.name_embedding AS name_embedding, + n.group_id AS group_id, + n.created_at AS created_at, + n.summary AS summary + """, + uuids=uuids, + ) + + nodes = [get_entity_node_from_record(record) for record in records] + + logger.info(f'Found Nodes: {uuids}') + + return nodes + + +class CommunityNode(Node): + name_embedding: list[float] | None = Field(default=None, description='embedding of the name') + summary: str = Field(description='region summary of member nodes', default_factory=str) + + async def save(self, driver: AsyncDriver): + result = await driver.execute_query( + """ + MERGE (n:Community {uuid: $uuid}) + SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at} + RETURN n.uuid AS uuid""", + uuid=self.uuid, + name=self.name, + group_id=self.group_id, + summary=self.summary, + name_embedding=self.name_embedding, + created_at=self.created_at, + ) + + logger.info(f'Saved Node to neo4j: {self.uuid}') + + return result + + async def generate_name_embedding(self, embedder, model='text-embedding-3-small'): + start = time() + text = self.name.replace('\n', ' ') + embedding = (await embedder.create(input=[text], model=model)).data[0].embedding + self.name_embedding = embedding[:EMBEDDING_DIM] + end = time() + logger.info(f'embedded {text} in {end - start} ms') + + return embedding + + @classmethod + async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): + records, _, _ = await driver.execute_query( + """ + MATCH (n:Community {uuid: $uuid}) + RETURN + n.uuid As uuid, + n.name AS name, + n.name_embedding AS name_embedding, + n.group_id AS group_id + n.created_at AS created_at, + n.summary AS summary + """, + uuid=uuid, + ) + + nodes = [get_community_node_from_record(record) for record in records] + + logger.info(f'Found Node: {uuid}') + + return nodes[0] + + @classmethod + async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + records, _, _ = await driver.execute_query( + """ + MATCH (n:Community) WHERE n.uuid IN $uuids + RETURN + n.uuid As uuid, + n.name AS name, + n.name_embedding AS name_embedding, + n.group_id AS group_id + n.created_at AS created_at, + n.summary AS summary + """, + uuids=uuids, + ) + + nodes = [get_community_node_from_record(record) for record in records] + + logger.info(f'Found Nodes: {uuids}') + + return nodes + # Node helpers def get_episodic_node_from_record(record: Any) -> EpisodicNode: @@ -254,3 +363,14 @@ def get_entity_node_from_record(record: Any) -> EntityNode: created_at=record['created_at'].to_native(), summary=record['summary'], ) + + +def get_community_node_from_record(record: Any) -> CommunityNode: + return CommunityNode( + uuid=record['uuid'], + name=record['name'], + group_id=record['group_id'], + name_embedding=record['name_embedding'], + created_at=record['created_at'].to_native(), + summary=record['summary'], + ) diff --git a/graphiti_core/prompts/lib.py b/graphiti_core/prompts/lib.py index e329d837..09cc3ec0 100644 --- a/graphiti_core/prompts/lib.py +++ b/graphiti_core/prompts/lib.py @@ -71,6 +71,9 @@ from .invalidate_edges import ( versions as invalidate_edges_versions, ) from .models import Message, PromptFunction +from .summarize_nodes import Prompt as SummarizeNodesPrompt +from .summarize_nodes import Versions as SummarizeNodesVersions +from .summarize_nodes import versions as summarize_nodes_versions class PromptLibrary(Protocol): @@ -80,6 +83,7 @@ class PromptLibrary(Protocol): dedupe_edges: DedupeEdgesPrompt invalidate_edges: InvalidateEdgesPrompt extract_edge_dates: ExtractEdgeDatesPrompt + summarize_nodes: SummarizeNodesPrompt class PromptLibraryImpl(TypedDict): @@ -89,6 +93,7 @@ class PromptLibraryImpl(TypedDict): dedupe_edges: DedupeEdgesVersions invalidate_edges: InvalidateEdgesVersions extract_edge_dates: ExtractEdgeDatesVersions + summarize_nodes: SummarizeNodesVersions class VersionWrapper: @@ -118,5 +123,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = { 'dedupe_edges': dedupe_edges_versions, 'invalidate_edges': invalidate_edges_versions, 'extract_edge_dates': extract_edge_dates_versions, + 'summarize_nodes': summarize_nodes_versions, } prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment] diff --git a/graphiti_core/prompts/summarize_nodes.py b/graphiti_core/prompts/summarize_nodes.py new file mode 100644 index 00000000..83702221 --- /dev/null +++ b/graphiti_core/prompts/summarize_nodes.py @@ -0,0 +1,79 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +from typing import Any, Protocol, TypedDict + +from .models import Message, PromptFunction, PromptVersion + + +class Prompt(Protocol): + summarize_pair: PromptVersion + summary_description: PromptVersion + + +class Versions(TypedDict): + summarize_pair: PromptFunction + summary_description: PromptFunction + + +def summarize_pair(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are a helpful assistant that combines summaries.', + ), + Message( + role='user', + content=f""" + Synthesize the information from the following two summaries into a single succinct summary. + + Summaries: + {json.dumps(context['node_summaries'], indent=2)} + + Respond with a JSON object in the following format: + {{ + "summary": "Summary containing the important information from both summaries" + }} + """, + ), + ] + + +def summary_description(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are a helpful assistant that describes provided contents in a single sentence.', + ), + Message( + role='user', + content=f""" + Create a short one sentence description of the summary that explains what kind of information is summarized. + + Summary: + {json.dumps(context['summary'], indent=2)} + + Respond with a JSON object in the following format: + {{ + "description": "One sentence description of the provided summary" + }} + """, + ), + ] + + +versions: Versions = {'summarize_pair': summarize_pair, 'summary_description': summary_description} diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py new file mode 100644 index 00000000..868342ad --- /dev/null +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -0,0 +1,155 @@ +import asyncio +import logging +from collections import defaultdict +from datetime import datetime + +from neo4j import AsyncDriver + +from graphiti_core.edges import CommunityEdge +from graphiti_core.llm_client import LLMClient +from graphiti_core.nodes import CommunityNode, EntityNode +from graphiti_core.prompts import prompt_library +from graphiti_core.utils.maintenance.edge_operations import build_community_edges + +logger = logging.getLogger(__name__) + + +async def build_community_projection(driver: AsyncDriver) -> str: + records, _, _ = await driver.execute_query(""" + CALL gds.graph.project("communities", "Entity", + {RELATES_TO: { + type: "RELATES_TO", + orientation: "UNDIRECTED", + properties: {weight: {property: "*", aggregation: "COUNT"}} + }} + ) + YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges + """) + + return records[0]['graph'] + + +async def destroy_projection(driver: AsyncDriver, projection_name: str): + await driver.execute_query( + """ + CALL gds.graph.drop($projection_name) + """, + projection_name=projection_name, + ) + + +async def get_community_clusters( + driver: AsyncDriver, projection_name: str +) -> list[list[EntityNode]]: + records, _, _ = await driver.execute_query(""" + CALL gds.leiden.stream("communities") + YIELD nodeId, communityId + RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId + """) + community_map: dict[int, list[str]] = defaultdict(list) + for record in records: + community_map[record['communityId']].append(record['entity_uuid']) + + community_clusters: list[list[EntityNode]] = list( + await asyncio.gather( + *[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()] + ) + ) + + return community_clusters + + +async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str: + # Prepare context for LLM + context = {'node_summaries': [{'summary': summary} for summary in summary_pair]} + + llm_response = await llm_client.generate_response( + prompt_library.summarize_nodes.summarize_pair(context) + ) + + pair_summary = llm_response.get('summary', '') + + return pair_summary + + +async def generate_summary_description(llm_client: LLMClient, summary: str) -> str: + context = {'summary': summary} + + llm_response = await llm_client.generate_response( + prompt_library.summarize_nodes.summary_description(context) + ) + + description = llm_response.get('description', '') + + return description + + +async def build_community( + llm_client: LLMClient, community_cluster: list[EntityNode] +) -> tuple[CommunityNode, list[CommunityEdge]]: + summaries = [entity.summary for entity in community_cluster] + length = len(summaries) + while length > 1: + odd_one_out: str | None = None + if length % 2 == 1: + odd_one_out = summaries.pop() + length -= 1 + new_summaries: list[str] = list( + await asyncio.gather( + *[ + summarize_pair(llm_client, (str(left_summary), str(right_summary))) + for left_summary, right_summary in zip( + summaries[: int(length / 2)], summaries[int(length / 2) :] + ) + ] + ) + ) + if odd_one_out is not None: + new_summaries.append(odd_one_out) + summaries = new_summaries + length = len(summaries) + + summary = summaries[0] + name = await generate_summary_description(llm_client, summary) + now = datetime.now() + community_node = CommunityNode( + name=name, + group_id=community_cluster[0].group_id, + labels=['Community'], + created_at=now, + summary=summary, + ) + community_edges = build_community_edges(community_cluster, community_node, now) + + logger.info((community_node, community_edges)) + + return community_node, community_edges + + +async def build_communities( + driver: AsyncDriver, llm_client: LLMClient +) -> tuple[list[CommunityNode], list[CommunityEdge]]: + projection = await build_community_projection(driver) + community_clusters = await get_community_clusters(driver, projection) + + communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list( + await asyncio.gather( + *[build_community(llm_client, cluster) for cluster in community_clusters] + ) + ) + + community_nodes: list[CommunityNode] = [] + community_edges: list[CommunityEdge] = [] + for community in communities: + community_nodes.append(community[0]) + community_edges.extend(community[1]) + + await destroy_projection(driver, projection) + return community_nodes, community_edges + + +async def remove_communities(driver: AsyncDriver): + await driver.execute_query(""" + MATCH (c:Community) + DETACH DELETE c + """) diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 4518c8da..83334e73 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -20,9 +20,9 @@ from datetime import datetime from time import time from typing import List -from graphiti_core.edges import EntityEdge, EpisodicEdge +from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge from graphiti_core.llm_client import LLMClient -from graphiti_core.nodes import EntityNode, EpisodicNode +from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode from graphiti_core.prompts import prompt_library from graphiti_core.utils.maintenance.temporal_operations import ( extract_edge_dates, @@ -50,6 +50,24 @@ def build_episodic_edges( return edges +def build_community_edges( + entity_nodes: List[EntityNode], + community_node: CommunityNode, + created_at: datetime, +) -> List[CommunityEdge]: + edges: List[CommunityEdge] = [ + CommunityEdge( + source_node_uuid=community_node.uuid, + target_node_uuid=node.uuid, + created_at=created_at, + group_id=community_node.group_id, + ) + for node in entity_nodes + ] + + return edges + + async def extract_edges( llm_client: LLMClient, episode: EpisodicNode, diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index a942a00b..446cb889 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -32,8 +32,10 @@ async def build_indices_and_constraints(driver: AsyncDriver): range_indices: list[LiteralString] = [ 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', + 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)', 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', + 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)', 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)', 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)', @@ -51,6 +53,7 @@ async def build_indices_and_constraints(driver: AsyncDriver): fulltext_indices: list[LiteralString] = [ 'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]', + 'CREATE FULLTEXT INDEX community_name IF NOT EXISTS FOR (n:Community) ON EACH [n.name]', 'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]', ] @@ -71,6 +74,14 @@ async def build_indices_and_constraints(driver: AsyncDriver): `vector.similarity_function`: 'cosine' }} """, + """ + CREATE VECTOR INDEX community_name_embedding IF NOT EXISTS + FOR (n:Community) ON (n.name_embedding) + OPTIONS {indexConfig: { + `vector.dimensions`: 1024, + `vector.similarity_function`: 'cosine' + }} + """, ] index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 2c2ebc35..e73500f1 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -73,6 +73,7 @@ def format_context(facts): async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) + await graphiti.build_communities() edges = await graphiti.search('Freakenomics guest', group_ids=['1'])