Add episode refactor (#85)
* temp commit while moving * fix name embedding bug * invalidation * format * tests on runner examples * format * ellipsis * ruff * fix * format * minor prompt change
This commit is contained in:
parent
1d31442751
commit
299021173b
8 changed files with 261 additions and 106 deletions
|
|
@ -94,7 +94,7 @@ async def main():
|
||||||
|
|
||||||
async def ingest_products_data(client: Graphiti):
|
async def ingest_products_data(client: Graphiti):
|
||||||
script_dir = Path(__file__).parent
|
script_dir = Path(__file__).parent
|
||||||
json_file_path = script_dir / 'allbirds_products.json'
|
json_file_path = script_dir / '../data/manybirds_products.json'
|
||||||
|
|
||||||
with open(json_file_path) as file:
|
with open(json_file_path) as file:
|
||||||
products = json.load(file)['products']
|
products = json.load(file)['products']
|
||||||
|
|
@ -110,7 +110,14 @@ async def ingest_products_data(client: Graphiti):
|
||||||
for i, product in enumerate(products)
|
for i, product in enumerate(products)
|
||||||
]
|
]
|
||||||
|
|
||||||
await client.add_episode_bulk(episodes)
|
for episode in episodes:
|
||||||
|
await client.add_episode(
|
||||||
|
episode.name,
|
||||||
|
episode.content,
|
||||||
|
episode.source_description,
|
||||||
|
episode.reference_time,
|
||||||
|
episode.source,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
|
|
@ -86,4 +86,4 @@ async def main(use_bulk: bool = True):
|
||||||
await client.add_episode_bulk(episodes)
|
await client.add_episode_bulk(episodes)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main(True))
|
asyncio.run(main(False))
|
||||||
|
|
|
||||||
|
|
@ -59,11 +59,6 @@ from graphiti_core.utils.maintenance.node_operations import (
|
||||||
extract_nodes,
|
extract_nodes,
|
||||||
resolve_extracted_nodes,
|
resolve_extracted_nodes,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
|
||||||
extract_edge_dates,
|
|
||||||
invalidate_edges,
|
|
||||||
prepare_edges_for_invalidation,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -293,7 +288,7 @@ class Graphiti:
|
||||||
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resolve extracted nodes with nodes already in the graph
|
# Resolve extracted nodes with nodes already in the graph and extract facts
|
||||||
existing_nodes_lists: list[list[EntityNode]] = list(
|
existing_nodes_lists: list[list[EntityNode]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
|
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
|
||||||
|
|
@ -302,22 +297,27 @@ class Graphiti:
|
||||||
|
|
||||||
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||||
|
|
||||||
mentioned_nodes, _ = await resolve_extracted_nodes(
|
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
||||||
self.llm_client, extracted_nodes, existing_nodes_lists
|
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
|
||||||
|
extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
|
||||||
)
|
)
|
||||||
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
||||||
nodes.extend(mentioned_nodes)
|
nodes.extend(mentioned_nodes)
|
||||||
|
|
||||||
# Extract facts as edges given entity nodes
|
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
|
||||||
extracted_edges = await extract_edges(
|
extracted_edges, uuid_map
|
||||||
self.llm_client, episode, mentioned_nodes, previous_episodes
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# calculate embeddings
|
# calculate embeddings
|
||||||
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
edge.generate_embedding(embedder)
|
||||||
|
for edge in extracted_edges_with_resolved_pointers
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Resolve extracted edges with edges already in the graph
|
# Resolve extracted edges with related edges already in the graph
|
||||||
existing_edges_list: list[list[EntityEdge]] = list(
|
related_edges_list: list[list[EntityEdge]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
get_relevant_edges(
|
get_relevant_edges(
|
||||||
|
|
@ -327,74 +327,66 @@ class Graphiti:
|
||||||
edge.target_node_uuid,
|
edge.target_node_uuid,
|
||||||
RELEVANT_SCHEMA_LIMIT,
|
RELEVANT_SCHEMA_LIMIT,
|
||||||
)
|
)
|
||||||
for edge in extracted_edges
|
for edge in extracted_edges_with_resolved_pointers
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
|
||||||
)
|
)
|
||||||
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
|
logger.info(
|
||||||
|
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
|
||||||
deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
|
|
||||||
self.llm_client, extracted_edges, existing_edges_list
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract dates for the newly extracted edges
|
existing_source_edges_list: list[list[EntityEdge]] = list(
|
||||||
edge_dates = await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
extract_edge_dates(
|
get_relevant_edges(
|
||||||
self.llm_client,
|
self.driver,
|
||||||
edge,
|
[edge],
|
||||||
episode,
|
edge.source_node_uuid,
|
||||||
previous_episodes,
|
None,
|
||||||
)
|
RELEVANT_SCHEMA_LIMIT,
|
||||||
for edge in deduped_edges
|
)
|
||||||
]
|
for edge in extracted_edges_with_resolved_pointers
|
||||||
|
]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, edge in enumerate(deduped_edges):
|
existing_target_edges_list: list[list[EntityEdge]] = list(
|
||||||
valid_at = edge_dates[i][0]
|
await asyncio.gather(
|
||||||
invalid_at = edge_dates[i][1]
|
*[
|
||||||
|
get_relevant_edges(
|
||||||
|
self.driver,
|
||||||
|
[edge],
|
||||||
|
None,
|
||||||
|
edge.target_node_uuid,
|
||||||
|
RELEVANT_SCHEMA_LIMIT,
|
||||||
|
)
|
||||||
|
for edge in extracted_edges_with_resolved_pointers
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
edge.valid_at = valid_at
|
existing_edges_list: list[list[EntityEdge]] = [
|
||||||
edge.invalid_at = invalid_at
|
source_lst + target_lst
|
||||||
if edge.invalid_at is not None:
|
for source_lst, target_lst in zip(
|
||||||
edge.expired_at = now
|
existing_source_edges_list, existing_target_edges_list
|
||||||
|
)
|
||||||
entity_edges.extend(deduped_edges)
|
|
||||||
|
|
||||||
existing_edges: list[EntityEdge] = [
|
|
||||||
e for edge_lst in existing_edges_list for e in edge_lst
|
|
||||||
]
|
]
|
||||||
|
|
||||||
(
|
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
||||||
old_edges_with_nodes_pending_invalidation,
|
|
||||||
new_edges_with_nodes,
|
|
||||||
) = prepare_edges_for_invalidation(
|
|
||||||
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
|
|
||||||
)
|
|
||||||
|
|
||||||
invalidated_edges = await invalidate_edges(
|
|
||||||
self.llm_client,
|
self.llm_client,
|
||||||
old_edges_with_nodes_pending_invalidation,
|
extracted_edges_with_resolved_pointers,
|
||||||
new_edges_with_nodes,
|
related_edges_list,
|
||||||
|
existing_edges_list,
|
||||||
episode,
|
episode,
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
for edge in invalidated_edges:
|
entity_edges.extend(resolved_edges + invalidated_edges)
|
||||||
for existing_edge in existing_edges:
|
|
||||||
if existing_edge.uuid == edge.uuid:
|
|
||||||
existing_edge.expired_at = edge.expired_at
|
|
||||||
for deduped_edge in deduped_edges:
|
|
||||||
if deduped_edge.uuid == edge.uuid:
|
|
||||||
deduped_edge.expired_at = edge.expired_at
|
|
||||||
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
|
|
||||||
|
|
||||||
entity_edges.extend(existing_edges)
|
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
||||||
|
|
||||||
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
|
|
||||||
|
|
||||||
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
|
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
|
||||||
mentioned_nodes,
|
mentioned_nodes,
|
||||||
|
|
|
||||||
|
|
@ -129,7 +129,7 @@ def v3(context: dict[str, Any]) -> list[Message]:
|
||||||
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
|
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
|
||||||
|
|
||||||
Existing Edges:
|
Existing Edges:
|
||||||
{json.dumps(context['existing_edges'], indent=2)}
|
{json.dumps(context['related_edges'], indent=2)}
|
||||||
|
|
||||||
New Edge:
|
New Edge:
|
||||||
{json.dumps(context['extracted_edges'], indent=2)}
|
{json.dumps(context['extracted_edges'], indent=2)}
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,12 @@ from .models import Message, PromptFunction, PromptVersion
|
||||||
|
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
v1: PromptVersion
|
v1: PromptVersion
|
||||||
|
v2: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
v1: PromptFunction
|
v1: PromptFunction
|
||||||
|
v2: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def v1(context: dict[str, Any]) -> list[Message]:
|
def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
@ -71,4 +73,38 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {'v1': v1}
|
def v2(context: dict[str, Any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role='system',
|
||||||
|
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.',
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role='user',
|
||||||
|
content=f"""
|
||||||
|
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
|
||||||
|
|
||||||
|
Existing Edges:
|
||||||
|
{context['existing_edges']}
|
||||||
|
|
||||||
|
New Edge:
|
||||||
|
{context['new_edge']}
|
||||||
|
|
||||||
|
|
||||||
|
For each existing edge that should be invalidated, respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"invalidated_edges": [
|
||||||
|
{{
|
||||||
|
"uuid": "The UUID of the edge to be invalidated",
|
||||||
|
"fact": "Updated fact of the edge"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
versions: Versions = {'v1': v1, 'v2': v2}
|
||||||
|
|
|
||||||
|
|
@ -96,11 +96,11 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||||
|
|
||||||
|
|
||||||
async def edge_similarity_search(
|
async def edge_similarity_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
search_vector: list[float],
|
search_vector: list[float],
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
query = Query("""
|
query = Query("""
|
||||||
|
|
@ -211,7 +211,7 @@ async def edge_similarity_search(
|
||||||
|
|
||||||
|
|
||||||
async def entity_similarity_search(
|
async def entity_similarity_search(
|
||||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -247,7 +247,7 @@ async def entity_similarity_search(
|
||||||
|
|
||||||
|
|
||||||
async def entity_fulltext_search(
|
async def entity_fulltext_search(
|
||||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# BM25 search to get top nodes
|
# BM25 search to get top nodes
|
||||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||||
|
|
@ -284,11 +284,11 @@ async def entity_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def edge_fulltext_search(
|
async def edge_fulltext_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
query: str,
|
query: str,
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# fulltext search over facts
|
# fulltext search over facts
|
||||||
cypher_query = Query("""
|
cypher_query = Query("""
|
||||||
|
|
@ -401,10 +401,10 @@ async def edge_fulltext_search(
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_node_search(
|
async def hybrid_node_search(
|
||||||
queries: list[str],
|
queries: list[str],
|
||||||
embeddings: list[list[float]],
|
embeddings: list[list[float]],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Perform a hybrid search for nodes using both text queries and embeddings.
|
Perform a hybrid search for nodes using both text queries and embeddings.
|
||||||
|
|
@ -466,8 +466,8 @@ async def hybrid_node_search(
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_nodes(
|
async def get_relevant_nodes(
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
"""
|
"""
|
||||||
Retrieve relevant nodes based on the provided list of EntityNodes.
|
Retrieve relevant nodes based on the provided list of EntityNodes.
|
||||||
|
|
@ -503,11 +503,11 @@ async def get_relevant_nodes(
|
||||||
|
|
||||||
|
|
||||||
async def get_relevant_edges(
|
async def get_relevant_edges(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
edges: list[EntityEdge],
|
edges: list[EntityEdge],
|
||||||
source_node_uuid: str | None,
|
source_node_uuid: str | None,
|
||||||
target_node_uuid: str | None,
|
target_node_uuid: str | None,
|
||||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
start = time()
|
start = time()
|
||||||
relevant_edges: list[EntityEdge] = []
|
relevant_edges: list[EntityEdge] = []
|
||||||
|
|
@ -557,7 +557,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||||
|
|
||||||
|
|
||||||
async def node_distance_reranker(
|
async def node_distance_reranker(
|
||||||
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# use rrf as a preliminary ranker
|
# use rrf as a preliminary ranker
|
||||||
sorted_uuids = rrf(results)
|
sorted_uuids = rrf(results)
|
||||||
|
|
@ -579,8 +579,8 @@ async def node_distance_reranker(
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
if (
|
if (
|
||||||
record['source_uuid'] == center_node_uuid
|
record['source_uuid'] == center_node_uuid
|
||||||
or record['target_uuid'] == center_node_uuid
|
or record['target_uuid'] == center_node_uuid
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
distance = record['score']
|
distance = record['score']
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,10 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
|
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||||
|
extract_edge_dates,
|
||||||
|
get_edge_contradictions,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -149,28 +153,110 @@ async def dedupe_extracted_edges(
|
||||||
async def resolve_extracted_edges(
|
async def resolve_extracted_edges(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
extracted_edges: list[EntityEdge],
|
extracted_edges: list[EntityEdge],
|
||||||
|
related_edges_lists: list[list[EntityEdge]],
|
||||||
existing_edges_lists: list[list[EntityEdge]],
|
existing_edges_lists: list[list[EntityEdge]],
|
||||||
) -> list[EntityEdge]:
|
current_episode: EpisodicNode,
|
||||||
resolved_edges: list[EntityEdge] = list(
|
previous_episodes: list[EpisodicNode],
|
||||||
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
||||||
|
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
|
||||||
|
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
resolve_extracted_edge(llm_client, extracted_edge, existing_edges)
|
resolve_extracted_edge(
|
||||||
for extracted_edge, existing_edges in zip(extracted_edges, existing_edges_lists)
|
llm_client,
|
||||||
|
extracted_edge,
|
||||||
|
related_edges,
|
||||||
|
existing_edges,
|
||||||
|
current_episode,
|
||||||
|
previous_episodes,
|
||||||
|
)
|
||||||
|
for extracted_edge, related_edges, existing_edges in zip(
|
||||||
|
extracted_edges, related_edges_lists, existing_edges_lists
|
||||||
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return resolved_edges
|
resolved_edges: list[EntityEdge] = []
|
||||||
|
invalidated_edges: list[EntityEdge] = []
|
||||||
|
for result in results:
|
||||||
|
resolved_edge = result[0]
|
||||||
|
invalidated_edge_chunk = result[1]
|
||||||
|
|
||||||
|
resolved_edges.append(resolved_edge)
|
||||||
|
invalidated_edges.extend(invalidated_edge_chunk)
|
||||||
|
|
||||||
|
return resolved_edges, invalidated_edges
|
||||||
|
|
||||||
|
|
||||||
async def resolve_extracted_edge(
|
async def resolve_extracted_edge(
|
||||||
llm_client: LLMClient, extracted_edge: EntityEdge, existing_edges: list[EntityEdge]
|
llm_client: LLMClient,
|
||||||
|
extracted_edge: EntityEdge,
|
||||||
|
related_edges: list[EntityEdge],
|
||||||
|
existing_edges: list[EntityEdge],
|
||||||
|
current_episode: EpisodicNode,
|
||||||
|
previous_episodes: list[EpisodicNode],
|
||||||
|
) -> tuple[EntityEdge, list[EntityEdge]]:
|
||||||
|
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
|
||||||
|
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
|
||||||
|
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
|
||||||
|
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
||||||
|
)
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
|
||||||
|
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
|
||||||
|
if invalid_at is not None and resolved_edge.expired_at is None:
|
||||||
|
resolved_edge.expired_at = now
|
||||||
|
|
||||||
|
# Determine if the new_edge needs to be expired
|
||||||
|
if resolved_edge.expired_at is None:
|
||||||
|
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
||||||
|
for candidate in invalidation_candidates:
|
||||||
|
if (
|
||||||
|
candidate.valid_at is not None and resolved_edge.valid_at is not None
|
||||||
|
) and candidate.valid_at > resolved_edge.valid_at:
|
||||||
|
# Expire new edge since we have information about more recent events
|
||||||
|
resolved_edge.invalid_at = candidate.valid_at
|
||||||
|
resolved_edge.expired_at = now
|
||||||
|
break
|
||||||
|
|
||||||
|
# Determine which contradictory edges need to be expired
|
||||||
|
invalidated_edges: list[EntityEdge] = []
|
||||||
|
for edge in invalidation_candidates:
|
||||||
|
# (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
|
||||||
|
if (
|
||||||
|
edge.invalid_at is not None
|
||||||
|
and resolved_edge.valid_at is not None
|
||||||
|
and edge.invalid_at < resolved_edge.valid_at
|
||||||
|
) or (
|
||||||
|
edge.valid_at is not None
|
||||||
|
and resolved_edge.invalid_at is not None
|
||||||
|
and resolved_edge.invalid_at < edge.valid_at
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
# New edge invalidates edge
|
||||||
|
elif (
|
||||||
|
edge.valid_at is not None
|
||||||
|
and resolved_edge.valid_at is not None
|
||||||
|
and edge.valid_at < resolved_edge.valid_at
|
||||||
|
):
|
||||||
|
edge.invalid_at = resolved_edge.valid_at
|
||||||
|
edge.expired_at = edge.expired_at if edge.expired_at is not None else now
|
||||||
|
invalidated_edges.append(edge)
|
||||||
|
|
||||||
|
return resolved_edge, invalidated_edges
|
||||||
|
|
||||||
|
|
||||||
|
async def dedupe_extracted_edge(
|
||||||
|
llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge]
|
||||||
) -> EntityEdge:
|
) -> EntityEdge:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
existing_edges_context = [
|
related_edges_context = [
|
||||||
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
|
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in related_edges
|
||||||
]
|
]
|
||||||
|
|
||||||
extracted_edge_context = {
|
extracted_edge_context = {
|
||||||
|
|
@ -180,7 +266,7 @@ async def resolve_extracted_edge(
|
||||||
}
|
}
|
||||||
|
|
||||||
context = {
|
context = {
|
||||||
'existing_edges': existing_edges_context,
|
'related_edges': related_edges_context,
|
||||||
'extracted_edges': extracted_edge_context,
|
'extracted_edges': extracted_edge_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -191,14 +277,14 @@ async def resolve_extracted_edge(
|
||||||
|
|
||||||
edge = extracted_edge
|
edge = extracted_edge
|
||||||
if is_duplicate:
|
if is_duplicate:
|
||||||
for existing_edge in existing_edges:
|
for existing_edge in related_edges:
|
||||||
if existing_edge.uuid != uuid:
|
if existing_edge.uuid != uuid:
|
||||||
continue
|
continue
|
||||||
edge = existing_edge
|
edge = existing_edge
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(
|
logger.info(
|
||||||
f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
||||||
)
|
)
|
||||||
|
|
||||||
return edge
|
return edge
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from time import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
|
|
@ -181,3 +182,36 @@ async def extract_edge_dates(
|
||||||
logger.info(f'Edge date extraction explanation: {explanation}')
|
logger.info(f'Edge date extraction explanation: {explanation}')
|
||||||
|
|
||||||
return valid_at_datetime, invalid_at_datetime
|
return valid_at_datetime, invalid_at_datetime
|
||||||
|
|
||||||
|
|
||||||
|
async def get_edge_contradictions(
|
||||||
|
llm_client: LLMClient, new_edge: EntityEdge, existing_edges: list[EntityEdge]
|
||||||
|
) -> list[EntityEdge]:
|
||||||
|
start = time()
|
||||||
|
existing_edge_map = {edge.uuid: edge for edge in existing_edges}
|
||||||
|
|
||||||
|
new_edge_context = {'uuid': new_edge.uuid, 'name': new_edge.name, 'fact': new_edge.fact}
|
||||||
|
existing_edge_context = [
|
||||||
|
{'uuid': existing_edge.uuid, 'name': existing_edge.name, 'fact': existing_edge.fact}
|
||||||
|
for existing_edge in existing_edges
|
||||||
|
]
|
||||||
|
|
||||||
|
context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context}
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v2(context))
|
||||||
|
|
||||||
|
contradicted_edge_data = llm_response.get('invalidated_edges', [])
|
||||||
|
|
||||||
|
contradicted_edges: list[EntityEdge] = []
|
||||||
|
for edge_data in contradicted_edge_data:
|
||||||
|
if edge_data['uuid'] in existing_edge_map:
|
||||||
|
contradicted_edge = existing_edge_map[edge_data['uuid']]
|
||||||
|
contradicted_edge.fact = edge_data['fact']
|
||||||
|
contradicted_edges.append(contradicted_edge)
|
||||||
|
|
||||||
|
end = time()
|
||||||
|
logger.info(
|
||||||
|
f'Found invalidated edge candidates from {new_edge.fact}, in {(end - start) * 1000} ms'
|
||||||
|
)
|
||||||
|
|
||||||
|
return contradicted_edges
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue