Edge types (#501)
* update entity edge attributes * Adding prompts * extract fact attributes * edge types * edge types no regressions * mypy * mypy update * Update graphiti_core/prompts/dedupe_edges.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update graphiti_core/prompts/dedupe_edges.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * mypy --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
619c84e98b
commit
db7595fe63
8 changed files with 252 additions and 38 deletions
|
|
@ -49,7 +49,9 @@ ENTITY_EDGE_RETURN: LiteralString = """
|
|||
e.episodes AS episodes,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at"""
|
||||
e.invalid_at AS invalid_at,
|
||||
properties(e) AS attributes
|
||||
"""
|
||||
|
||||
|
||||
class Edge(BaseModel, ABC):
|
||||
|
|
@ -209,6 +211,9 @@ class EntityEdge(Edge):
|
|||
invalid_at: datetime | None = Field(
|
||||
default=None, description='datetime of when the fact stopped being true'
|
||||
)
|
||||
attributes: dict[str, Any] = Field(
|
||||
default={}, description='Additional attributes of the edge. Dependent on edge name'
|
||||
)
|
||||
|
||||
async def generate_embedding(self, embedder: EmbedderClient):
|
||||
start = time()
|
||||
|
|
@ -236,20 +241,26 @@ class EntityEdge(Edge):
|
|||
self.fact_embedding = records[0]['fact_embedding']
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
edge_data: dict[str, Any] = {
|
||||
'source_uuid': self.source_node_uuid,
|
||||
'target_uuid': self.target_node_uuid,
|
||||
'uuid': self.uuid,
|
||||
'name': self.name,
|
||||
'group_id': self.group_id,
|
||||
'fact': self.fact,
|
||||
'fact_embedding': self.fact_embedding,
|
||||
'episodes': self.episodes,
|
||||
'created_at': self.created_at,
|
||||
'expired_at': self.expired_at,
|
||||
'valid_at': self.valid_at,
|
||||
'invalid_at': self.invalid_at,
|
||||
}
|
||||
|
||||
edge_data.update(self.attributes or {})
|
||||
|
||||
result = await driver.execute_query(
|
||||
ENTITY_EDGE_SAVE,
|
||||
source_uuid=self.source_node_uuid,
|
||||
target_uuid=self.target_node_uuid,
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
fact=self.fact,
|
||||
fact_embedding=self.fact_embedding,
|
||||
episodes=self.episodes,
|
||||
created_at=self.created_at,
|
||||
expired_at=self.expired_at,
|
||||
valid_at=self.valid_at,
|
||||
invalid_at=self.invalid_at,
|
||||
edge_data=edge_data,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
|
|
@ -334,8 +345,8 @@ class EntityEdge(Edge):
|
|||
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -457,7 +468,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|||
|
||||
|
||||
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
||||
return EntityEdge(
|
||||
edge = EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
|
|
@ -469,8 +480,23 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|||
expired_at=parse_db_date(record['expired_at']),
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
attributes=record['attributes'],
|
||||
)
|
||||
|
||||
edge.attributes.pop('uuid', None)
|
||||
edge.attributes.pop('source_node_uuid', None)
|
||||
edge.attributes.pop('target_node_uuid', None)
|
||||
edge.attributes.pop('fact', None)
|
||||
edge.attributes.pop('name', None)
|
||||
edge.attributes.pop('group_id', None)
|
||||
edge.attributes.pop('episodes', None)
|
||||
edge.attributes.pop('created_at', None)
|
||||
edge.attributes.pop('expired_at', None)
|
||||
edge.attributes.pop('valid_at', None)
|
||||
edge.attributes.pop('invalid_at', None)
|
||||
|
||||
return edge
|
||||
|
||||
|
||||
def get_community_edge_from_record(record: Any):
|
||||
return CommunityEdge(
|
||||
|
|
|
|||
|
|
@ -273,6 +273,8 @@ class Graphiti:
|
|||
update_communities: bool = False,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
previous_episode_uuids: list[str] | None = None,
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
||||
) -> AddEpisodeResults:
|
||||
"""
|
||||
Process an episode and update the graph.
|
||||
|
|
@ -355,6 +357,13 @@ class Graphiti:
|
|||
)
|
||||
)
|
||||
|
||||
# Create default edge type map
|
||||
edge_type_map_default = (
|
||||
{('Entity', 'Entity'): list(edge_types.keys())}
|
||||
if edge_types is not None
|
||||
else {('Entity', 'Entity'): []}
|
||||
)
|
||||
|
||||
# Extract entities as nodes
|
||||
|
||||
extracted_nodes = await extract_nodes(
|
||||
|
|
@ -370,7 +379,9 @@ class Graphiti:
|
|||
previous_episodes,
|
||||
entity_types,
|
||||
),
|
||||
extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
|
||||
extract_edges(
|
||||
self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
|
||||
),
|
||||
)
|
||||
|
||||
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
||||
|
|
@ -380,6 +391,9 @@ class Graphiti:
|
|||
self.clients,
|
||||
edges,
|
||||
episode,
|
||||
nodes,
|
||||
edge_types or {},
|
||||
edge_type_map or edge_type_map_default,
|
||||
),
|
||||
extract_attributes_from_nodes(
|
||||
self.clients, nodes, episode, previous_episodes, entity_types
|
||||
|
|
@ -686,7 +700,19 @@ class Graphiti:
|
|||
)[0]
|
||||
|
||||
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
||||
self.llm_client, updated_edge, related_edges, existing_edges
|
||||
self.llm_client,
|
||||
updated_edge,
|
||||
related_edges,
|
||||
existing_edges,
|
||||
EpisodicNode(
|
||||
name='',
|
||||
source=EpisodeType.text,
|
||||
source_description='',
|
||||
content='',
|
||||
valid_at=edge.valid_at or utc_now(),
|
||||
entity_edges=[],
|
||||
group_id=edge.group_id,
|
||||
),
|
||||
)
|
||||
|
||||
await add_nodes_and_edges_bulk(
|
||||
|
|
|
|||
|
|
@ -34,8 +34,7 @@ ENTITY_EDGE_SAVE = """
|
|||
MATCH (source:Entity {uuid: $source_uuid})
|
||||
MATCH (target:Entity {uuid: $target_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
||||
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
|
||||
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
||||
SET r = $edge_data
|
||||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
||||
RETURN r.uuid AS uuid"""
|
||||
|
||||
|
|
@ -44,8 +43,7 @@ ENTITY_EDGE_SAVE_BULK = """
|
|||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
|
||||
SET r = edge
|
||||
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class EdgeDuplicate(BaseModel):
|
|||
...,
|
||||
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
|
||||
)
|
||||
fact_type: str = Field(..., description='One of the provided fact types or DEFAULT')
|
||||
|
||||
|
||||
class UniqueFact(BaseModel):
|
||||
|
|
@ -133,11 +134,18 @@ def resolve_edge(context: dict[str, Any]) -> list[Message]:
|
|||
{context['edge_invalidation_candidates']}
|
||||
</FACT INVALIDATION CANDIDATES>
|
||||
|
||||
<FACT TYPES>
|
||||
{context['edge_types']}
|
||||
</FACT TYPES>
|
||||
|
||||
|
||||
Task:
|
||||
If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
|
||||
If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
|
||||
|
||||
Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
|
||||
Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
|
||||
|
||||
Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
|
||||
Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
|
||||
If there are no contradicted facts, return an empty list.
|
||||
|
|
|
|||
|
|
@ -48,11 +48,13 @@ class MissingFacts(BaseModel):
|
|||
class Prompt(Protocol):
|
||||
edge: PromptVersion
|
||||
reflexion: PromptVersion
|
||||
extract_attributes: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
edge: PromptFunction
|
||||
reflexion: PromptFunction
|
||||
extract_attributes: PromptFunction
|
||||
|
||||
|
||||
def edge(context: dict[str, Any]) -> list[Message]:
|
||||
|
|
@ -82,12 +84,18 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|||
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
|
||||
</REFERENCE_TIME>
|
||||
|
||||
<FACT TYPES>
|
||||
{context['edge_types']}
|
||||
</FACT TYPES>
|
||||
|
||||
# TASK
|
||||
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
|
||||
Only extract facts that:
|
||||
- involve two DISTINCT ENTITIES from the ENTITIES list,
|
||||
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
|
||||
- and can be represented as edges in a knowledge graph.
|
||||
and can be represented as edges in a knowledge graph.
|
||||
- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
|
||||
could be classified into one of the provided fact types
|
||||
|
||||
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
|
||||
|
||||
|
|
@ -145,4 +153,40 @@ determine if any facts haven't been extracted.
|
|||
]
|
||||
|
||||
|
||||
versions: Versions = {'edge': edge, 'reflexion': reflexion}
|
||||
def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content='You are a helpful assistant that extracts fact properties from the provided text.',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
|
||||
<MESSAGE>
|
||||
{json.dumps(context['episode_content'], indent=2)}
|
||||
</MESSAGE>
|
||||
<REFERENCE TIME>
|
||||
{context['reference_time']}
|
||||
</REFERENCE TIME>
|
||||
|
||||
Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
|
||||
in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
|
||||
|
||||
Guidelines:
|
||||
1. Do not hallucinate entity property values if they cannot be found in the current context.
|
||||
2. Only use the provided MESSAGES and FACT to set attribute values.
|
||||
|
||||
<FACT>
|
||||
{context['fact']}
|
||||
</FACT>
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {
|
||||
'edge': edge,
|
||||
'reflexion': reflexion,
|
||||
'extract_attributes': extract_attributes,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -174,7 +174,8 @@ async def edge_fulltext_search(
|
|||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
r.invalid_at AS invalid_at,
|
||||
properties(r) AS attributes
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
|
@ -243,7 +244,8 @@ async def edge_similarity_search(
|
|||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
r.invalid_at AS invalid_at,
|
||||
properties(r) AS attributes
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
|
@ -301,7 +303,8 @@ async def edge_bfs_search(
|
|||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
r.invalid_at AS invalid_at,
|
||||
properties(r) AS attributes
|
||||
LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
|
@ -337,10 +340,10 @@ async def node_fulltext_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||
YIELD node AS n, score
|
||||
WHERE n:Entity
|
||||
"""
|
||||
+ filter_query
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
|
|
@ -771,7 +774,8 @@ async def get_relevant_edges(
|
|||
episodes: e.episodes,
|
||||
expired_at: e.expired_at,
|
||||
valid_at: e.valid_at,
|
||||
invalid_at: e.invalid_at
|
||||
invalid_at: e.invalid_at,
|
||||
attributes: properties(e)
|
||||
})[..$limit] AS matches
|
||||
"""
|
||||
)
|
||||
|
|
@ -837,7 +841,8 @@ async def get_edge_invalidation_candidates(
|
|||
episodes: e.episodes,
|
||||
expired_at: e.expired_at,
|
||||
valid_at: e.valid_at,
|
||||
invalid_at: e.invalid_at
|
||||
invalid_at: e.invalid_at,
|
||||
attributes: properties(e)
|
||||
})[..$limit] AS matches
|
||||
"""
|
||||
)
|
||||
|
|
|
|||
|
|
@ -137,16 +137,34 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
||||
nodes.append(entity_data)
|
||||
|
||||
edges: list[dict[str, Any]] = []
|
||||
for edge in entity_edges:
|
||||
if edge.fact_embedding is None:
|
||||
await edge.generate_embedding(embedder)
|
||||
edge_data: dict[str, Any] = {
|
||||
'uuid': edge.uuid,
|
||||
'source_node_uuid': edge.source_node_uuid,
|
||||
'target_node_uuid': edge.target_node_uuid,
|
||||
'name': edge.name,
|
||||
'fact': edge.fact,
|
||||
'fact_embedding': edge.fact_embedding,
|
||||
'group_id': edge.group_id,
|
||||
'episodes': edge.episodes,
|
||||
'created_at': edge.created_at,
|
||||
'expired_at': edge.expired_at,
|
||||
'valid_at': edge.valid_at,
|
||||
'invalid_at': edge.invalid_at,
|
||||
}
|
||||
|
||||
edge_data.update(edge.attributes or {})
|
||||
edges.append(edge_data)
|
||||
|
||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
||||
await tx.run(
|
||||
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
||||
)
|
||||
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
|
||||
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ import logging
|
|||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.edges import (
|
||||
CommunityEdge,
|
||||
EntityEdge,
|
||||
|
|
@ -83,6 +85,7 @@ async def extract_edges(
|
|||
nodes: list[EntityNode],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
group_id: str = '',
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
|
||||
|
|
@ -91,12 +94,25 @@ async def extract_edges(
|
|||
|
||||
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
||||
|
||||
edge_types_context = (
|
||||
[
|
||||
{
|
||||
'fact_type_name': type_name,
|
||||
'fact_type_description': type_model.__doc__,
|
||||
}
|
||||
for type_name, type_model in edge_types.items()
|
||||
]
|
||||
if edge_types is not None
|
||||
else []
|
||||
)
|
||||
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
'episode_content': episode.content,
|
||||
'nodes': [node.name for node in nodes],
|
||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||
'reference_time': episode.valid_at,
|
||||
'edge_types': edge_types_context,
|
||||
'custom_prompt': '',
|
||||
}
|
||||
|
||||
|
|
@ -233,6 +249,9 @@ async def resolve_extracted_edges(
|
|||
clients: GraphitiClients,
|
||||
extracted_edges: list[EntityEdge],
|
||||
episode: EpisodicNode,
|
||||
entities: list[EntityNode],
|
||||
edge_types: dict[str, BaseModel],
|
||||
edge_type_map: dict[tuple[str, str], list[str]],
|
||||
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
||||
driver = clients.driver
|
||||
llm_client = clients.llm_client
|
||||
|
|
@ -251,15 +270,50 @@ async def resolve_extracted_edges(
|
|||
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
||||
)
|
||||
|
||||
# Build entity hash table
|
||||
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
||||
|
||||
# Determine which edge types are relevant for each edge
|
||||
edge_types_lst: list[dict[str, BaseModel]] = []
|
||||
for extracted_edge in extracted_edges:
|
||||
source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels
|
||||
target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels
|
||||
label_tuples = [
|
||||
(source_label, target_label)
|
||||
for source_label in source_node_labels
|
||||
for target_label in target_node_labels
|
||||
]
|
||||
|
||||
extracted_edge_types = {}
|
||||
for label_tuple in label_tuples:
|
||||
type_names = edge_type_map.get(label_tuple, [])
|
||||
for type_name in type_names:
|
||||
type_model = edge_types.get(type_name)
|
||||
if type_model is None:
|
||||
continue
|
||||
|
||||
extracted_edge_types[type_name] = type_model
|
||||
|
||||
edge_types_lst.append(extracted_edge_types)
|
||||
|
||||
# resolve edges with related edges in the graph and find invalidation candidates
|
||||
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
resolve_extracted_edge(
|
||||
llm_client, extracted_edge, related_edges, existing_edges, episode
|
||||
llm_client,
|
||||
extracted_edge,
|
||||
related_edges,
|
||||
existing_edges,
|
||||
episode,
|
||||
extracted_edge_types,
|
||||
)
|
||||
for extracted_edge, related_edges, existing_edges in zip(
|
||||
extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True
|
||||
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
||||
extracted_edges,
|
||||
related_edges_lists,
|
||||
edge_invalidation_candidates,
|
||||
edge_types_lst,
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -322,7 +376,8 @@ async def resolve_extracted_edge(
|
|||
extracted_edge: EntityEdge,
|
||||
related_edges: list[EntityEdge],
|
||||
existing_edges: list[EntityEdge],
|
||||
episode: EpisodicNode | None = None,
|
||||
episode: EpisodicNode,
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
) -> tuple[EntityEdge, list[EntityEdge]]:
|
||||
if len(related_edges) == 0 and len(existing_edges) == 0:
|
||||
return extracted_edge, []
|
||||
|
|
@ -338,10 +393,24 @@ async def resolve_extracted_edge(
|
|||
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
||||
]
|
||||
|
||||
edge_types_context = (
|
||||
[
|
||||
{
|
||||
'fact_type_id': i,
|
||||
'fact_type_name': type_name,
|
||||
'fact_type_description': type_model.__doc__,
|
||||
}
|
||||
for i, (type_name, type_model) in enumerate(edge_types.items())
|
||||
]
|
||||
if edge_types is not None
|
||||
else []
|
||||
)
|
||||
|
||||
context = {
|
||||
'existing_edges': related_edges_context,
|
||||
'new_edge': extracted_edge.fact,
|
||||
'edge_invalidation_candidates': invalidation_edge_candidates_context,
|
||||
'edge_types': edge_types_context,
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
|
|
@ -365,6 +434,26 @@ async def resolve_extracted_edge(
|
|||
|
||||
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
|
||||
|
||||
fact_type: str = str(llm_response.get('fact_type'))
|
||||
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
|
||||
resolved_edge.name = fact_type
|
||||
|
||||
edge_attributes_context = {
|
||||
'message': episode.content,
|
||||
'reference_time': episode.valid_at,
|
||||
'fact': resolved_edge.fact,
|
||||
}
|
||||
|
||||
edge_model = edge_types.get(fact_type)
|
||||
|
||||
edge_attributes_response = await llm_client.generate_response(
|
||||
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
||||
response_model=edge_model, # type: ignore
|
||||
model_size=ModelSize.small,
|
||||
)
|
||||
|
||||
resolved_edge.attributes = edge_attributes_response
|
||||
|
||||
end = time()
|
||||
logger.debug(
|
||||
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue