Edge types (#501)
* update entity edge attributes * Adding prompts * extract fact attributes * edge types * edge types no regressions * mypy * mypy update * Update graphiti_core/prompts/dedupe_edges.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update graphiti_core/prompts/dedupe_edges.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * mypy --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
619c84e98b
commit
db7595fe63
8 changed files with 252 additions and 38 deletions
|
|
@ -49,7 +49,9 @@ ENTITY_EDGE_RETURN: LiteralString = """
|
||||||
e.episodes AS episodes,
|
e.episodes AS episodes,
|
||||||
e.expired_at AS expired_at,
|
e.expired_at AS expired_at,
|
||||||
e.valid_at AS valid_at,
|
e.valid_at AS valid_at,
|
||||||
e.invalid_at AS invalid_at"""
|
e.invalid_at AS invalid_at,
|
||||||
|
properties(e) AS attributes
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel, ABC):
|
class Edge(BaseModel, ABC):
|
||||||
|
|
@ -209,6 +211,9 @@ class EntityEdge(Edge):
|
||||||
invalid_at: datetime | None = Field(
|
invalid_at: datetime | None = Field(
|
||||||
default=None, description='datetime of when the fact stopped being true'
|
default=None, description='datetime of when the fact stopped being true'
|
||||||
)
|
)
|
||||||
|
attributes: dict[str, Any] = Field(
|
||||||
|
default={}, description='Additional attributes of the edge. Dependent on edge name'
|
||||||
|
)
|
||||||
|
|
||||||
async def generate_embedding(self, embedder: EmbedderClient):
|
async def generate_embedding(self, embedder: EmbedderClient):
|
||||||
start = time()
|
start = time()
|
||||||
|
|
@ -236,20 +241,26 @@ class EntityEdge(Edge):
|
||||||
self.fact_embedding = records[0]['fact_embedding']
|
self.fact_embedding = records[0]['fact_embedding']
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
|
edge_data: dict[str, Any] = {
|
||||||
|
'source_uuid': self.source_node_uuid,
|
||||||
|
'target_uuid': self.target_node_uuid,
|
||||||
|
'uuid': self.uuid,
|
||||||
|
'name': self.name,
|
||||||
|
'group_id': self.group_id,
|
||||||
|
'fact': self.fact,
|
||||||
|
'fact_embedding': self.fact_embedding,
|
||||||
|
'episodes': self.episodes,
|
||||||
|
'created_at': self.created_at,
|
||||||
|
'expired_at': self.expired_at,
|
||||||
|
'valid_at': self.valid_at,
|
||||||
|
'invalid_at': self.invalid_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
edge_data.update(self.attributes or {})
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
ENTITY_EDGE_SAVE,
|
ENTITY_EDGE_SAVE,
|
||||||
source_uuid=self.source_node_uuid,
|
edge_data=edge_data,
|
||||||
target_uuid=self.target_node_uuid,
|
|
||||||
uuid=self.uuid,
|
|
||||||
name=self.name,
|
|
||||||
group_id=self.group_id,
|
|
||||||
fact=self.fact,
|
|
||||||
fact_embedding=self.fact_embedding,
|
|
||||||
episodes=self.episodes,
|
|
||||||
created_at=self.created_at,
|
|
||||||
expired_at=self.expired_at,
|
|
||||||
valid_at=self.valid_at,
|
|
||||||
invalid_at=self.invalid_at,
|
|
||||||
database_=DEFAULT_DATABASE,
|
database_=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -334,8 +345,8 @@ class EntityEdge(Edge):
|
||||||
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ ENTITY_EDGE_RETURN
|
+ ENTITY_EDGE_RETURN
|
||||||
)
|
)
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -457,7 +468,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
||||||
|
|
||||||
|
|
||||||
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
||||||
return EntityEdge(
|
edge = EntityEdge(
|
||||||
uuid=record['uuid'],
|
uuid=record['uuid'],
|
||||||
source_node_uuid=record['source_node_uuid'],
|
source_node_uuid=record['source_node_uuid'],
|
||||||
target_node_uuid=record['target_node_uuid'],
|
target_node_uuid=record['target_node_uuid'],
|
||||||
|
|
@ -469,8 +480,23 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
||||||
expired_at=parse_db_date(record['expired_at']),
|
expired_at=parse_db_date(record['expired_at']),
|
||||||
valid_at=parse_db_date(record['valid_at']),
|
valid_at=parse_db_date(record['valid_at']),
|
||||||
invalid_at=parse_db_date(record['invalid_at']),
|
invalid_at=parse_db_date(record['invalid_at']),
|
||||||
|
attributes=record['attributes'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
edge.attributes.pop('uuid', None)
|
||||||
|
edge.attributes.pop('source_node_uuid', None)
|
||||||
|
edge.attributes.pop('target_node_uuid', None)
|
||||||
|
edge.attributes.pop('fact', None)
|
||||||
|
edge.attributes.pop('name', None)
|
||||||
|
edge.attributes.pop('group_id', None)
|
||||||
|
edge.attributes.pop('episodes', None)
|
||||||
|
edge.attributes.pop('created_at', None)
|
||||||
|
edge.attributes.pop('expired_at', None)
|
||||||
|
edge.attributes.pop('valid_at', None)
|
||||||
|
edge.attributes.pop('invalid_at', None)
|
||||||
|
|
||||||
|
return edge
|
||||||
|
|
||||||
|
|
||||||
def get_community_edge_from_record(record: Any):
|
def get_community_edge_from_record(record: Any):
|
||||||
return CommunityEdge(
|
return CommunityEdge(
|
||||||
|
|
|
||||||
|
|
@ -273,6 +273,8 @@ class Graphiti:
|
||||||
update_communities: bool = False,
|
update_communities: bool = False,
|
||||||
entity_types: dict[str, BaseModel] | None = None,
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
previous_episode_uuids: list[str] | None = None,
|
previous_episode_uuids: list[str] | None = None,
|
||||||
|
edge_types: dict[str, BaseModel] | None = None,
|
||||||
|
edge_type_map: dict[tuple[str, str], list[str]] | None = None,
|
||||||
) -> AddEpisodeResults:
|
) -> AddEpisodeResults:
|
||||||
"""
|
"""
|
||||||
Process an episode and update the graph.
|
Process an episode and update the graph.
|
||||||
|
|
@ -355,6 +357,13 @@ class Graphiti:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create default edge type map
|
||||||
|
edge_type_map_default = (
|
||||||
|
{('Entity', 'Entity'): list(edge_types.keys())}
|
||||||
|
if edge_types is not None
|
||||||
|
else {('Entity', 'Entity'): []}
|
||||||
|
)
|
||||||
|
|
||||||
# Extract entities as nodes
|
# Extract entities as nodes
|
||||||
|
|
||||||
extracted_nodes = await extract_nodes(
|
extracted_nodes = await extract_nodes(
|
||||||
|
|
@ -370,7 +379,9 @@ class Graphiti:
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
entity_types,
|
entity_types,
|
||||||
),
|
),
|
||||||
extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
|
extract_edges(
|
||||||
|
self.clients, episode, extracted_nodes, previous_episodes, group_id, edge_types
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
||||||
|
|
@ -380,6 +391,9 @@ class Graphiti:
|
||||||
self.clients,
|
self.clients,
|
||||||
edges,
|
edges,
|
||||||
episode,
|
episode,
|
||||||
|
nodes,
|
||||||
|
edge_types or {},
|
||||||
|
edge_type_map or edge_type_map_default,
|
||||||
),
|
),
|
||||||
extract_attributes_from_nodes(
|
extract_attributes_from_nodes(
|
||||||
self.clients, nodes, episode, previous_episodes, entity_types
|
self.clients, nodes, episode, previous_episodes, entity_types
|
||||||
|
|
@ -686,7 +700,19 @@ class Graphiti:
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
resolved_edge, invalidated_edges = await resolve_extracted_edge(
|
||||||
self.llm_client, updated_edge, related_edges, existing_edges
|
self.llm_client,
|
||||||
|
updated_edge,
|
||||||
|
related_edges,
|
||||||
|
existing_edges,
|
||||||
|
EpisodicNode(
|
||||||
|
name='',
|
||||||
|
source=EpisodeType.text,
|
||||||
|
source_description='',
|
||||||
|
content='',
|
||||||
|
valid_at=edge.valid_at or utc_now(),
|
||||||
|
entity_edges=[],
|
||||||
|
group_id=edge.group_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await add_nodes_and_edges_bulk(
|
await add_nodes_and_edges_bulk(
|
||||||
|
|
|
||||||
|
|
@ -34,8 +34,7 @@ ENTITY_EDGE_SAVE = """
|
||||||
MATCH (source:Entity {uuid: $source_uuid})
|
MATCH (source:Entity {uuid: $source_uuid})
|
||||||
MATCH (target:Entity {uuid: $target_uuid})
|
MATCH (target:Entity {uuid: $target_uuid})
|
||||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
||||||
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
|
SET r = $edge_data
|
||||||
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
|
||||||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
||||||
RETURN r.uuid AS uuid"""
|
RETURN r.uuid AS uuid"""
|
||||||
|
|
||||||
|
|
@ -44,8 +43,7 @@ ENTITY_EDGE_SAVE_BULK = """
|
||||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
SET r = edge
|
||||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
|
|
||||||
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
||||||
RETURN edge.uuid AS uuid
|
RETURN edge.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ class EdgeDuplicate(BaseModel):
|
||||||
...,
|
...,
|
||||||
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
|
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
|
||||||
)
|
)
|
||||||
|
fact_type: str = Field(..., description='One of the provided fact types or DEFAULT')
|
||||||
|
|
||||||
|
|
||||||
class UniqueFact(BaseModel):
|
class UniqueFact(BaseModel):
|
||||||
|
|
@ -133,11 +134,18 @@ def resolve_edge(context: dict[str, Any]) -> list[Message]:
|
||||||
{context['edge_invalidation_candidates']}
|
{context['edge_invalidation_candidates']}
|
||||||
</FACT INVALIDATION CANDIDATES>
|
</FACT INVALIDATION CANDIDATES>
|
||||||
|
|
||||||
|
<FACT TYPES>
|
||||||
|
{context['edge_types']}
|
||||||
|
</FACT TYPES>
|
||||||
|
|
||||||
|
|
||||||
Task:
|
Task:
|
||||||
If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
|
If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
|
||||||
If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
|
If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
|
||||||
|
|
||||||
|
Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
|
||||||
|
Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
|
||||||
|
|
||||||
Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
|
Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
|
||||||
Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
|
Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
|
||||||
If there are no contradicted facts, return an empty list.
|
If there are no contradicted facts, return an empty list.
|
||||||
|
|
|
||||||
|
|
@ -48,11 +48,13 @@ class MissingFacts(BaseModel):
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
edge: PromptVersion
|
edge: PromptVersion
|
||||||
reflexion: PromptVersion
|
reflexion: PromptVersion
|
||||||
|
extract_attributes: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
edge: PromptFunction
|
edge: PromptFunction
|
||||||
reflexion: PromptFunction
|
reflexion: PromptFunction
|
||||||
|
extract_attributes: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def edge(context: dict[str, Any]) -> list[Message]:
|
def edge(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
@ -82,12 +84,18 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
||||||
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
|
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
|
||||||
</REFERENCE_TIME>
|
</REFERENCE_TIME>
|
||||||
|
|
||||||
|
<FACT TYPES>
|
||||||
|
{context['edge_types']}
|
||||||
|
</FACT TYPES>
|
||||||
|
|
||||||
# TASK
|
# TASK
|
||||||
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
|
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
|
||||||
Only extract facts that:
|
Only extract facts that:
|
||||||
- involve two DISTINCT ENTITIES from the ENTITIES list,
|
- involve two DISTINCT ENTITIES from the ENTITIES list,
|
||||||
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
|
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
|
||||||
- and can be represented as edges in a knowledge graph.
|
and can be represented as edges in a knowledge graph.
|
||||||
|
- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
|
||||||
|
could be classified into one of the provided fact types
|
||||||
|
|
||||||
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
|
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
|
||||||
|
|
||||||
|
|
@ -145,4 +153,40 @@ determine if any facts haven't been extracted.
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {'edge': edge, 'reflexion': reflexion}
|
def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role='system',
|
||||||
|
content='You are a helpful assistant that extracts fact properties from the provided text.',
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role='user',
|
||||||
|
content=f"""
|
||||||
|
|
||||||
|
<MESSAGE>
|
||||||
|
{json.dumps(context['episode_content'], indent=2)}
|
||||||
|
</MESSAGE>
|
||||||
|
<REFERENCE TIME>
|
||||||
|
{context['reference_time']}
|
||||||
|
</REFERENCE TIME>
|
||||||
|
|
||||||
|
Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
|
||||||
|
in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
1. Do not hallucinate entity property values if they cannot be found in the current context.
|
||||||
|
2. Only use the provided MESSAGES and FACT to set attribute values.
|
||||||
|
|
||||||
|
<FACT>
|
||||||
|
{context['fact']}
|
||||||
|
</FACT>
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
versions: Versions = {
|
||||||
|
'edge': edge,
|
||||||
|
'reflexion': reflexion,
|
||||||
|
'extract_attributes': extract_attributes,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -174,7 +174,8 @@ async def edge_fulltext_search(
|
||||||
r.episodes AS episodes,
|
r.episodes AS episodes,
|
||||||
r.expired_at AS expired_at,
|
r.expired_at AS expired_at,
|
||||||
r.valid_at AS valid_at,
|
r.valid_at AS valid_at,
|
||||||
r.invalid_at AS invalid_at
|
r.invalid_at AS invalid_at,
|
||||||
|
properties(r) AS attributes
|
||||||
ORDER BY score DESC LIMIT $limit
|
ORDER BY score DESC LIMIT $limit
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
@ -243,7 +244,8 @@ async def edge_similarity_search(
|
||||||
r.episodes AS episodes,
|
r.episodes AS episodes,
|
||||||
r.expired_at AS expired_at,
|
r.expired_at AS expired_at,
|
||||||
r.valid_at AS valid_at,
|
r.valid_at AS valid_at,
|
||||||
r.invalid_at AS invalid_at
|
r.invalid_at AS invalid_at,
|
||||||
|
properties(r) AS attributes
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
@ -301,7 +303,8 @@ async def edge_bfs_search(
|
||||||
r.episodes AS episodes,
|
r.episodes AS episodes,
|
||||||
r.expired_at AS expired_at,
|
r.expired_at AS expired_at,
|
||||||
r.valid_at AS valid_at,
|
r.valid_at AS valid_at,
|
||||||
r.invalid_at AS invalid_at
|
r.invalid_at AS invalid_at,
|
||||||
|
properties(r) AS attributes
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
@ -337,10 +340,10 @@ async def node_fulltext_search(
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
||||||
YIELD node AS n, score
|
YIELD node AS n, score
|
||||||
WHERE n:Entity
|
WHERE n:Entity
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
|
|
@ -771,7 +774,8 @@ async def get_relevant_edges(
|
||||||
episodes: e.episodes,
|
episodes: e.episodes,
|
||||||
expired_at: e.expired_at,
|
expired_at: e.expired_at,
|
||||||
valid_at: e.valid_at,
|
valid_at: e.valid_at,
|
||||||
invalid_at: e.invalid_at
|
invalid_at: e.invalid_at,
|
||||||
|
attributes: properties(e)
|
||||||
})[..$limit] AS matches
|
})[..$limit] AS matches
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
@ -837,7 +841,8 @@ async def get_edge_invalidation_candidates(
|
||||||
episodes: e.episodes,
|
episodes: e.episodes,
|
||||||
expired_at: e.expired_at,
|
expired_at: e.expired_at,
|
||||||
valid_at: e.valid_at,
|
valid_at: e.valid_at,
|
||||||
invalid_at: e.invalid_at
|
invalid_at: e.invalid_at,
|
||||||
|
attributes: properties(e)
|
||||||
})[..$limit] AS matches
|
})[..$limit] AS matches
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -137,16 +137,34 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
||||||
nodes.append(entity_data)
|
nodes.append(entity_data)
|
||||||
|
|
||||||
|
edges: list[dict[str, Any]] = []
|
||||||
for edge in entity_edges:
|
for edge in entity_edges:
|
||||||
if edge.fact_embedding is None:
|
if edge.fact_embedding is None:
|
||||||
await edge.generate_embedding(embedder)
|
await edge.generate_embedding(embedder)
|
||||||
|
edge_data: dict[str, Any] = {
|
||||||
|
'uuid': edge.uuid,
|
||||||
|
'source_node_uuid': edge.source_node_uuid,
|
||||||
|
'target_node_uuid': edge.target_node_uuid,
|
||||||
|
'name': edge.name,
|
||||||
|
'fact': edge.fact,
|
||||||
|
'fact_embedding': edge.fact_embedding,
|
||||||
|
'group_id': edge.group_id,
|
||||||
|
'episodes': edge.episodes,
|
||||||
|
'created_at': edge.created_at,
|
||||||
|
'expired_at': edge.expired_at,
|
||||||
|
'valid_at': edge.valid_at,
|
||||||
|
'invalid_at': edge.invalid_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
edge_data.update(edge.attributes or {})
|
||||||
|
edges.append(edge_data)
|
||||||
|
|
||||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||||
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
||||||
await tx.run(
|
await tx.run(
|
||||||
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
||||||
)
|
)
|
||||||
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[edge.model_dump() for edge in entity_edges])
|
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes_and_edges_bulk(
|
async def extract_nodes_and_edges_bulk(
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from graphiti_core.edges import (
|
from graphiti_core.edges import (
|
||||||
CommunityEdge,
|
CommunityEdge,
|
||||||
EntityEdge,
|
EntityEdge,
|
||||||
|
|
@ -83,6 +85,7 @@ async def extract_edges(
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
group_id: str = '',
|
group_id: str = '',
|
||||||
|
edge_types: dict[str, BaseModel] | None = None,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
@ -91,12 +94,25 @@ async def extract_edges(
|
||||||
|
|
||||||
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
||||||
|
|
||||||
|
edge_types_context = (
|
||||||
|
[
|
||||||
|
{
|
||||||
|
'fact_type_name': type_name,
|
||||||
|
'fact_type_description': type_model.__doc__,
|
||||||
|
}
|
||||||
|
for type_name, type_model in edge_types.items()
|
||||||
|
]
|
||||||
|
if edge_types is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
'episode_content': episode.content,
|
'episode_content': episode.content,
|
||||||
'nodes': [node.name for node in nodes],
|
'nodes': [node.name for node in nodes],
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||||
'reference_time': episode.valid_at,
|
'reference_time': episode.valid_at,
|
||||||
|
'edge_types': edge_types_context,
|
||||||
'custom_prompt': '',
|
'custom_prompt': '',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -233,6 +249,9 @@ async def resolve_extracted_edges(
|
||||||
clients: GraphitiClients,
|
clients: GraphitiClients,
|
||||||
extracted_edges: list[EntityEdge],
|
extracted_edges: list[EntityEdge],
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
|
entities: list[EntityNode],
|
||||||
|
edge_types: dict[str, BaseModel],
|
||||||
|
edge_type_map: dict[tuple[str, str], list[str]],
|
||||||
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
||||||
driver = clients.driver
|
driver = clients.driver
|
||||||
llm_client = clients.llm_client
|
llm_client = clients.llm_client
|
||||||
|
|
@ -251,15 +270,50 @@ async def resolve_extracted_edges(
|
||||||
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build entity hash table
|
||||||
|
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
||||||
|
|
||||||
|
# Determine which edge types are relevant for each edge
|
||||||
|
edge_types_lst: list[dict[str, BaseModel]] = []
|
||||||
|
for extracted_edge in extracted_edges:
|
||||||
|
source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels
|
||||||
|
target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels
|
||||||
|
label_tuples = [
|
||||||
|
(source_label, target_label)
|
||||||
|
for source_label in source_node_labels
|
||||||
|
for target_label in target_node_labels
|
||||||
|
]
|
||||||
|
|
||||||
|
extracted_edge_types = {}
|
||||||
|
for label_tuple in label_tuples:
|
||||||
|
type_names = edge_type_map.get(label_tuple, [])
|
||||||
|
for type_name in type_names:
|
||||||
|
type_model = edge_types.get(type_name)
|
||||||
|
if type_model is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
extracted_edge_types[type_name] = type_model
|
||||||
|
|
||||||
|
edge_types_lst.append(extracted_edge_types)
|
||||||
|
|
||||||
# resolve edges with related edges in the graph and find invalidation candidates
|
# resolve edges with related edges in the graph and find invalidation candidates
|
||||||
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
resolve_extracted_edge(
|
resolve_extracted_edge(
|
||||||
llm_client, extracted_edge, related_edges, existing_edges, episode
|
llm_client,
|
||||||
|
extracted_edge,
|
||||||
|
related_edges,
|
||||||
|
existing_edges,
|
||||||
|
episode,
|
||||||
|
extracted_edge_types,
|
||||||
)
|
)
|
||||||
for extracted_edge, related_edges, existing_edges in zip(
|
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
||||||
extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True
|
extracted_edges,
|
||||||
|
related_edges_lists,
|
||||||
|
edge_invalidation_candidates,
|
||||||
|
edge_types_lst,
|
||||||
|
strict=True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -322,7 +376,8 @@ async def resolve_extracted_edge(
|
||||||
extracted_edge: EntityEdge,
|
extracted_edge: EntityEdge,
|
||||||
related_edges: list[EntityEdge],
|
related_edges: list[EntityEdge],
|
||||||
existing_edges: list[EntityEdge],
|
existing_edges: list[EntityEdge],
|
||||||
episode: EpisodicNode | None = None,
|
episode: EpisodicNode,
|
||||||
|
edge_types: dict[str, BaseModel] | None = None,
|
||||||
) -> tuple[EntityEdge, list[EntityEdge]]:
|
) -> tuple[EntityEdge, list[EntityEdge]]:
|
||||||
if len(related_edges) == 0 and len(existing_edges) == 0:
|
if len(related_edges) == 0 and len(existing_edges) == 0:
|
||||||
return extracted_edge, []
|
return extracted_edge, []
|
||||||
|
|
@ -338,10 +393,24 @@ async def resolve_extracted_edge(
|
||||||
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
edge_types_context = (
|
||||||
|
[
|
||||||
|
{
|
||||||
|
'fact_type_id': i,
|
||||||
|
'fact_type_name': type_name,
|
||||||
|
'fact_type_description': type_model.__doc__,
|
||||||
|
}
|
||||||
|
for i, (type_name, type_model) in enumerate(edge_types.items())
|
||||||
|
]
|
||||||
|
if edge_types is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
context = {
|
context = {
|
||||||
'existing_edges': related_edges_context,
|
'existing_edges': related_edges_context,
|
||||||
'new_edge': extracted_edge.fact,
|
'new_edge': extracted_edge.fact,
|
||||||
'edge_invalidation_candidates': invalidation_edge_candidates_context,
|
'edge_invalidation_candidates': invalidation_edge_candidates_context,
|
||||||
|
'edge_types': edge_types_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
|
|
@ -365,6 +434,26 @@ async def resolve_extracted_edge(
|
||||||
|
|
||||||
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
|
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
|
||||||
|
|
||||||
|
fact_type: str = str(llm_response.get('fact_type'))
|
||||||
|
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
|
||||||
|
resolved_edge.name = fact_type
|
||||||
|
|
||||||
|
edge_attributes_context = {
|
||||||
|
'message': episode.content,
|
||||||
|
'reference_time': episode.valid_at,
|
||||||
|
'fact': resolved_edge.fact,
|
||||||
|
}
|
||||||
|
|
||||||
|
edge_model = edge_types.get(fact_type)
|
||||||
|
|
||||||
|
edge_attributes_response = await llm_client.generate_response(
|
||||||
|
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
||||||
|
response_model=edge_model, # type: ignore
|
||||||
|
model_size=ModelSize.small,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_edge.attributes = edge_attributes_response
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
|
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue