fix: Address graph disconnect (#7)
* fix: Address graph disconnect * chore: Remove valid_to and valid_from setting in extract edges step (will be handled during invalidation step)
This commit is contained in:
parent
4db3906049
commit
40e74a2e97
6 changed files with 72 additions and 25 deletions
|
|
@ -113,12 +113,16 @@ class Graphiti:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
|
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
|
||||||
|
logger.info(
|
||||||
|
f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}"
|
||||||
|
)
|
||||||
new_nodes = await dedupe_extracted_nodes(
|
new_nodes = await dedupe_extracted_nodes(
|
||||||
self.llm_client, extracted_nodes, existing_nodes
|
self.llm_client, extracted_nodes, existing_nodes
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Deduped touched nodes: {[(n.name, n.uuid) for n in new_nodes]}"
|
||||||
|
)
|
||||||
nodes.extend(new_nodes)
|
nodes.extend(new_nodes)
|
||||||
|
|
||||||
extracted_edges = await extract_edges(
|
extracted_edges = await extract_edges(
|
||||||
|
|
@ -130,11 +134,17 @@ class Graphiti:
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_edges = await get_relevant_edges(extracted_edges, self.driver)
|
existing_edges = await get_relevant_edges(extracted_edges, self.driver)
|
||||||
|
logger.info(f"Existing edges: {[(e.name, e.uuid) for e in existing_edges]}")
|
||||||
|
logger.info(
|
||||||
|
f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}"
|
||||||
|
)
|
||||||
|
|
||||||
new_edges = await dedupe_extracted_edges(
|
new_edges = await dedupe_extracted_edges(
|
||||||
self.llm_client, extracted_edges, existing_edges
|
self.llm_client, extracted_edges, existing_edges
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}")
|
||||||
|
|
||||||
entity_edges.extend(new_edges)
|
entity_edges.extend(new_edges)
|
||||||
episodic_edges.extend(
|
episodic_edges.extend(
|
||||||
build_episodic_edges(
|
build_episodic_edges(
|
||||||
|
|
|
||||||
|
|
@ -135,17 +135,18 @@ def v2(context: dict[str, any]) -> list[Message]:
|
||||||
Message(
|
Message(
|
||||||
role="user",
|
role="user",
|
||||||
content=f"""
|
content=f"""
|
||||||
Given the following context, extract new edges (relationships) that need to be added to the knowledge graph:
|
Given the following context, extract edges (relationships) that need to be added to the knowledge graph:
|
||||||
Nodes:
|
Nodes:
|
||||||
{json.dumps(context['nodes'], indent=2)}
|
{json.dumps(context['nodes'], indent=2)}
|
||||||
|
|
||||||
New Episode:
|
|
||||||
Content: {context['episode_content']}
|
|
||||||
|
|
||||||
Previous Episodes:
|
|
||||||
|
Episodes:
|
||||||
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
|
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
|
||||||
|
{context['episode_content']} <-- New Episode
|
||||||
|
|
||||||
Extract new entity edges based on the content of the current episode, the given nodes, and context from previous episodes.
|
|
||||||
|
Extract entity edges based on the content of the current episode, the given nodes, and context from previous episodes.
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Create edges only between the provided nodes.
|
1. Create edges only between the provided nodes.
|
||||||
|
|
@ -168,7 +169,7 @@ def v2(context: dict[str, any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
}}
|
}}
|
||||||
|
|
||||||
If no new edges need to be added, return an empty list for "new_edges".
|
If no edges need to be added, return an empty list for "edges".
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,13 @@ from .models import Message, PromptVersion, PromptFunction
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
v1: PromptVersion
|
v1: PromptVersion
|
||||||
v2: PromptVersion
|
v2: PromptVersion
|
||||||
|
v3: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
v1: PromptFunction
|
v1: PromptFunction
|
||||||
v2: PromptFunction
|
v2: PromptFunction
|
||||||
|
v3: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def v1(context: dict[str, any]) -> list[Message]:
|
def v1(context: dict[str, any]) -> list[Message]:
|
||||||
|
|
@ -103,4 +105,37 @@ def v2(context: dict[str, any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {"v1": v1, "v2": v2}
|
def v3(context: dict[str, any]) -> list[Message]:
|
||||||
|
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
|
||||||
|
|
||||||
|
user_prompt = f"""
|
||||||
|
Given the following conversation, extract entity nodes that are explicitly or implicitly mentioned:
|
||||||
|
|
||||||
|
Conversation:
|
||||||
|
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
|
||||||
|
{context["episode_content"]}
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
1. ALWAYS extract the speaker/actor as the first node. The speaker is the part before the colon in each line of dialogue.
|
||||||
|
2. Extract other significant entities, concepts, or actors mentioned in the conversation.
|
||||||
|
3. Provide concise but informative summaries for each extracted node.
|
||||||
|
4. Avoid creating nodes for relationships or actions.
|
||||||
|
|
||||||
|
Respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"new_nodes": [
|
||||||
|
{{
|
||||||
|
"name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
|
||||||
|
"labels": ["Entity", "Speaker" for speaker nodes, "OptionalAdditionalLabel"],
|
||||||
|
"summary": "Brief summary of the node's role or significance"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
Message(role="system", content=sys_prompt),
|
||||||
|
Message(role="user", content=user_prompt),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
versions: Versions = {"v1": v1, "v2": v2, "v3": v3}
|
||||||
|
|
|
||||||
|
|
@ -170,20 +170,21 @@ async def extract_edges(
|
||||||
# Convert the extracted data into EntityEdge objects
|
# Convert the extracted data into EntityEdge objects
|
||||||
edges = []
|
edges = []
|
||||||
for edge_data in edges_data:
|
for edge_data in edges_data:
|
||||||
edge = EntityEdge(
|
if edge_data["target_node_uuid"] and edge_data["source_node_uuid"]:
|
||||||
source_node_uuid=edge_data["source_node_uuid"],
|
edge = EntityEdge(
|
||||||
target_node_uuid=edge_data["target_node_uuid"],
|
source_node_uuid=edge_data["source_node_uuid"],
|
||||||
name=edge_data["relation_type"],
|
target_node_uuid=edge_data["target_node_uuid"],
|
||||||
fact=edge_data["fact"],
|
name=edge_data["relation_type"],
|
||||||
episodes=[episode.uuid],
|
fact=edge_data["fact"],
|
||||||
created_at=datetime.now(),
|
episodes=[episode.uuid],
|
||||||
valid_at=edge_data["valid_at"],
|
created_at=datetime.now(),
|
||||||
invalid_at=edge_data["invalid_at"],
|
valid_at=None,
|
||||||
)
|
invalid_at=None,
|
||||||
edges.append(edge)
|
)
|
||||||
logger.info(
|
edges.append(edge)
|
||||||
f"Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})"
|
logger.info(
|
||||||
)
|
f"Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})"
|
||||||
|
)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ async def extract_nodes(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_nodes.v2(context)
|
prompt_library.extract_nodes.v3(context)
|
||||||
)
|
)
|
||||||
new_nodes_data = llm_response.get("new_nodes", [])
|
new_nodes_data = llm_response.get("new_nodes", [])
|
||||||
logger.info(f"Extracted new nodes: {new_nodes_data}")
|
logger.info(f"Extracted new nodes: {new_nodes_data}")
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ async def main():
|
||||||
)
|
)
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
name="Message 2",
|
name="Message 2",
|
||||||
episode_body="Paul: I love bananas",
|
episode_body="Paul: I own many bananas",
|
||||||
source_description="WhatsApp Message",
|
source_description="WhatsApp Message",
|
||||||
)
|
)
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue