Speed up add episode (#77)
* WIP * updates * use uuid for node dedupe * pret-testing * parallelized node resolution * working add_episode * revert to 4o * format * mypy update * update types
This commit is contained in:
parent
db12ac548d
commit
e9e6039b1e
12 changed files with 427 additions and 177 deletions
|
|
@ -70,6 +70,7 @@ async def main(use_bulk: bool = True):
|
|||
reference_time=message.actual_timestamp,
|
||||
source_description='Podcast Transcript',
|
||||
)
|
||||
return
|
||||
|
||||
episodes: list[RawEpisode] = [
|
||||
RawEpisode(
|
||||
|
|
@ -79,10 +80,10 @@ async def main(use_bulk: bool = True):
|
|||
source_description='Podcast Transcript',
|
||||
reference_time=message.actual_timestamp,
|
||||
)
|
||||
for i, message in enumerate(messages[3:14])
|
||||
for i, message in enumerate(messages[3:20])
|
||||
]
|
||||
|
||||
await client.add_episode_bulk(episodes)
|
||||
|
||||
|
||||
asyncio.run(main(True))
|
||||
asyncio.run(main(False))
|
||||
|
|
|
|||
|
|
@ -48,14 +48,17 @@ from graphiti_core.utils.bulk_utils import (
|
|||
retrieve_previous_episodes_bulk,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
dedupe_extracted_edges,
|
||||
extract_edges,
|
||||
resolve_extracted_edges,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||
EPISODE_WINDOW_LEN,
|
||||
build_indices_and_constraints,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
||||
from graphiti_core.utils.maintenance.node_operations import (
|
||||
extract_nodes,
|
||||
resolve_extracted_nodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||
extract_edge_dates,
|
||||
invalidate_edges,
|
||||
|
|
@ -177,9 +180,9 @@ class Graphiti:
|
|||
await build_indices_and_constraints(self.driver)
|
||||
|
||||
async def retrieve_episodes(
|
||||
self,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
self,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
) -> list[EpisodicNode]:
|
||||
"""
|
||||
Retrieve the last n episodic nodes from the graph.
|
||||
|
|
@ -207,14 +210,14 @@ class Graphiti:
|
|||
return await retrieve_episodes(self.driver, reference_time, last_n)
|
||||
|
||||
async def add_episode(
|
||||
self,
|
||||
name: str,
|
||||
episode_body: str,
|
||||
source_description: str,
|
||||
reference_time: datetime,
|
||||
source: EpisodeType = EpisodeType.message,
|
||||
success_callback: Callable | None = None,
|
||||
error_callback: Callable | None = None,
|
||||
self,
|
||||
name: str,
|
||||
episode_body: str,
|
||||
source_description: str,
|
||||
reference_time: datetime,
|
||||
source: EpisodeType = EpisodeType.message,
|
||||
success_callback: Callable | None = None,
|
||||
error_callback: Callable | None = None,
|
||||
):
|
||||
"""
|
||||
Process an episode and update the graph.
|
||||
|
|
@ -265,7 +268,6 @@ class Graphiti:
|
|||
|
||||
nodes: list[EntityNode] = []
|
||||
entity_edges: list[EntityEdge] = []
|
||||
episodic_edges: list[EpisodicEdge] = []
|
||||
embedder = self.llm_client.get_embedder()
|
||||
now = datetime.now()
|
||||
|
||||
|
|
@ -280,6 +282,8 @@ class Graphiti:
|
|||
valid_at=reference_time,
|
||||
)
|
||||
|
||||
# Extract entities as nodes
|
||||
|
||||
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
|
||||
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
|
||||
|
|
@ -288,57 +292,82 @@ class Graphiti:
|
|||
await asyncio.gather(
|
||||
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
||||
)
|
||||
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
|
||||
|
||||
# Resolve extracted nodes with nodes already in the graph
|
||||
existing_nodes_lists: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
|
||||
self.llm_client, extracted_nodes, existing_nodes
|
||||
)
|
||||
logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}')
|
||||
nodes.extend(touched_nodes)
|
||||
|
||||
mentioned_nodes, _ = await resolve_extracted_nodes(
|
||||
self.llm_client, extracted_nodes, existing_nodes_lists
|
||||
)
|
||||
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
||||
nodes.extend(mentioned_nodes)
|
||||
|
||||
# Extract facts as edges given entity nodes
|
||||
extracted_edges = await extract_edges(
|
||||
self.llm_client, episode, touched_nodes, previous_episodes
|
||||
self.llm_client, episode, mentioned_nodes, previous_episodes
|
||||
)
|
||||
|
||||
# calculate embeddings
|
||||
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
|
||||
|
||||
existing_edges = await get_relevant_edges(extracted_edges, self.driver)
|
||||
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
|
||||
# Resolve extracted edges with edges already in the graph
|
||||
existing_edges_list: list[list[EntityEdge]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
get_relevant_edges(
|
||||
[edge],
|
||||
self.driver,
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
edge.source_node_uuid,
|
||||
edge.target_node_uuid,
|
||||
)
|
||||
for edge in extracted_edges
|
||||
]
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
|
||||
)
|
||||
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
|
||||
|
||||
deduped_edges = await dedupe_extracted_edges(
|
||||
self.llm_client,
|
||||
extracted_edges,
|
||||
existing_edges,
|
||||
deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
|
||||
self.llm_client, extracted_edges, existing_edges_list
|
||||
)
|
||||
|
||||
edge_touched_node_uuids = [n.uuid for n in brand_new_nodes]
|
||||
for edge in deduped_edges:
|
||||
edge_touched_node_uuids.append(edge.source_node_uuid)
|
||||
edge_touched_node_uuids.append(edge.target_node_uuid)
|
||||
# Extract dates for the newly extracted edges
|
||||
edge_dates = await asyncio.gather(
|
||||
*[
|
||||
extract_edge_dates(
|
||||
self.llm_client,
|
||||
edge,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
for edge in deduped_edges
|
||||
]
|
||||
)
|
||||
|
||||
for i, edge in enumerate(deduped_edges):
|
||||
valid_at = edge_dates[i][0]
|
||||
invalid_at = edge_dates[i][1]
|
||||
|
||||
for edge in deduped_edges:
|
||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
||||
self.llm_client,
|
||||
edge,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
edge.valid_at = valid_at
|
||||
edge.invalid_at = invalid_at
|
||||
if edge.invalid_at:
|
||||
edge.expired_at = now
|
||||
for edge in existing_edges:
|
||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
||||
self.llm_client,
|
||||
edge,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
edge.valid_at = valid_at
|
||||
edge.invalid_at = invalid_at
|
||||
if edge.invalid_at:
|
||||
if edge.invalid_at is not None:
|
||||
edge.expired_at = now
|
||||
|
||||
entity_edges.extend(deduped_edges)
|
||||
|
||||
existing_edges: list[EntityEdge] = [
|
||||
e for edge_lst in existing_edges_list for e in edge_lst
|
||||
]
|
||||
|
||||
(
|
||||
old_edges_with_nodes_pending_invalidation,
|
||||
new_edges_with_nodes,
|
||||
|
|
@ -361,30 +390,18 @@ class Graphiti:
|
|||
for deduped_edge in deduped_edges:
|
||||
if deduped_edge.uuid == edge.uuid:
|
||||
deduped_edge.expired_at = edge.expired_at
|
||||
edge_touched_node_uuids.append(edge.source_node_uuid)
|
||||
edge_touched_node_uuids.append(edge.target_node_uuid)
|
||||
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
|
||||
|
||||
edges_to_save = existing_edges + deduped_edges
|
||||
|
||||
entity_edges.extend(edges_to_save)
|
||||
|
||||
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
|
||||
involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids]
|
||||
|
||||
logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
|
||||
entity_edges.extend(existing_edges)
|
||||
|
||||
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
|
||||
|
||||
episodic_edges.extend(
|
||||
build_episodic_edges(
|
||||
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
||||
involved_nodes,
|
||||
episode,
|
||||
now,
|
||||
)
|
||||
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
|
||||
mentioned_nodes,
|
||||
episode,
|
||||
now,
|
||||
)
|
||||
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
|
||||
|
||||
logger.info(f'Built episodic edges: {episodic_edges}')
|
||||
|
||||
# Future optimization would be using batch operations to save nodes and edges
|
||||
|
|
@ -395,9 +412,7 @@ class Graphiti:
|
|||
|
||||
end = time()
|
||||
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
||||
# for node in nodes:
|
||||
# if isinstance(node, EntityNode):
|
||||
# await node.update_summary(self.driver)
|
||||
|
||||
if success_callback:
|
||||
await success_callback(episode)
|
||||
except Exception as e:
|
||||
|
|
@ -407,8 +422,8 @@ class Graphiti:
|
|||
raise e
|
||||
|
||||
async def add_episode_bulk(
|
||||
self,
|
||||
bulk_episodes: list[RawEpisode],
|
||||
self,
|
||||
bulk_episodes: list[RawEpisode],
|
||||
):
|
||||
"""
|
||||
Process multiple episodes in bulk and update the graph.
|
||||
|
|
@ -572,18 +587,18 @@ class Graphiti:
|
|||
return edges
|
||||
|
||||
async def _search(
|
||||
self,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
self,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
):
|
||||
return await hybrid_search(
|
||||
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
||||
)
|
||||
|
||||
async def get_nodes_by_query(
|
||||
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
|
||||
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
Retrieve nodes from the graph database based on a text query.
|
||||
|
|
|
|||
|
|
@ -23,12 +23,14 @@ from .models import Message, PromptFunction, PromptVersion
|
|||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
v2: PromptVersion
|
||||
v3: PromptVersion
|
||||
edge_list: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
v2: PromptFunction
|
||||
v3: PromptFunction
|
||||
edge_list: PromptFunction
|
||||
|
||||
|
||||
|
|
@ -41,17 +43,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Given the following context, deduplicate facts from a list of new facts given a list of existing facts:
|
||||
Given the following context, deduplicate facts from a list of new facts given a list of existing edges:
|
||||
|
||||
Existing Facts:
|
||||
Existing Edges:
|
||||
{json.dumps(context['existing_edges'], indent=2)}
|
||||
|
||||
New Facts:
|
||||
New Edges:
|
||||
{json.dumps(context['extracted_edges'], indent=2)}
|
||||
|
||||
Task:
|
||||
If any facts in New Facts is a duplicate of a fact in Existing Facts,
|
||||
do not return it in the list of unique facts.
|
||||
If any edge in New Edges is a duplicate of an edge in Existing Edges, add their uuids to the output list.
|
||||
When finding duplicates edges, synthesize their facts into a short new fact.
|
||||
|
||||
Guidelines:
|
||||
1. identical or near identical facts are duplicates
|
||||
|
|
@ -60,9 +62,11 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"unique_facts": [
|
||||
"duplicates": [
|
||||
{{
|
||||
"uuid": "unique identifier of the fact"
|
||||
"uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39",
|
||||
"duplicate_of": "uuid of the existing node",
|
||||
"fact": "one sentence description of the fact"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
|
@ -113,6 +117,40 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
def v3(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content='You are a helpful assistant that de-duplicates edges from edge lists.',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
|
||||
|
||||
Existing Edges:
|
||||
{json.dumps(context['existing_edges'], indent=2)}
|
||||
|
||||
New Edge:
|
||||
{json.dumps(context['extracted_edges'], indent=2)}
|
||||
Task:
|
||||
1. If the New Edges represents the same factual information as any edge in Existing Edges, return 'is_duplicate: true' in the
|
||||
response. Otherwise, return 'is_duplicate: false'
|
||||
2. If is_duplicate is true, also return the uuid of the existing edge in the response
|
||||
|
||||
Guidelines:
|
||||
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"is_duplicate": true or false,
|
||||
"uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
|
||||
}}
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def edge_list(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
|
|
@ -151,4 +189,4 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list}
|
||||
versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'edge_list': edge_list}
|
||||
|
|
|
|||
|
|
@ -23,13 +23,15 @@ from .models import Message, PromptFunction, PromptVersion
|
|||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
v2: PromptVersion
|
||||
v3: PromptVersion
|
||||
node_list: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
v2: PromptFunction
|
||||
node_list: PromptVersion
|
||||
v3: PromptFunction
|
||||
node_list: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, Any]) -> list[Message]:
|
||||
|
|
@ -94,22 +96,22 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
|||
Important:
|
||||
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
|
||||
Task:
|
||||
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
|
||||
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their uuids to the output list
|
||||
When finding duplicates nodes, synthesize their summaries into a short new summary that contains the
|
||||
relevant information of the summaries of the new and existing nodes.
|
||||
|
||||
Guidelines:
|
||||
1. Use both the name and summary of nodes to determine if they are duplicates,
|
||||
duplicate nodes may have different names
|
||||
2. In the output, name should always be the name of the New Node that is a duplicate. duplicate_of should be
|
||||
the name of the Existing Node.
|
||||
2. In the output, uuid should always be the uuid of the New Node that is a duplicate. duplicate_of should be
|
||||
the uuid of the Existing Node.
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"duplicates": [
|
||||
{{
|
||||
"name": "name of the new node",
|
||||
"duplicate_of": "name of the existing node",
|
||||
"uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39",
|
||||
"duplicate_of": "uuid of the existing node",
|
||||
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes"
|
||||
}}
|
||||
]
|
||||
|
|
@ -119,6 +121,44 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
def v3(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content='You are a helpful assistant that de-duplicates nodes from node lists.',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Given the following context, determine whether the New Node represents any of the entities in the list of Existing Nodes.
|
||||
|
||||
Existing Nodes:
|
||||
{json.dumps(context['existing_nodes'], indent=2)}
|
||||
|
||||
New Node:
|
||||
{json.dumps(context['extracted_nodes'], indent=2)}
|
||||
Task:
|
||||
1. If the New Node represents the same entity as any node in Existing Nodes, return 'is_duplicate: true' in the
|
||||
response. Otherwise, return 'is_duplicate: false'
|
||||
2. If is_duplicate is true, also return the uuid of the existing node in the response
|
||||
3. If is_duplicate is true, return a summary that synthesizes the information in the New Node summary and the
|
||||
summary of the Existing Node it is a duplicate of.
|
||||
|
||||
Guidelines:
|
||||
1. Use both the name and summary of nodes to determine if the entities are duplicates,
|
||||
duplicate nodes may have different names
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"is_duplicate": true or false,
|
||||
"uuid": "uuid of the existing node like 5d643020624c42fa9de13f97b1b3fa39 or null",
|
||||
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing node"
|
||||
}}
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def node_list(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
|
|
@ -134,19 +174,19 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
|||
{json.dumps(context['nodes'], indent=2)}
|
||||
|
||||
Task:
|
||||
1. Group nodes together such that all duplicate nodes are in the same list of names
|
||||
2. All duplicate names should be grouped together in the same list
|
||||
1. Group nodes together such that all duplicate nodes are in the same list of uuids
|
||||
2. All duplicate uuids should be grouped together in the same list
|
||||
3. Also return a new summary that synthesizes the summary into a new short summary
|
||||
|
||||
Guidelines:
|
||||
1. Each name from the list of nodes should appear EXACTLY once in your response
|
||||
2. If a node has no duplicates, it should appear in the response in a list of only one name
|
||||
1. Each uuid from the list of nodes should appear EXACTLY once in your response
|
||||
2. If a node has no duplicates, it should appear in the response in a list of only one uuid
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"nodes": [
|
||||
{{
|
||||
"names": ["myNode", "node that is a duplicate of myNode"],
|
||||
"uuids": ["5d643020624c42fa9de13f97b1b3fa39", "node that is a duplicate of 5d643020624c42fa9de13f97b1b3fa39"],
|
||||
"summary": "Brief summary of the node summaries that appear in the list of names."
|
||||
}}
|
||||
]
|
||||
|
|
@ -156,4 +196,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
versions: Versions = {'v1': v1, 'v2': v2, 'node_list': node_list}
|
||||
versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'node_list': node_list}
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
|||
1. Focus on entities, concepts, or actors that are central to the current episode.
|
||||
2. Avoid creating nodes for relationships or actions (these will be handled as edges later).
|
||||
3. Provide a brief but informative summary for each node.
|
||||
4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
|
|
@ -90,6 +91,7 @@ Guidelines:
|
|||
3. Provide concise but informative summaries for each extracted node.
|
||||
4. Avoid creating nodes for relationships or actions.
|
||||
5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
|
||||
6. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ async def hybrid_search(
|
|||
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
||||
|
||||
if SearchMethod.bm25 in config.search_methods:
|
||||
text_search = await edge_fulltext_search(query, driver, 2 * config.num_edges)
|
||||
text_search = await edge_fulltext_search(driver, query, 2 * config.num_edges)
|
||||
search_results.append(text_search)
|
||||
|
||||
if SearchMethod.cosine_similarity in config.search_methods:
|
||||
|
|
@ -95,7 +95,7 @@ async def hybrid_search(
|
|||
)
|
||||
|
||||
similarity_search = await edge_similarity_search(
|
||||
search_vector, driver, 2 * config.num_edges
|
||||
driver, search_vector, 2 * config.num_edges
|
||||
)
|
||||
search_results.append(similarity_search)
|
||||
|
||||
|
|
|
|||
|
|
@ -96,14 +96,18 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
|||
|
||||
|
||||
async def edge_similarity_search(
|
||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
driver: AsyncDriver,
|
||||
search_vector: list[float],
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
source_node_uuid: str = '*',
|
||||
target_node_uuid: str = '*',
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS r, score
|
||||
MATCH (n)-[r:RELATES_TO]->(m)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
|
|
@ -119,6 +123,8 @@ async def edge_similarity_search(
|
|||
ORDER BY score DESC
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
|
@ -214,7 +220,11 @@ async def entity_fulltext_search(
|
|||
|
||||
|
||||
async def edge_fulltext_search(
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
source_node_uuid: str = '*',
|
||||
target_node_uuid: str = '*',
|
||||
) -> list[EntityEdge]:
|
||||
# fulltext search over facts
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
|
|
@ -222,8 +232,8 @@ async def edge_fulltext_search(
|
|||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS r, score
|
||||
MATCH (n:Entity)-[r]->(m:Entity)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
|
|
@ -239,6 +249,8 @@ async def edge_fulltext_search(
|
|||
ORDER BY score DESC LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
|
@ -369,6 +381,9 @@ async def get_relevant_nodes(
|
|||
async def get_relevant_edges(
|
||||
edges: list[EntityEdge],
|
||||
driver: AsyncDriver,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
source_node_uuid: str = '*',
|
||||
target_node_uuid: str = '*',
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
relevant_edges: list[EntityEdge] = []
|
||||
|
|
@ -376,11 +391,16 @@ async def get_relevant_edges(
|
|||
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
edge_similarity_search(edge.fact_embedding, driver)
|
||||
edge_similarity_search(
|
||||
driver, edge.fact_embedding, limit, source_node_uuid, target_node_uuid
|
||||
)
|
||||
for edge in edges
|
||||
if edge.fact_embedding is not None
|
||||
],
|
||||
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
|
||||
*[
|
||||
edge_fulltext_search(driver, edge.fact, limit, source_node_uuid, target_node_uuid)
|
||||
for edge in edges
|
||||
],
|
||||
)
|
||||
|
||||
for result in results:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ limitations under the License.
|
|||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from math import ceil
|
||||
|
||||
|
|
@ -43,6 +42,7 @@ from graphiti_core.utils.maintenance.node_operations import (
|
|||
extract_nodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
||||
from graphiti_core.utils.utils import chunk_edges_by_nodes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -128,7 +128,7 @@ async def dedupe_nodes_bulk(
|
|||
)
|
||||
)
|
||||
|
||||
results: list[tuple[list[EntityNode], dict[str, str], list[EntityNode]]] = list(
|
||||
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
||||
|
|
@ -265,19 +265,7 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
|
|||
return edges
|
||||
# We only want to dedupe edges that are between the same pair of nodes
|
||||
# We build a map of the edges based on their source and target nodes.
|
||||
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
# We drop loop edges
|
||||
if edge.source_node_uuid == edge.target_node_uuid:
|
||||
continue
|
||||
|
||||
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
||||
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
||||
pointers.sort()
|
||||
|
||||
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
||||
|
||||
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
||||
edge_chunks = chunk_edges_by_nodes(edges)
|
||||
|
||||
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
|
@ -109,8 +110,8 @@ async def dedupe_extracted_edges(
|
|||
existing_edges: list[EntityEdge],
|
||||
) -> list[EntityEdge]:
|
||||
# Create edge map
|
||||
edge_map = {}
|
||||
for edge in extracted_edges:
|
||||
edge_map: dict[str, EntityEdge] = {}
|
||||
for edge in existing_edges:
|
||||
edge_map[edge.uuid] = edge
|
||||
|
||||
# Prepare context for LLM
|
||||
|
|
@ -124,18 +125,85 @@ async def dedupe_extracted_edges(
|
|||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
|
||||
unique_edge_data = llm_response.get('unique_facts', [])
|
||||
logger.info(f'Extracted unique edges: {unique_edge_data}')
|
||||
duplicate_data = llm_response.get('duplicates', [])
|
||||
logger.info(f'Extracted unique edges: {duplicate_data}')
|
||||
|
||||
duplicate_uuid_map: dict[str, str] = {}
|
||||
for duplicate in duplicate_data:
|
||||
uuid_value = duplicate['duplicate_of']
|
||||
duplicate_uuid_map[duplicate['uuid']] = uuid_value
|
||||
|
||||
# Get full edge data
|
||||
edges = []
|
||||
for unique_edge in unique_edge_data:
|
||||
edge = edge_map[unique_edge['uuid']]
|
||||
edges.append(edge)
|
||||
edges: list[EntityEdge] = []
|
||||
for edge in extracted_edges:
|
||||
if edge.uuid in duplicate_uuid_map:
|
||||
existing_uuid = duplicate_uuid_map[edge.uuid]
|
||||
existing_edge = edge_map[existing_uuid]
|
||||
edges.append(existing_edge)
|
||||
else:
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def resolve_extracted_edges(
|
||||
llm_client: LLMClient,
|
||||
extracted_edges: list[EntityEdge],
|
||||
existing_edges_lists: list[list[EntityEdge]],
|
||||
) -> list[EntityEdge]:
|
||||
resolved_edges: list[EntityEdge] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
resolve_extracted_edge(llm_client, extracted_edge, existing_edges)
|
||||
for extracted_edge, existing_edges in zip(extracted_edges, existing_edges_lists)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return resolved_edges
|
||||
|
||||
|
||||
async def resolve_extracted_edge(
|
||||
llm_client: LLMClient, extracted_edge: EntityEdge, existing_edges: list[EntityEdge]
|
||||
) -> EntityEdge:
|
||||
start = time()
|
||||
|
||||
# Prepare context for LLM
|
||||
existing_edges_context = [
|
||||
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
|
||||
]
|
||||
|
||||
extracted_edge_context = {
|
||||
'uuid': extracted_edge.uuid,
|
||||
'name': extracted_edge.name,
|
||||
'fact': extracted_edge.fact,
|
||||
}
|
||||
|
||||
context = {
|
||||
'existing_edges': existing_edges_context,
|
||||
'extracted_edges': extracted_edge_context,
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v3(context))
|
||||
|
||||
is_duplicate: bool = llm_response.get('is_duplicate', False)
|
||||
uuid: str | None = llm_response.get('uuid', None)
|
||||
|
||||
edge = extracted_edge
|
||||
if is_duplicate:
|
||||
for existing_edge in existing_edges:
|
||||
if existing_edge.uuid != uuid:
|
||||
continue
|
||||
edge = existing_edge
|
||||
|
||||
end = time()
|
||||
logger.info(
|
||||
f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
||||
)
|
||||
|
||||
return edge
|
||||
|
||||
|
||||
async def dedupe_edge_list(
|
||||
llm_client: LLMClient,
|
||||
edges: list[EntityEdge],
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
|
@ -27,7 +28,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def extract_message_nodes(
|
||||
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
||||
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
||||
) -> list[dict[str, Any]]:
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
|
|
@ -48,8 +49,8 @@ async def extract_message_nodes(
|
|||
|
||||
|
||||
async def extract_json_nodes(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
) -> list[dict[str, Any]]:
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
|
|
@ -66,9 +67,9 @@ async def extract_json_nodes(
|
|||
|
||||
|
||||
async def extract_nodes(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
extracted_node_data: list[dict[str, Any]] = []
|
||||
|
|
@ -95,29 +96,24 @@ async def extract_nodes(
|
|||
|
||||
|
||||
async def dedupe_extracted_nodes(
|
||||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
existing_nodes: list[EntityNode],
|
||||
) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]:
|
||||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
existing_nodes: list[EntityNode],
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
start = time()
|
||||
|
||||
# build existing node map
|
||||
node_map: dict[str, EntityNode] = {}
|
||||
for node in existing_nodes:
|
||||
node_map[node.name] = node
|
||||
|
||||
# Temp hack
|
||||
new_nodes_map: dict[str, EntityNode] = {}
|
||||
for node in extracted_nodes:
|
||||
new_nodes_map[node.name] = node
|
||||
node_map[node.uuid] = node
|
||||
|
||||
# Prepare context for LLM
|
||||
existing_nodes_context = [
|
||||
{'name': node.name, 'summary': node.summary} for node in existing_nodes
|
||||
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
|
||||
]
|
||||
|
||||
extracted_nodes_context = [
|
||||
{'name': node.name, 'summary': node.summary} for node in extracted_nodes
|
||||
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in extracted_nodes
|
||||
]
|
||||
|
||||
context = {
|
||||
|
|
@ -134,42 +130,104 @@ async def dedupe_extracted_nodes(
|
|||
|
||||
uuid_map: dict[str, str] = {}
|
||||
for duplicate in duplicate_data:
|
||||
uuid = new_nodes_map[duplicate['name']].uuid
|
||||
uuid_value = node_map[duplicate['duplicate_of']].uuid
|
||||
uuid_map[uuid] = uuid_value
|
||||
uuid_value = duplicate['duplicate_of']
|
||||
uuid_map[duplicate['uuid']] = uuid_value
|
||||
|
||||
nodes: list[EntityNode] = []
|
||||
brand_new_nodes: list[EntityNode] = []
|
||||
for node in extracted_nodes:
|
||||
if node.uuid in uuid_map:
|
||||
existing_uuid = uuid_map[node.uuid]
|
||||
# TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes,
|
||||
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
|
||||
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
|
||||
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
|
||||
if existing_node:
|
||||
nodes.append(existing_node)
|
||||
existing_node = node_map[existing_uuid]
|
||||
nodes.append(existing_node)
|
||||
else:
|
||||
nodes.append(node)
|
||||
|
||||
continue
|
||||
brand_new_nodes.append(node)
|
||||
nodes.append(node)
|
||||
return nodes, uuid_map
|
||||
|
||||
return nodes, uuid_map, brand_new_nodes
|
||||
|
||||
async def resolve_extracted_nodes(
|
||||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
existing_nodes_lists: list[list[EntityNode]],
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
uuid_map: dict[str, str] = {}
|
||||
resolved_nodes: list[EntityNode] = []
|
||||
results: list[tuple[EntityNode, dict[str, str]]] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
resolve_extracted_node(llm_client, extracted_node, existing_nodes)
|
||||
for extracted_node, existing_nodes in zip(extracted_nodes, existing_nodes_lists)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
for result in results:
|
||||
uuid_map.update(result[1])
|
||||
resolved_nodes.append(result[0])
|
||||
|
||||
return resolved_nodes, uuid_map
|
||||
|
||||
|
||||
async def resolve_extracted_node(
|
||||
llm_client: LLMClient, extracted_node: EntityNode, existing_nodes: list[EntityNode]
|
||||
) -> tuple[EntityNode, dict[str, str]]:
|
||||
start = time()
|
||||
|
||||
# Prepare context for LLM
|
||||
existing_nodes_context = [
|
||||
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in existing_nodes
|
||||
]
|
||||
|
||||
extracted_node_context = {
|
||||
'uuid': extracted_node.uuid,
|
||||
'name': extracted_node.name,
|
||||
'summary': extracted_node.summary,
|
||||
}
|
||||
|
||||
context = {
|
||||
'existing_nodes': existing_nodes_context,
|
||||
'extracted_nodes': extracted_node_context,
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v3(context))
|
||||
|
||||
is_duplicate: bool = llm_response.get('is_duplicate', False)
|
||||
uuid: str | None = llm_response.get('uuid', None)
|
||||
summary = llm_response.get('summary', '')
|
||||
|
||||
node = extracted_node
|
||||
uuid_map: dict[str, str] = {}
|
||||
if is_duplicate:
|
||||
for existing_node in existing_nodes:
|
||||
if existing_node.uuid != uuid:
|
||||
continue
|
||||
node = existing_node
|
||||
node.summary = summary
|
||||
uuid_map[extracted_node.uuid] = existing_node.uuid
|
||||
|
||||
end = time()
|
||||
logger.info(
|
||||
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
|
||||
)
|
||||
|
||||
return node, uuid_map
|
||||
|
||||
|
||||
async def dedupe_node_list(
|
||||
llm_client: LLMClient,
|
||||
nodes: list[EntityNode],
|
||||
llm_client: LLMClient,
|
||||
nodes: list[EntityNode],
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
start = time()
|
||||
|
||||
# build node map
|
||||
node_map = {}
|
||||
for node in nodes:
|
||||
node_map[node.name] = node
|
||||
node_map[node.uuid] = node
|
||||
|
||||
# Prepare context for LLM
|
||||
nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes]
|
||||
nodes_context = [
|
||||
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in nodes
|
||||
]
|
||||
|
||||
context = {
|
||||
'nodes': nodes_context,
|
||||
|
|
@ -188,13 +246,12 @@ async def dedupe_node_list(
|
|||
unique_nodes = []
|
||||
uuid_map: dict[str, str] = {}
|
||||
for node_data in nodes_data:
|
||||
node = node_map[node_data['names'][0]]
|
||||
node = node_map[node_data['uuids'][0]]
|
||||
node.summary = node_data['summary']
|
||||
unique_nodes.append(node)
|
||||
|
||||
for name in node_data['names'][1:]:
|
||||
uuid = node_map[name].uuid
|
||||
uuid_value = node_map[node_data['names'][0]].uuid
|
||||
for uuid in node_data['uuids'][1:]:
|
||||
uuid_value = node_map[node_data['uuids'][0]].uuid
|
||||
uuid_map[uuid] = uuid_value
|
||||
|
||||
return unique_nodes, uuid_map
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ async def extract_edge_dates(
|
|||
edge: EntityEdge,
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: List[EpisodicNode],
|
||||
) -> tuple[datetime | None, datetime | None, str]:
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
context = {
|
||||
'edge_name': edge.name,
|
||||
'edge_fact': edge.fact,
|
||||
|
|
@ -180,4 +180,4 @@ async def extract_edge_dates(
|
|||
|
||||
logger.info(f'Edge date extraction explanation: {explanation}')
|
||||
|
||||
return valid_at_datetime, invalid_at_datetime, explanation
|
||||
return valid_at_datetime, invalid_at_datetime
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from graphiti_core.edges import EpisodicEdge
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -37,3 +38,23 @@ def build_episodic_edges(
|
|||
)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def chunk_edges_by_nodes(edges: list[EntityEdge]) -> list[list[EntityEdge]]:
|
||||
# We only want to dedupe edges that are between the same pair of nodes
|
||||
# We build a map of the edges based on their source and target nodes.
|
||||
edge_chunk_map: dict[str, list[EntityEdge]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
# We drop loop edges
|
||||
if edge.source_node_uuid == edge.target_node_uuid:
|
||||
continue
|
||||
|
||||
# Keep the order of the two nodes consistent, we want to be direction agnostic during edge resolution
|
||||
pointers = [edge.source_node_uuid, edge.target_node_uuid]
|
||||
pointers.sort()
|
||||
|
||||
edge_chunk_map[pointers[0] + pointers[1]].append(edge)
|
||||
|
||||
edge_chunks = [chunk for chunk in edge_chunk_map.values()]
|
||||
|
||||
return edge_chunks
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue