Add episode refactor (#85)

* temp commit while moving

* fix name embedding bug

* invalidation

* format

* tests on runner examples

* format

* ellipsis

* ruff

* fix

* format

* minor prompt change
This commit is contained in:
Preston Rasmussen 2024-09-05 12:05:44 -04:00 committed by GitHub
parent 1d31442751
commit 299021173b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 261 additions and 106 deletions

View file

@ -94,7 +94,7 @@ async def main():
async def ingest_products_data(client: Graphiti): async def ingest_products_data(client: Graphiti):
script_dir = Path(__file__).parent script_dir = Path(__file__).parent
json_file_path = script_dir / 'allbirds_products.json' json_file_path = script_dir / '../data/manybirds_products.json'
with open(json_file_path) as file: with open(json_file_path) as file:
products = json.load(file)['products'] products = json.load(file)['products']
@ -110,7 +110,14 @@ async def ingest_products_data(client: Graphiti):
for i, product in enumerate(products) for i, product in enumerate(products)
] ]
await client.add_episode_bulk(episodes) for episode in episodes:
await client.add_episode(
episode.name,
episode.content,
episode.source_description,
episode.reference_time,
episode.source,
)
asyncio.run(main()) asyncio.run(main())

View file

@ -86,4 +86,4 @@ async def main(use_bulk: bool = True):
await client.add_episode_bulk(episodes) await client.add_episode_bulk(episodes)
asyncio.run(main(True)) asyncio.run(main(False))

View file

@ -59,11 +59,6 @@ from graphiti_core.utils.maintenance.node_operations import (
extract_nodes, extract_nodes,
resolve_extracted_nodes, resolve_extracted_nodes,
) )
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
invalidate_edges,
prepare_edges_for_invalidation,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -293,7 +288,7 @@ class Graphiti:
*[node.generate_name_embedding(embedder) for node in extracted_nodes] *[node.generate_name_embedding(embedder) for node in extracted_nodes]
) )
# Resolve extracted nodes with nodes already in the graph # Resolve extracted nodes with nodes already in the graph and extract facts
existing_nodes_lists: list[list[EntityNode]] = list( existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather( await asyncio.gather(
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes] *[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
@ -302,22 +297,27 @@ class Graphiti:
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
mentioned_nodes, _ = await resolve_extracted_nodes( (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
self.llm_client, extracted_nodes, existing_nodes_lists resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
) )
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}') logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes) nodes.extend(mentioned_nodes)
# Extract facts as edges given entity nodes extracted_edges_with_resolved_pointers = resolve_edge_pointers(
extracted_edges = await extract_edges( extracted_edges, uuid_map
self.llm_client, episode, mentioned_nodes, previous_episodes
) )
# calculate embeddings # calculate embeddings
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges]) await asyncio.gather(
*[
edge.generate_embedding(embedder)
for edge in extracted_edges_with_resolved_pointers
]
)
# Resolve extracted edges with edges already in the graph # Resolve extracted edges with related edges already in the graph
existing_edges_list: list[list[EntityEdge]] = list( related_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather( await asyncio.gather(
*[ *[
get_relevant_edges( get_relevant_edges(
@ -327,74 +327,66 @@ class Graphiti:
edge.target_node_uuid, edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT, RELEVANT_SCHEMA_LIMIT,
) )
for edge in extracted_edges for edge in extracted_edges_with_resolved_pointers
] ]
) )
) )
logger.info( logger.info(
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}' f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
) )
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}') logger.info(
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
self.llm_client, extracted_edges, existing_edges_list
) )
# Extract dates for the newly extracted edges existing_source_edges_list: list[list[EntityEdge]] = list(
edge_dates = await asyncio.gather( await asyncio.gather(
*[ *[
extract_edge_dates( get_relevant_edges(
self.llm_client, self.driver,
edge, [edge],
episode, edge.source_node_uuid,
previous_episodes, None,
) RELEVANT_SCHEMA_LIMIT,
for edge in deduped_edges )
] for edge in extracted_edges_with_resolved_pointers
]
)
) )
for i, edge in enumerate(deduped_edges): existing_target_edges_list: list[list[EntityEdge]] = list(
valid_at = edge_dates[i][0] await asyncio.gather(
invalid_at = edge_dates[i][1] *[
get_relevant_edges(
self.driver,
[edge],
None,
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)
edge.valid_at = valid_at existing_edges_list: list[list[EntityEdge]] = [
edge.invalid_at = invalid_at source_lst + target_lst
if edge.invalid_at is not None: for source_lst, target_lst in zip(
edge.expired_at = now existing_source_edges_list, existing_target_edges_list
)
entity_edges.extend(deduped_edges)
existing_edges: list[EntityEdge] = [
e for edge_lst in existing_edges_list for e in edge_lst
] ]
( resolved_edges, invalidated_edges = await resolve_extracted_edges(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
) = prepare_edges_for_invalidation(
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
)
invalidated_edges = await invalidate_edges(
self.llm_client, self.llm_client,
old_edges_with_nodes_pending_invalidation, extracted_edges_with_resolved_pointers,
new_edges_with_nodes, related_edges_list,
existing_edges_list,
episode, episode,
previous_episodes, previous_episodes,
) )
for edge in invalidated_edges: entity_edges.extend(resolved_edges + invalidated_edges)
for existing_edge in existing_edges:
if existing_edge.uuid == edge.uuid:
existing_edge.expired_at = edge.expired_at
for deduped_edge in deduped_edges:
if deduped_edge.uuid == edge.uuid:
deduped_edge.expired_at = edge.expired_at
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
entity_edges.extend(existing_edges) logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
episodic_edges: list[EpisodicEdge] = build_episodic_edges( episodic_edges: list[EpisodicEdge] = build_episodic_edges(
mentioned_nodes, mentioned_nodes,

View file

@ -129,7 +129,7 @@ def v3(context: dict[str, Any]) -> list[Message]:
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges. Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
Existing Edges: Existing Edges:
{json.dumps(context['existing_edges'], indent=2)} {json.dumps(context['related_edges'], indent=2)}
New Edge: New Edge:
{json.dumps(context['extracted_edges'], indent=2)} {json.dumps(context['extracted_edges'], indent=2)}

View file

@ -21,10 +21,12 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol): class Prompt(Protocol):
v1: PromptVersion v1: PromptVersion
v2: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
v1: PromptFunction v1: PromptFunction
v2: PromptFunction
def v1(context: dict[str, Any]) -> list[Message]: def v1(context: dict[str, Any]) -> list[Message]:
@ -71,4 +73,38 @@ def v1(context: dict[str, Any]) -> list[Message]:
] ]
versions: Versions = {'v1': v1} def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.',
),
Message(
role='user',
content=f"""
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
Existing Edges:
{context['existing_edges']}
New Edge:
{context['new_edge']}
For each existing edge that should be invalidated, respond with a JSON object in the following format:
{{
"invalidated_edges": [
{{
"uuid": "The UUID of the edge to be invalidated",
"fact": "Updated fact of the edge"
}}
]
}}
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
""",
),
]
versions: Versions = {'v1': v1, 'v2': v2}

View file

@ -96,11 +96,11 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
async def edge_similarity_search( async def edge_similarity_search(
driver: AsyncDriver, driver: AsyncDriver,
search_vector: list[float], search_vector: list[float],
source_node_uuid: str | None, source_node_uuid: str | None,
target_node_uuid: str | None, target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# vector similarity search over embedded facts # vector similarity search over embedded facts
query = Query(""" query = Query("""
@ -211,7 +211,7 @@ async def edge_similarity_search(
async def entity_similarity_search( async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]: ) -> list[EntityNode]:
# vector similarity search over entity names # vector similarity search over entity names
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
@ -247,7 +247,7 @@ async def entity_similarity_search(
async def entity_fulltext_search( async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]: ) -> list[EntityNode]:
# BM25 search to get top nodes # BM25 search to get top nodes
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
@ -284,11 +284,11 @@ async def entity_fulltext_search(
async def edge_fulltext_search( async def edge_fulltext_search(
driver: AsyncDriver, driver: AsyncDriver,
query: str, query: str,
source_node_uuid: str | None, source_node_uuid: str | None,
target_node_uuid: str | None, target_node_uuid: str | None,
limit=RELEVANT_SCHEMA_LIMIT, limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
# fulltext search over facts # fulltext search over facts
cypher_query = Query(""" cypher_query = Query("""
@ -401,10 +401,10 @@ async def edge_fulltext_search(
async def hybrid_node_search( async def hybrid_node_search(
queries: list[str], queries: list[str],
embeddings: list[list[float]], embeddings: list[list[float]],
driver: AsyncDriver, driver: AsyncDriver,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
""" """
Perform a hybrid search for nodes using both text queries and embeddings. Perform a hybrid search for nodes using both text queries and embeddings.
@ -466,8 +466,8 @@ async def hybrid_node_search(
async def get_relevant_nodes( async def get_relevant_nodes(
nodes: list[EntityNode], nodes: list[EntityNode],
driver: AsyncDriver, driver: AsyncDriver,
) -> list[EntityNode]: ) -> list[EntityNode]:
""" """
Retrieve relevant nodes based on the provided list of EntityNodes. Retrieve relevant nodes based on the provided list of EntityNodes.
@ -503,11 +503,11 @@ async def get_relevant_nodes(
async def get_relevant_edges( async def get_relevant_edges(
driver: AsyncDriver, driver: AsyncDriver,
edges: list[EntityEdge], edges: list[EntityEdge],
source_node_uuid: str | None, source_node_uuid: str | None,
target_node_uuid: str | None, target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT, limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
start = time() start = time()
relevant_edges: list[EntityEdge] = [] relevant_edges: list[EntityEdge] = []
@ -557,7 +557,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
async def node_distance_reranker( async def node_distance_reranker(
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
) -> list[str]: ) -> list[str]:
# use rrf as a preliminary ranker # use rrf as a preliminary ranker
sorted_uuids = rrf(results) sorted_uuids = rrf(results)
@ -579,8 +579,8 @@ async def node_distance_reranker(
for record in records: for record in records:
if ( if (
record['source_uuid'] == center_node_uuid record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid
): ):
continue continue
distance = record['score'] distance = record['score']

View file

@ -24,6 +24,10 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodicNode from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library from graphiti_core.prompts import prompt_library
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
get_edge_contradictions,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -149,28 +153,110 @@ async def dedupe_extracted_edges(
async def resolve_extracted_edges( async def resolve_extracted_edges(
llm_client: LLMClient, llm_client: LLMClient,
extracted_edges: list[EntityEdge], extracted_edges: list[EntityEdge],
related_edges_lists: list[list[EntityEdge]],
existing_edges_lists: list[list[EntityEdge]], existing_edges_lists: list[list[EntityEdge]],
) -> list[EntityEdge]: current_episode: EpisodicNode,
resolved_edges: list[EntityEdge] = list( previous_episodes: list[EpisodicNode],
) -> tuple[list[EntityEdge], list[EntityEdge]]:
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
await asyncio.gather( await asyncio.gather(
*[ *[
resolve_extracted_edge(llm_client, extracted_edge, existing_edges) resolve_extracted_edge(
for extracted_edge, existing_edges in zip(extracted_edges, existing_edges_lists) llm_client,
extracted_edge,
related_edges,
existing_edges,
current_episode,
previous_episodes,
)
for extracted_edge, related_edges, existing_edges in zip(
extracted_edges, related_edges_lists, existing_edges_lists
)
] ]
) )
) )
return resolved_edges resolved_edges: list[EntityEdge] = []
invalidated_edges: list[EntityEdge] = []
for result in results:
resolved_edge = result[0]
invalidated_edge_chunk = result[1]
resolved_edges.append(resolved_edge)
invalidated_edges.extend(invalidated_edge_chunk)
return resolved_edges, invalidated_edges
async def resolve_extracted_edge( async def resolve_extracted_edge(
llm_client: LLMClient, extracted_edge: EntityEdge, existing_edges: list[EntityEdge] llm_client: LLMClient,
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
)
now = datetime.now()
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
if invalid_at is not None and resolved_edge.expired_at is None:
resolved_edge.expired_at = now
# Determine if the new_edge needs to be expired
if resolved_edge.expired_at is None:
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
for candidate in invalidation_candidates:
if (
candidate.valid_at is not None and resolved_edge.valid_at is not None
) and candidate.valid_at > resolved_edge.valid_at:
# Expire new edge since we have information about more recent events
resolved_edge.invalid_at = candidate.valid_at
resolved_edge.expired_at = now
break
# Determine which contradictory edges need to be expired
invalidated_edges: list[EntityEdge] = []
for edge in invalidation_candidates:
# (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
if (
edge.invalid_at is not None
and resolved_edge.valid_at is not None
and edge.invalid_at < resolved_edge.valid_at
) or (
edge.valid_at is not None
and resolved_edge.invalid_at is not None
and resolved_edge.invalid_at < edge.valid_at
):
continue
# New edge invalidates edge
elif (
edge.valid_at is not None
and resolved_edge.valid_at is not None
and edge.valid_at < resolved_edge.valid_at
):
edge.invalid_at = resolved_edge.valid_at
edge.expired_at = edge.expired_at if edge.expired_at is not None else now
invalidated_edges.append(edge)
return resolved_edge, invalidated_edges
async def dedupe_extracted_edge(
llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge]
) -> EntityEdge: ) -> EntityEdge:
start = time() start = time()
# Prepare context for LLM # Prepare context for LLM
existing_edges_context = [ related_edges_context = [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in related_edges
] ]
extracted_edge_context = { extracted_edge_context = {
@ -180,7 +266,7 @@ async def resolve_extracted_edge(
} }
context = { context = {
'existing_edges': existing_edges_context, 'related_edges': related_edges_context,
'extracted_edges': extracted_edge_context, 'extracted_edges': extracted_edge_context,
} }
@ -191,14 +277,14 @@ async def resolve_extracted_edge(
edge = extracted_edge edge = extracted_edge
if is_duplicate: if is_duplicate:
for existing_edge in existing_edges: for existing_edge in related_edges:
if existing_edge.uuid != uuid: if existing_edge.uuid != uuid:
continue continue
edge = existing_edge edge = existing_edge
end = time() end = time()
logger.info( logger.info(
f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms' f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
) )
return edge return edge

View file

@ -16,6 +16,7 @@ limitations under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from time import time
from typing import List from typing import List
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
@ -181,3 +182,36 @@ async def extract_edge_dates(
logger.info(f'Edge date extraction explanation: {explanation}') logger.info(f'Edge date extraction explanation: {explanation}')
return valid_at_datetime, invalid_at_datetime return valid_at_datetime, invalid_at_datetime
async def get_edge_contradictions(
llm_client: LLMClient, new_edge: EntityEdge, existing_edges: list[EntityEdge]
) -> list[EntityEdge]:
start = time()
existing_edge_map = {edge.uuid: edge for edge in existing_edges}
new_edge_context = {'uuid': new_edge.uuid, 'name': new_edge.name, 'fact': new_edge.fact}
existing_edge_context = [
{'uuid': existing_edge.uuid, 'name': existing_edge.name, 'fact': existing_edge.fact}
for existing_edge in existing_edges
]
context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context}
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v2(context))
contradicted_edge_data = llm_response.get('invalidated_edges', [])
contradicted_edges: list[EntityEdge] = []
for edge_data in contradicted_edge_data:
if edge_data['uuid'] in existing_edge_map:
contradicted_edge = existing_edge_map[edge_data['uuid']]
contradicted_edge.fact = edge_data['fact']
contradicted_edges.append(contradicted_edge)
end = time()
logger.info(
f'Found invalidated edge candidates from {new_edge.fact}, in {(end - start) * 1000} ms'
)
return contradicted_edges