Node dedupe efficiency (#490)
* update resolve extracted edge * updated edge resolution * dedupe nodes update * single pass node resolution * updates * mypy updates * Update graphiti_core/prompts/dedupe_nodes.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * remove unused imports * mypy --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
f096c8770c
commit
9422b6f5fb
9 changed files with 236 additions and 141 deletions
|
|
@ -41,6 +41,7 @@ from graphiti_core.search.search_config_recipes import (
|
||||||
from graphiti_core.search.search_filters import SearchFilters
|
from graphiti_core.search.search_filters import SearchFilters
|
||||||
from graphiti_core.search.search_utils import (
|
from graphiti_core.search.search_utils import (
|
||||||
RELEVANT_SCHEMA_LIMIT,
|
RELEVANT_SCHEMA_LIMIT,
|
||||||
|
get_edge_invalidation_candidates,
|
||||||
get_mentioned_nodes,
|
get_mentioned_nodes,
|
||||||
get_relevant_edges,
|
get_relevant_edges,
|
||||||
)
|
)
|
||||||
|
|
@ -62,9 +63,8 @@ from graphiti_core.utils.maintenance.community_operations import (
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
build_episodic_edges,
|
build_episodic_edges,
|
||||||
dedupe_extracted_edge,
|
|
||||||
extract_edges,
|
extract_edges,
|
||||||
resolve_edge_contradictions,
|
resolve_extracted_edge,
|
||||||
resolve_extracted_edges,
|
resolve_extracted_edges,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||||
|
|
@ -77,7 +77,6 @@ from graphiti_core.utils.maintenance.node_operations import (
|
||||||
extract_nodes,
|
extract_nodes,
|
||||||
resolve_extracted_nodes,
|
resolve_extracted_nodes,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
|
|
||||||
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
|
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -681,17 +680,15 @@ class Graphiti:
|
||||||
|
|
||||||
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
||||||
|
|
||||||
related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
|
related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
|
||||||
|
existing_edges = (
|
||||||
|
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
|
||||||
|
)[0]
|
||||||
|
|
||||||
resolved_edge = await dedupe_extracted_edge(
|
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
||||||
self.llm_client,
|
self.llm_client, updated_edge, related_edges, existing_edges
|
||||||
updated_edge,
|
|
||||||
related_edges[0],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
|
|
||||||
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
|
||||||
|
|
||||||
await add_nodes_and_edges_bulk(
|
await add_nodes_and_edges_bulk(
|
||||||
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
|
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -87,8 +87,8 @@ def normalize_l2(embedding: list[float]) -> NDArray:
|
||||||
|
|
||||||
# Use this instead of asyncio.gather() to bound coroutines
|
# Use this instead of asyncio.gather() to bound coroutines
|
||||||
async def semaphore_gather(
|
async def semaphore_gather(
|
||||||
*coroutines: Coroutine,
|
*coroutines: Coroutine,
|
||||||
max_coroutines: int = SEMAPHORE_LIMIT,
|
max_coroutines: int = SEMAPHORE_LIMIT,
|
||||||
):
|
):
|
||||||
semaphore = asyncio.Semaphore(max_coroutines)
|
semaphore = asyncio.Semaphore(max_coroutines)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,10 @@ class EdgeDuplicate(BaseModel):
|
||||||
...,
|
...,
|
||||||
description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
|
description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
|
||||||
)
|
)
|
||||||
|
contradicted_facts: list[int] = Field(
|
||||||
|
...,
|
||||||
|
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UniqueFact(BaseModel):
|
class UniqueFact(BaseModel):
|
||||||
|
|
@ -41,11 +45,13 @@ class UniqueFacts(BaseModel):
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
edge: PromptVersion
|
edge: PromptVersion
|
||||||
edge_list: PromptVersion
|
edge_list: PromptVersion
|
||||||
|
resolve_edge: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
edge: PromptFunction
|
edge: PromptFunction
|
||||||
edge_list: PromptFunction
|
edge_list: PromptFunction
|
||||||
|
resolve_edge: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def edge(context: dict[str, Any]) -> list[Message]:
|
def edge(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
@ -106,4 +112,41 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {'edge': edge, 'edge_list': edge_list}
|
def resolve_edge(context: dict[str, Any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role='system',
|
||||||
|
content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
|
||||||
|
'facts are contradicted by the new fact.',
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role='user',
|
||||||
|
content=f"""
|
||||||
|
<NEW FACT>
|
||||||
|
{context['new_edge']}
|
||||||
|
</NEW FACT>
|
||||||
|
|
||||||
|
<EXISTING FACTS>
|
||||||
|
{context['existing_edges']}
|
||||||
|
</EXISTING FACTS>
|
||||||
|
<FACT INVALIDATION CANDIDATES>
|
||||||
|
{context['edge_invalidation_candidates']}
|
||||||
|
</FACT INVALIDATION CANDIDATES>
|
||||||
|
|
||||||
|
|
||||||
|
Task:
|
||||||
|
If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
|
||||||
|
If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
|
||||||
|
|
||||||
|
Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
|
||||||
|
Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
|
||||||
|
If there are no contradicted facts, return an empty list.
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}
|
||||||
|
|
|
||||||
|
|
@ -23,21 +23,31 @@ from .models import Message, PromptFunction, PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class NodeDuplicate(BaseModel):
|
class NodeDuplicate(BaseModel):
|
||||||
duplicate_node_id: int = Field(
|
id: int = Field(..., description='integer id of the entity')
|
||||||
|
duplicate_idx: int = Field(
|
||||||
...,
|
...,
|
||||||
description='id of the duplicate node. If no duplicate nodes are found, default to -1.',
|
description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
|
||||||
)
|
)
|
||||||
name: str = Field(..., description='Name of the entity.')
|
name: str = Field(
|
||||||
|
...,
|
||||||
|
description='Name of the entity. Should be the most complete and descriptive name possible.',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NodeResolutions(BaseModel):
|
||||||
|
entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
|
||||||
|
|
||||||
|
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
node: PromptVersion
|
node: PromptVersion
|
||||||
node_list: PromptVersion
|
node_list: PromptVersion
|
||||||
|
nodes: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
node: PromptFunction
|
node: PromptFunction
|
||||||
node_list: PromptFunction
|
node_list: PromptFunction
|
||||||
|
nodes: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def node(context: dict[str, Any]) -> list[Message]:
|
def node(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
@ -89,6 +99,67 @@ def node(context: dict[str, Any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def nodes(context: dict[str, Any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role='system',
|
||||||
|
content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
|
||||||
|
'of existing entities.',
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role='user',
|
||||||
|
content=f"""
|
||||||
|
<PREVIOUS MESSAGES>
|
||||||
|
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
||||||
|
</PREVIOUS MESSAGES>
|
||||||
|
<CURRENT MESSAGE>
|
||||||
|
{context['episode_content']}
|
||||||
|
</CURRENT MESSAGE>
|
||||||
|
|
||||||
|
|
||||||
|
Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
|
||||||
|
Each entity in ENTITIES is represented as a JSON object with the following structure:
|
||||||
|
{{
|
||||||
|
id: integer id of the entity,
|
||||||
|
name: "name of the entity",
|
||||||
|
entity_type: "ontological classification of the entity",
|
||||||
|
entity_type_description: "Description of what the entity type represents",
|
||||||
|
duplication_candidates: [
|
||||||
|
{{
|
||||||
|
idx: integer index of the candidate entity,
|
||||||
|
name: "name of the candidate entity",
|
||||||
|
entity_type: "ontological classification of the candidate entity",
|
||||||
|
...<additional attributes>
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
<ENTITIES>
|
||||||
|
{json.dumps(context['extracted_nodes'], indent=2)}
|
||||||
|
</ENTITIES>
|
||||||
|
|
||||||
|
For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
|
||||||
|
|
||||||
|
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
|
||||||
|
|
||||||
|
Do NOT mark entities as duplicates if:
|
||||||
|
- They are related but distinct.
|
||||||
|
- They have similar names or purposes but refer to separate instances or concepts.
|
||||||
|
|
||||||
|
Task:
|
||||||
|
Your response will be a list called entity_resolutions which contains one entry for each entity.
|
||||||
|
|
||||||
|
For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
|
||||||
|
as an integer.
|
||||||
|
|
||||||
|
- If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
|
||||||
|
duplicate of.
|
||||||
|
- If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def node_list(context: dict[str, Any]) -> list[Message]:
|
def node_list(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
|
|
@ -126,4 +197,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {'node': node, 'node_list': node_list}
|
versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion
|
||||||
class InvalidatedEdges(BaseModel):
|
class InvalidatedEdges(BaseModel):
|
||||||
contradicted_facts: list[int] = Field(
|
contradicted_facts: list[int] = Field(
|
||||||
...,
|
...,
|
||||||
description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.',
|
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,9 +35,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
||||||
from graphiti_core.search.search_filters import SearchFilters
|
from graphiti_core.search.search_filters import SearchFilters
|
||||||
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
||||||
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
||||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
|
||||||
get_edge_contradictions,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -245,7 +242,7 @@ async def resolve_extracted_edges(
|
||||||
|
|
||||||
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
||||||
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
||||||
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
|
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
|
||||||
)
|
)
|
||||||
|
|
||||||
related_edges_lists, edge_invalidation_candidates = search_results
|
related_edges_lists, edge_invalidation_candidates = search_results
|
||||||
|
|
@ -325,11 +322,52 @@ async def resolve_extracted_edge(
|
||||||
extracted_edge: EntityEdge,
|
extracted_edge: EntityEdge,
|
||||||
related_edges: list[EntityEdge],
|
related_edges: list[EntityEdge],
|
||||||
existing_edges: list[EntityEdge],
|
existing_edges: list[EntityEdge],
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode | None = None,
|
||||||
) -> tuple[EntityEdge, list[EntityEdge]]:
|
) -> tuple[EntityEdge, list[EntityEdge]]:
|
||||||
resolved_edge, invalidation_candidates = await semaphore_gather(
|
if len(related_edges) == 0 and len(existing_edges) == 0:
|
||||||
dedupe_extracted_edge(llm_client, extracted_edge, related_edges, episode),
|
return extracted_edge, []
|
||||||
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
|
||||||
|
start = time()
|
||||||
|
|
||||||
|
# Prepare context for LLM
|
||||||
|
related_edges_context = [
|
||||||
|
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
||||||
|
]
|
||||||
|
|
||||||
|
invalidation_edge_candidates_context = [
|
||||||
|
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
||||||
|
]
|
||||||
|
|
||||||
|
context = {
|
||||||
|
'existing_edges': related_edges_context,
|
||||||
|
'new_edge': extracted_edge.fact,
|
||||||
|
'edge_invalidation_candidates': invalidation_edge_candidates_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(
|
||||||
|
prompt_library.dedupe_edges.resolve_edge(context),
|
||||||
|
response_model=EdgeDuplicate,
|
||||||
|
model_size=ModelSize.small,
|
||||||
|
)
|
||||||
|
|
||||||
|
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
||||||
|
|
||||||
|
resolved_edge = (
|
||||||
|
related_edges[duplicate_fact_id]
|
||||||
|
if 0 <= duplicate_fact_id < len(related_edges)
|
||||||
|
else extracted_edge
|
||||||
|
)
|
||||||
|
|
||||||
|
if duplicate_fact_id >= 0 and episode is not None:
|
||||||
|
resolved_edge.episodes.append(episode.uuid)
|
||||||
|
|
||||||
|
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
||||||
|
|
||||||
|
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
|
||||||
|
|
||||||
|
end = time()
|
||||||
|
logger.debug(
|
||||||
|
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
|
||||||
)
|
)
|
||||||
|
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.llm_client.config import ModelSize
|
from graphiti_core.llm_client.config import ModelSize
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
|
||||||
from graphiti_core.prompts.extract_nodes import (
|
from graphiti_core.prompts.extract_nodes import (
|
||||||
ExtractedEntities,
|
ExtractedEntities,
|
||||||
ExtractedEntity,
|
ExtractedEntity,
|
||||||
|
|
@ -243,28 +243,65 @@ async def resolve_extracted_nodes(
|
||||||
|
|
||||||
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
|
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
|
||||||
|
|
||||||
resolved_nodes: list[EntityNode] = await semaphore_gather(
|
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
|
||||||
*[
|
|
||||||
resolve_extracted_node(
|
# Prepare context for LLM
|
||||||
llm_client,
|
extracted_nodes_context = [
|
||||||
extracted_node,
|
{
|
||||||
existing_nodes,
|
'id': i,
|
||||||
episode,
|
'name': node.name,
|
||||||
previous_episodes,
|
'entity_type': node.labels,
|
||||||
entity_types.get(
|
'entity_type_description': entity_types_dict.get(
|
||||||
next((item for item in extracted_node.labels if item != 'Entity'), '')
|
next((item for item in node.labels if item != 'Entity'), '')
|
||||||
)
|
).__doc__
|
||||||
if entity_types is not None
|
or 'Default Entity Type',
|
||||||
else None,
|
'duplication_candidates': [
|
||||||
)
|
{
|
||||||
for extracted_node, existing_nodes in zip(
|
**{
|
||||||
extracted_nodes, existing_nodes_lists, strict=True
|
'idx': j,
|
||||||
)
|
'name': candidate.name,
|
||||||
]
|
'entity_types': candidate.labels,
|
||||||
|
},
|
||||||
|
**candidate.attributes,
|
||||||
|
}
|
||||||
|
for j, candidate in enumerate(existing_nodes_lists[i])
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for i, node in enumerate(extracted_nodes)
|
||||||
|
]
|
||||||
|
|
||||||
|
context = {
|
||||||
|
'extracted_nodes': extracted_nodes_context,
|
||||||
|
'episode_content': episode.content if episode is not None else '',
|
||||||
|
'previous_episodes': [ep.content for ep in previous_episodes]
|
||||||
|
if previous_episodes is not None
|
||||||
|
else [],
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(
|
||||||
|
prompt_library.dedupe_nodes.nodes(context),
|
||||||
|
response_model=NodeResolutions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
node_resolutions: list = llm_response.get('entity_resolutions', [])
|
||||||
|
|
||||||
|
resolved_nodes: list[EntityNode] = []
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
for extracted_node, resolved_node in zip(extracted_nodes, resolved_nodes, strict=True):
|
for resolution in node_resolutions:
|
||||||
|
resolution_id = resolution.get('id', -1)
|
||||||
|
duplicate_idx = resolution.get('duplicate_idx', -1)
|
||||||
|
|
||||||
|
extracted_node = extracted_nodes[resolution_id]
|
||||||
|
|
||||||
|
resolved_node = (
|
||||||
|
existing_nodes_lists[resolution_id][duplicate_idx]
|
||||||
|
if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
|
||||||
|
else extracted_node
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_node.name = resolution.get('name')
|
||||||
|
|
||||||
|
resolved_nodes.append(resolved_node)
|
||||||
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||||
|
|
||||||
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
||||||
|
|
@ -410,6 +447,7 @@ async def extract_attributes_from_node(
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_nodes.extract_attributes(summary_context),
|
prompt_library.extract_nodes.extract_attributes(summary_context),
|
||||||
response_model=entity_attributes_model,
|
response_model=entity_attributes_model,
|
||||||
|
model_size=ModelSize.small,
|
||||||
)
|
)
|
||||||
|
|
||||||
node.summary = llm_response.get('summary', node.summary)
|
node.summary = llm_response.get('summary', node.summary)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.11.6pre9"
|
version = "0.11.6"
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,10 @@
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest import MonkeyPatch
|
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.nodes import EpisodicNode
|
from graphiti_core.nodes import EpisodicNode
|
||||||
from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -91,96 +89,6 @@ def mock_previous_episodes():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_resolve_extracted_edge_no_changes(
|
|
||||||
mock_llm_client,
|
|
||||||
mock_extracted_edge,
|
|
||||||
mock_related_edges,
|
|
||||||
mock_existing_edges,
|
|
||||||
mock_current_episode,
|
|
||||||
mock_previous_episodes,
|
|
||||||
monkeypatch: MonkeyPatch,
|
|
||||||
):
|
|
||||||
# Mock the function calls
|
|
||||||
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
|
|
||||||
get_contradictions_mock = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
# Patch the function calls
|
|
||||||
monkeypatch.setattr(
|
|
||||||
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
|
|
||||||
get_contradictions_mock,
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
|
||||||
mock_llm_client,
|
|
||||||
mock_extracted_edge,
|
|
||||||
mock_related_edges,
|
|
||||||
mock_existing_edges,
|
|
||||||
mock_current_episode,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert resolved_edge.uuid == mock_extracted_edge.uuid
|
|
||||||
assert invalidated_edges == []
|
|
||||||
dedupe_mock.assert_called_once()
|
|
||||||
get_contradictions_mock.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_resolve_extracted_edge_with_invalidation(
|
|
||||||
mock_llm_client,
|
|
||||||
mock_extracted_edge,
|
|
||||||
mock_related_edges,
|
|
||||||
mock_existing_edges,
|
|
||||||
mock_current_episode,
|
|
||||||
mock_previous_episodes,
|
|
||||||
monkeypatch: MonkeyPatch,
|
|
||||||
):
|
|
||||||
valid_at = datetime.now(timezone.utc) - timedelta(days=1)
|
|
||||||
mock_extracted_edge.valid_at = valid_at
|
|
||||||
|
|
||||||
invalidation_candidate = EntityEdge(
|
|
||||||
source_node_uuid='source_uuid_4',
|
|
||||||
target_node_uuid='target_uuid_4',
|
|
||||||
name='invalidation_candidate',
|
|
||||||
group_id='group_1',
|
|
||||||
fact='Invalidation candidate fact',
|
|
||||||
episodes=['episode_4'],
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
|
|
||||||
invalid_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the function calls
|
|
||||||
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
|
|
||||||
get_contradictions_mock = AsyncMock(return_value=[invalidation_candidate])
|
|
||||||
|
|
||||||
# Patch the function calls
|
|
||||||
monkeypatch.setattr(
|
|
||||||
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
|
|
||||||
get_contradictions_mock,
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
|
||||||
mock_llm_client,
|
|
||||||
mock_extracted_edge,
|
|
||||||
mock_related_edges,
|
|
||||||
mock_existing_edges,
|
|
||||||
mock_current_episode,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert resolved_edge.uuid == mock_extracted_edge.uuid
|
|
||||||
assert len(invalidated_edges) == 1
|
|
||||||
assert invalidated_edges[0].uuid == invalidation_candidate.uuid
|
|
||||||
assert invalidated_edges[0].invalid_at == valid_at
|
|
||||||
assert invalidated_edges[0].expired_at is not None
|
|
||||||
|
|
||||||
|
|
||||||
# Run the tests
|
# Run the tests
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue