Community nodes (#103)
* add gds * community work * save progress * community updates * e2e communities * troubleshooting * updates * communities * remove unused import
This commit is contained in:
parent
4122d350a5
commit
c0a740ff60
10 changed files with 502 additions and 59 deletions
|
|
@ -84,7 +84,7 @@ async def main(use_bulk: bool = True):
|
||||||
for i, message in enumerate(messages[3:20])
|
for i, message in enumerate(messages[3:20])
|
||||||
]
|
]
|
||||||
|
|
||||||
await client.add_episode_bulk(episodes)
|
await client.add_episode_bulk(episodes, None)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main(False))
|
asyncio.run(main(False))
|
||||||
|
|
|
||||||
|
|
@ -41,8 +41,18 @@ class Edge(BaseModel, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save(self, driver: AsyncDriver): ...
|
async def save(self, driver: AsyncDriver): ...
|
||||||
|
|
||||||
@abstractmethod
|
async def delete(self, driver: AsyncDriver):
|
||||||
async def delete(self, driver: AsyncDriver): ...
|
result = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n)-[e {uuid: $uuid}]->(m)
|
||||||
|
DELETE e
|
||||||
|
""",
|
||||||
|
uuid=self.uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f'Deleted Edge: {self.uuid}')
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.uuid)
|
return hash(self.uuid)
|
||||||
|
|
@ -76,19 +86,6 @@ class EpisodicEdge(Edge):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete(self, driver: AsyncDriver):
|
|
||||||
result = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
||||||
DELETE e
|
|
||||||
""",
|
|
||||||
uuid=self.uuid,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f'Deleted Edge: {self.uuid}')
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -169,19 +166,6 @@ class EntityEdge(Edge):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete(self, driver: AsyncDriver):
|
|
||||||
result = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
||||||
DELETE e
|
|
||||||
""",
|
|
||||||
uuid=self.uuid,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f'Deleted Edge: {self.uuid}')
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -211,6 +195,48 @@ class EntityEdge(Edge):
|
||||||
return edges[0]
|
return edges[0]
|
||||||
|
|
||||||
|
|
||||||
|
class CommunityEdge(Edge):
|
||||||
|
async def save(self, driver: AsyncDriver):
|
||||||
|
result = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (community:Community {uuid: $community_uuid})
|
||||||
|
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||||
|
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||||
|
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||||
|
RETURN r.uuid AS uuid""",
|
||||||
|
community_uuid=self.source_node_uuid,
|
||||||
|
entity_uuid=self.target_node_uuid,
|
||||||
|
uuid=self.uuid,
|
||||||
|
group_id=self.group_id,
|
||||||
|
created_at=self.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f'Saved edge to neo4j: {self.uuid}')
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
|
||||||
|
RETURN
|
||||||
|
e.uuid As uuid,
|
||||||
|
e.group_id AS group_id,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
e.created_at AS created_at
|
||||||
|
""",
|
||||||
|
uuid=uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
edges = [get_community_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
|
logger.info(f'Found Edge: {uuid}')
|
||||||
|
|
||||||
|
return edges[0]
|
||||||
|
|
||||||
|
|
||||||
# Edge helpers
|
# Edge helpers
|
||||||
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
||||||
return EpisodicEdge(
|
return EpisodicEdge(
|
||||||
|
|
@ -237,3 +263,13 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
||||||
valid_at=parse_db_date(record['valid_at']),
|
valid_at=parse_db_date(record['valid_at']),
|
||||||
invalid_at=parse_db_date(record['invalid_at']),
|
invalid_at=parse_db_date(record['invalid_at']),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_community_edge_from_record(record: Any):
|
||||||
|
return CommunityEdge(
|
||||||
|
uuid=record['uuid'],
|
||||||
|
group_id=record['group_id'],
|
||||||
|
source_node_uuid=record['source_node_uuid'],
|
||||||
|
target_node_uuid=record['target_node_uuid'],
|
||||||
|
created_at=record['created_at'].to_native(),
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,10 @@ from graphiti_core.utils.bulk_utils import (
|
||||||
resolve_edge_pointers,
|
resolve_edge_pointers,
|
||||||
retrieve_previous_episodes_bulk,
|
retrieve_previous_episodes_bulk,
|
||||||
)
|
)
|
||||||
|
from graphiti_core.utils.maintenance.community_operations import (
|
||||||
|
build_communities,
|
||||||
|
remove_communities,
|
||||||
|
)
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
extract_edges,
|
extract_edges,
|
||||||
resolve_extracted_edges,
|
resolve_extracted_edges,
|
||||||
|
|
@ -526,6 +530,19 @@ class Graphiti:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def build_communities(self):
|
||||||
|
embedder = self.llm_client.get_embedder()
|
||||||
|
|
||||||
|
# Clear existing communities
|
||||||
|
await remove_communities(self.driver)
|
||||||
|
|
||||||
|
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
||||||
|
|
||||||
|
await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
|
||||||
|
|
||||||
|
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
||||||
|
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
|
||||||
|
|
@ -76,8 +76,18 @@ class Node(BaseModel, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save(self, driver: AsyncDriver): ...
|
async def save(self, driver: AsyncDriver): ...
|
||||||
|
|
||||||
@abstractmethod
|
async def delete(self, driver: AsyncDriver):
|
||||||
async def delete(self, driver: AsyncDriver): ...
|
result = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n {uuid: $uuid})
|
||||||
|
DETACH DELETE n
|
||||||
|
""",
|
||||||
|
uuid=self.uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f'Deleted Node: {self.uuid}')
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.uuid)
|
return hash(self.uuid)
|
||||||
|
|
@ -90,6 +100,9 @@ class Node(BaseModel, ABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
|
||||||
|
|
||||||
|
|
||||||
class EpisodicNode(Node):
|
class EpisodicNode(Node):
|
||||||
source: EpisodeType = Field(description='source type')
|
source: EpisodeType = Field(description='source type')
|
||||||
|
|
@ -125,19 +138,6 @@ class EpisodicNode(Node):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete(self, driver: AsyncDriver):
|
|
||||||
result = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
MATCH (n:Episodic {uuid: $uuid})
|
|
||||||
DETACH DELETE n
|
|
||||||
""",
|
|
||||||
uuid=self.uuid,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f'Deleted Node: {self.uuid}')
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -161,6 +161,29 @@ class EpisodicNode(Node):
|
||||||
|
|
||||||
return episodes[0]
|
return episodes[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
||||||
|
RETURN e.content AS content,
|
||||||
|
e.created_at AS created_at,
|
||||||
|
e.valid_at AS valid_at,
|
||||||
|
e.uuid AS uuid,
|
||||||
|
e.name AS name,
|
||||||
|
e.group_id AS group_id
|
||||||
|
e.source_description AS source_description,
|
||||||
|
e.source AS source
|
||||||
|
""",
|
||||||
|
uuids=uuids,
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
logger.info(f'Found Nodes: {uuids}')
|
||||||
|
|
||||||
|
return episodes
|
||||||
|
|
||||||
|
|
||||||
class EntityNode(Node):
|
class EntityNode(Node):
|
||||||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||||
|
|
@ -194,19 +217,6 @@ class EntityNode(Node):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete(self, driver: AsyncDriver):
|
|
||||||
result = await driver.execute_query(
|
|
||||||
"""
|
|
||||||
MATCH (n:Entity {uuid: $uuid})
|
|
||||||
DETACH DELETE n
|
|
||||||
""",
|
|
||||||
uuid=self.uuid,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f'Deleted Node: {self.uuid}')
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -229,6 +239,105 @@ class EntityNode(Node):
|
||||||
|
|
||||||
return nodes[0]
|
return nodes[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
||||||
|
RETURN
|
||||||
|
n.uuid As uuid,
|
||||||
|
n.name AS name,
|
||||||
|
n.name_embedding AS name_embedding,
|
||||||
|
n.group_id AS group_id,
|
||||||
|
n.created_at AS created_at,
|
||||||
|
n.summary AS summary
|
||||||
|
""",
|
||||||
|
uuids=uuids,
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
logger.info(f'Found Nodes: {uuids}')
|
||||||
|
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
class CommunityNode(Node):
|
||||||
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||||
|
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
||||||
|
|
||||||
|
async def save(self, driver: AsyncDriver):
|
||||||
|
result = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MERGE (n:Community {uuid: $uuid})
|
||||||
|
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||||
|
RETURN n.uuid AS uuid""",
|
||||||
|
uuid=self.uuid,
|
||||||
|
name=self.name,
|
||||||
|
group_id=self.group_id,
|
||||||
|
summary=self.summary,
|
||||||
|
name_embedding=self.name_embedding,
|
||||||
|
created_at=self.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f'Saved Node to neo4j: {self.uuid}')
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
|
||||||
|
start = time()
|
||||||
|
text = self.name.replace('\n', ' ')
|
||||||
|
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||||
|
self.name_embedding = embedding[:EMBEDDING_DIM]
|
||||||
|
end = time()
|
||||||
|
logger.info(f'embedded {text} in {end - start} ms')
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n:Community {uuid: $uuid})
|
||||||
|
RETURN
|
||||||
|
n.uuid As uuid,
|
||||||
|
n.name AS name,
|
||||||
|
n.name_embedding AS name_embedding,
|
||||||
|
n.group_id AS group_id
|
||||||
|
n.created_at AS created_at,
|
||||||
|
n.summary AS summary
|
||||||
|
""",
|
||||||
|
uuid=uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
logger.info(f'Found Node: {uuid}')
|
||||||
|
|
||||||
|
return nodes[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n:Community) WHERE n.uuid IN $uuids
|
||||||
|
RETURN
|
||||||
|
n.uuid As uuid,
|
||||||
|
n.name AS name,
|
||||||
|
n.name_embedding AS name_embedding,
|
||||||
|
n.group_id AS group_id
|
||||||
|
n.created_at AS created_at,
|
||||||
|
n.summary AS summary
|
||||||
|
""",
|
||||||
|
uuids=uuids,
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
logger.info(f'Found Nodes: {uuids}')
|
||||||
|
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
# Node helpers
|
# Node helpers
|
||||||
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
||||||
|
|
@ -254,3 +363,14 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=record['created_at'].to_native(),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_community_node_from_record(record: Any) -> CommunityNode:
|
||||||
|
return CommunityNode(
|
||||||
|
uuid=record['uuid'],
|
||||||
|
name=record['name'],
|
||||||
|
group_id=record['group_id'],
|
||||||
|
name_embedding=record['name_embedding'],
|
||||||
|
created_at=record['created_at'].to_native(),
|
||||||
|
summary=record['summary'],
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,9 @@ from .invalidate_edges import (
|
||||||
versions as invalidate_edges_versions,
|
versions as invalidate_edges_versions,
|
||||||
)
|
)
|
||||||
from .models import Message, PromptFunction
|
from .models import Message, PromptFunction
|
||||||
|
from .summarize_nodes import Prompt as SummarizeNodesPrompt
|
||||||
|
from .summarize_nodes import Versions as SummarizeNodesVersions
|
||||||
|
from .summarize_nodes import versions as summarize_nodes_versions
|
||||||
|
|
||||||
|
|
||||||
class PromptLibrary(Protocol):
|
class PromptLibrary(Protocol):
|
||||||
|
|
@ -80,6 +83,7 @@ class PromptLibrary(Protocol):
|
||||||
dedupe_edges: DedupeEdgesPrompt
|
dedupe_edges: DedupeEdgesPrompt
|
||||||
invalidate_edges: InvalidateEdgesPrompt
|
invalidate_edges: InvalidateEdgesPrompt
|
||||||
extract_edge_dates: ExtractEdgeDatesPrompt
|
extract_edge_dates: ExtractEdgeDatesPrompt
|
||||||
|
summarize_nodes: SummarizeNodesPrompt
|
||||||
|
|
||||||
|
|
||||||
class PromptLibraryImpl(TypedDict):
|
class PromptLibraryImpl(TypedDict):
|
||||||
|
|
@ -89,6 +93,7 @@ class PromptLibraryImpl(TypedDict):
|
||||||
dedupe_edges: DedupeEdgesVersions
|
dedupe_edges: DedupeEdgesVersions
|
||||||
invalidate_edges: InvalidateEdgesVersions
|
invalidate_edges: InvalidateEdgesVersions
|
||||||
extract_edge_dates: ExtractEdgeDatesVersions
|
extract_edge_dates: ExtractEdgeDatesVersions
|
||||||
|
summarize_nodes: SummarizeNodesVersions
|
||||||
|
|
||||||
|
|
||||||
class VersionWrapper:
|
class VersionWrapper:
|
||||||
|
|
@ -118,5 +123,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
||||||
'dedupe_edges': dedupe_edges_versions,
|
'dedupe_edges': dedupe_edges_versions,
|
||||||
'invalidate_edges': invalidate_edges_versions,
|
'invalidate_edges': invalidate_edges_versions,
|
||||||
'extract_edge_dates': extract_edge_dates_versions,
|
'extract_edge_dates': extract_edge_dates_versions,
|
||||||
|
'summarize_nodes': summarize_nodes_versions,
|
||||||
}
|
}
|
||||||
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
|
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
|
||||||
|
|
|
||||||
79
graphiti_core/prompts/summarize_nodes.py
Normal file
79
graphiti_core/prompts/summarize_nodes.py
Normal file
|
|
@ -0,0 +1,79 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
|
||||||
|
|
||||||
|
class Prompt(Protocol):
|
||||||
|
summarize_pair: PromptVersion
|
||||||
|
summary_description: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
|
class Versions(TypedDict):
|
||||||
|
summarize_pair: PromptFunction
|
||||||
|
summary_description: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_pair(context: dict[str, Any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role='system',
|
||||||
|
content='You are a helpful assistant that combines summaries.',
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role='user',
|
||||||
|
content=f"""
|
||||||
|
Synthesize the information from the following two summaries into a single succinct summary.
|
||||||
|
|
||||||
|
Summaries:
|
||||||
|
{json.dumps(context['node_summaries'], indent=2)}
|
||||||
|
|
||||||
|
Respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"summary": "Summary containing the important information from both summaries"
|
||||||
|
}}
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def summary_description(context: dict[str, Any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role='system',
|
||||||
|
content='You are a helpful assistant that describes provided contents in a single sentence.',
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role='user',
|
||||||
|
content=f"""
|
||||||
|
Create a short one sentence description of the summary that explains what kind of information is summarized.
|
||||||
|
|
||||||
|
Summary:
|
||||||
|
{json.dumps(context['summary'], indent=2)}
|
||||||
|
|
||||||
|
Respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"description": "One sentence description of the provided summary"
|
||||||
|
}}
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
versions: Versions = {'summarize_pair': summarize_pair, 'summary_description': summary_description}
|
||||||
155
graphiti_core/utils/maintenance/community_operations.py
Normal file
155
graphiti_core/utils/maintenance/community_operations.py
Normal file
|
|
@ -0,0 +1,155 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from neo4j import AsyncDriver
|
||||||
|
|
||||||
|
from graphiti_core.edges import CommunityEdge
|
||||||
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||||
|
from graphiti_core.prompts import prompt_library
|
||||||
|
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def build_community_projection(driver: AsyncDriver) -> str:
|
||||||
|
records, _, _ = await driver.execute_query("""
|
||||||
|
CALL gds.graph.project("communities", "Entity",
|
||||||
|
{RELATES_TO: {
|
||||||
|
type: "RELATES_TO",
|
||||||
|
orientation: "UNDIRECTED",
|
||||||
|
properties: {weight: {property: "*", aggregation: "COUNT"}}
|
||||||
|
}}
|
||||||
|
)
|
||||||
|
YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges
|
||||||
|
""")
|
||||||
|
|
||||||
|
return records[0]['graph']
|
||||||
|
|
||||||
|
|
||||||
|
async def destroy_projection(driver: AsyncDriver, projection_name: str):
|
||||||
|
await driver.execute_query(
|
||||||
|
"""
|
||||||
|
CALL gds.graph.drop($projection_name)
|
||||||
|
""",
|
||||||
|
projection_name=projection_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_community_clusters(
|
||||||
|
driver: AsyncDriver, projection_name: str
|
||||||
|
) -> list[list[EntityNode]]:
|
||||||
|
records, _, _ = await driver.execute_query("""
|
||||||
|
CALL gds.leiden.stream("communities")
|
||||||
|
YIELD nodeId, communityId
|
||||||
|
RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
|
||||||
|
""")
|
||||||
|
community_map: dict[int, list[str]] = defaultdict(list)
|
||||||
|
for record in records:
|
||||||
|
community_map[record['communityId']].append(record['entity_uuid'])
|
||||||
|
|
||||||
|
community_clusters: list[list[EntityNode]] = list(
|
||||||
|
await asyncio.gather(
|
||||||
|
*[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return community_clusters
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
|
||||||
|
# Prepare context for LLM
|
||||||
|
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(
|
||||||
|
prompt_library.summarize_nodes.summarize_pair(context)
|
||||||
|
)
|
||||||
|
|
||||||
|
pair_summary = llm_response.get('summary', '')
|
||||||
|
|
||||||
|
return pair_summary
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
|
||||||
|
context = {'summary': summary}
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(
|
||||||
|
prompt_library.summarize_nodes.summary_description(context)
|
||||||
|
)
|
||||||
|
|
||||||
|
description = llm_response.get('description', '')
|
||||||
|
|
||||||
|
return description
|
||||||
|
|
||||||
|
|
||||||
|
async def build_community(
|
||||||
|
llm_client: LLMClient, community_cluster: list[EntityNode]
|
||||||
|
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
||||||
|
summaries = [entity.summary for entity in community_cluster]
|
||||||
|
length = len(summaries)
|
||||||
|
while length > 1:
|
||||||
|
odd_one_out: str | None = None
|
||||||
|
if length % 2 == 1:
|
||||||
|
odd_one_out = summaries.pop()
|
||||||
|
length -= 1
|
||||||
|
new_summaries: list[str] = list(
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
||||||
|
for left_summary, right_summary in zip(
|
||||||
|
summaries[: int(length / 2)], summaries[int(length / 2) :]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if odd_one_out is not None:
|
||||||
|
new_summaries.append(odd_one_out)
|
||||||
|
summaries = new_summaries
|
||||||
|
length = len(summaries)
|
||||||
|
|
||||||
|
summary = summaries[0]
|
||||||
|
name = await generate_summary_description(llm_client, summary)
|
||||||
|
now = datetime.now()
|
||||||
|
community_node = CommunityNode(
|
||||||
|
name=name,
|
||||||
|
group_id=community_cluster[0].group_id,
|
||||||
|
labels=['Community'],
|
||||||
|
created_at=now,
|
||||||
|
summary=summary,
|
||||||
|
)
|
||||||
|
community_edges = build_community_edges(community_cluster, community_node, now)
|
||||||
|
|
||||||
|
logger.info((community_node, community_edges))
|
||||||
|
|
||||||
|
return community_node, community_edges
|
||||||
|
|
||||||
|
|
||||||
|
async def build_communities(
|
||||||
|
driver: AsyncDriver, llm_client: LLMClient
|
||||||
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||||
|
projection = await build_community_projection(driver)
|
||||||
|
community_clusters = await get_community_clusters(driver, projection)
|
||||||
|
|
||||||
|
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
|
||||||
|
await asyncio.gather(
|
||||||
|
*[build_community(llm_client, cluster) for cluster in community_clusters]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
community_nodes: list[CommunityNode] = []
|
||||||
|
community_edges: list[CommunityEdge] = []
|
||||||
|
for community in communities:
|
||||||
|
community_nodes.append(community[0])
|
||||||
|
community_edges.extend(community[1])
|
||||||
|
|
||||||
|
await destroy_projection(driver, projection)
|
||||||
|
return community_nodes, community_edges
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_communities(driver: AsyncDriver):
|
||||||
|
await driver.execute_query("""
|
||||||
|
MATCH (c:Community)
|
||||||
|
DETACH DELETE c
|
||||||
|
""")
|
||||||
|
|
@ -20,9 +20,9 @@ from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||||
extract_edge_dates,
|
extract_edge_dates,
|
||||||
|
|
@ -50,6 +50,24 @@ def build_episodic_edges(
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
def build_community_edges(
|
||||||
|
entity_nodes: List[EntityNode],
|
||||||
|
community_node: CommunityNode,
|
||||||
|
created_at: datetime,
|
||||||
|
) -> List[CommunityEdge]:
|
||||||
|
edges: List[CommunityEdge] = [
|
||||||
|
CommunityEdge(
|
||||||
|
source_node_uuid=community_node.uuid,
|
||||||
|
target_node_uuid=node.uuid,
|
||||||
|
created_at=created_at,
|
||||||
|
group_id=community_node.group_id,
|
||||||
|
)
|
||||||
|
for node in entity_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
return edges
|
||||||
|
|
||||||
|
|
||||||
async def extract_edges(
|
async def extract_edges(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,10 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
||||||
range_indices: list[LiteralString] = [
|
range_indices: list[LiteralString] = [
|
||||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||||
|
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
||||||
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
||||||
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
||||||
|
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||||
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
||||||
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
||||||
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
||||||
|
|
@ -51,6 +53,7 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
||||||
|
|
||||||
fulltext_indices: list[LiteralString] = [
|
fulltext_indices: list[LiteralString] = [
|
||||||
'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]',
|
'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]',
|
||||||
|
'CREATE FULLTEXT INDEX community_name IF NOT EXISTS FOR (n:Community) ON EACH [n.name]',
|
||||||
'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]',
|
'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -71,6 +74,14 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
||||||
`vector.similarity_function`: 'cosine'
|
`vector.similarity_function`: 'cosine'
|
||||||
}}
|
}}
|
||||||
""",
|
""",
|
||||||
|
"""
|
||||||
|
CREATE VECTOR INDEX community_name_embedding IF NOT EXISTS
|
||||||
|
FOR (n:Community) ON (n.name_embedding)
|
||||||
|
OPTIONS {indexConfig: {
|
||||||
|
`vector.dimensions`: 1024,
|
||||||
|
`vector.similarity_function`: 'cosine'
|
||||||
|
}}
|
||||||
|
""",
|
||||||
]
|
]
|
||||||
index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
|
index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,7 @@ def format_context(facts):
|
||||||
async def test_graphiti_init():
|
async def test_graphiti_init():
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||||
|
await graphiti.build_communities()
|
||||||
|
|
||||||
edges = await graphiti.search('Freakenomics guest', group_ids=['1'])
|
edges = await graphiti.search('Freakenomics guest', group_ids=['1'])
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue