add extract nodes from text prompt (#106)
This commit is contained in:
parent
b214baa85f
commit
4122d350a5
2 changed files with 69 additions and 2 deletions
|
|
@ -24,12 +24,14 @@ class Prompt(Protocol):
|
||||||
v1: PromptVersion
|
v1: PromptVersion
|
||||||
v2: PromptVersion
|
v2: PromptVersion
|
||||||
extract_json: PromptVersion
|
extract_json: PromptVersion
|
||||||
|
extract_text: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
v1: PromptFunction
|
v1: PromptFunction
|
||||||
v2: PromptFunction
|
v2: PromptFunction
|
||||||
extract_json: PromptFunction
|
extract_json: PromptFunction
|
||||||
|
extract_text: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def v1(context: dict[str, Any]) -> list[Message]:
|
def v1(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
@ -144,4 +146,44 @@ Respond with a JSON object in the following format:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {'v1': v1, 'v2': v2, 'extract_json': extract_json}
|
def extract_text(context: dict[str, Any]) -> list[Message]:
|
||||||
|
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
|
||||||
|
|
||||||
|
user_prompt = f"""
|
||||||
|
Given the following conversation, extract entity nodes from the CURRENT MESSAGE that are explicitly or implicitly mentioned:
|
||||||
|
|
||||||
|
Conversation:
|
||||||
|
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
|
||||||
|
<CURRENT MESSAGE>
|
||||||
|
{context["episode_content"]}
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
2. Extract significant entities, concepts, or actors mentioned in the conversation.
|
||||||
|
3. Provide concise but informative summaries for each extracted node.
|
||||||
|
4. Avoid creating nodes for relationships or actions.
|
||||||
|
5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
|
||||||
|
6. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
|
||||||
|
|
||||||
|
Respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"extracted_nodes": [
|
||||||
|
{{
|
||||||
|
"name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
|
||||||
|
"labels": ["Entity", "OptionalAdditionalLabel"],
|
||||||
|
"summary": "Brief summary of the node's role or significance"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
Message(role='system', content=sys_prompt),
|
||||||
|
Message(role='user', content=user_prompt),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
versions: Versions = {
|
||||||
|
'v1': v1,
|
||||||
|
'v2': v2,
|
||||||
|
'extract_json': extract_json,
|
||||||
|
'extract_text': extract_text,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,29 @@ async def extract_message_nodes(
|
||||||
return extracted_node_data
|
return extracted_node_data
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_text_nodes(
|
||||||
|
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
# Prepare context for LLM
|
||||||
|
context = {
|
||||||
|
'episode_content': episode.content,
|
||||||
|
'episode_timestamp': episode.valid_at.isoformat(),
|
||||||
|
'previous_episodes': [
|
||||||
|
{
|
||||||
|
'content': ep.content,
|
||||||
|
'timestamp': ep.valid_at.isoformat(),
|
||||||
|
}
|
||||||
|
for ep in previous_episodes
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(
|
||||||
|
prompt_library.extract_nodes.extract_text(context)
|
||||||
|
)
|
||||||
|
extracted_node_data = llm_response.get('extracted_nodes', [])
|
||||||
|
return extracted_node_data
|
||||||
|
|
||||||
|
|
||||||
async def extract_json_nodes(
|
async def extract_json_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
|
|
@ -73,8 +96,10 @@ async def extract_nodes(
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
start = time()
|
start = time()
|
||||||
extracted_node_data: list[dict[str, Any]] = []
|
extracted_node_data: list[dict[str, Any]] = []
|
||||||
if episode.source in [EpisodeType.message, EpisodeType.text]:
|
if episode.source == EpisodeType.message:
|
||||||
extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
|
extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
|
||||||
|
elif episode.source == EpisodeType.text:
|
||||||
|
extracted_node_data = await extract_text_nodes(llm_client, episode, previous_episodes)
|
||||||
elif episode.source == EpisodeType.json:
|
elif episode.source == EpisodeType.json:
|
||||||
extracted_node_data = await extract_json_nodes(llm_client, episode)
|
extracted_node_data = await extract_json_nodes(llm_client, episode)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue