add extract nodes from text prompt (#106)

This commit is contained in:
Preston Rasmussen 2024-09-11 12:06:08 -04:00 committed by GitHub
parent b214baa85f
commit 4122d350a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 69 additions and 2 deletions

View file

@ -24,12 +24,14 @@ class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
extract_json: PromptVersion
extract_text: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
extract_json: PromptFunction
extract_text: PromptFunction
def v1(context: dict[str, Any]) -> list[Message]:
@ -144,4 +146,44 @@ Respond with a JSON object in the following format:
]
versions: Versions = {'v1': v1, 'v2': v2, 'extract_json': extract_json}
def extract_text(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
user_prompt = f"""
Given the following conversation, extract entity nodes from the CURRENT MESSAGE that are explicitly or implicitly mentioned:
Conversation:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
<CURRENT MESSAGE>
{context["episode_content"]}
Guidelines:
2. Extract significant entities, concepts, or actors mentioned in the conversation.
3. Provide concise but informative summaries for each extracted node.
4. Avoid creating nodes for relationships or actions.
5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
6. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
Respond with a JSON object in the following format:
{{
"extracted_nodes": [
{{
"name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
"labels": ["Entity", "OptionalAdditionalLabel"],
"summary": "Brief summary of the node's role or significance"
}}
]
}}
"""
return [
Message(role='system', content=sys_prompt),
Message(role='user', content=user_prompt),
]
versions: Versions = {
'v1': v1,
'v2': v2,
'extract_json': extract_json,
'extract_text': extract_text,
}

View file

@ -48,6 +48,29 @@ async def extract_message_nodes(
return extracted_node_data
async def extract_text_nodes(
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
) -> list[dict[str, Any]]:
# Prepare context for LLM
context = {
'episode_content': episode.content,
'episode_timestamp': episode.valid_at.isoformat(),
'previous_episodes': [
{
'content': ep.content,
'timestamp': ep.valid_at.isoformat(),
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_text(context)
)
extracted_node_data = llm_response.get('extracted_nodes', [])
return extracted_node_data
async def extract_json_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
@ -73,8 +96,10 @@ async def extract_nodes(
) -> list[EntityNode]:
start = time()
extracted_node_data: list[dict[str, Any]] = []
if episode.source in [EpisodeType.message, EpisodeType.text]:
if episode.source == EpisodeType.message:
extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
elif episode.source == EpisodeType.text:
extracted_node_data = await extract_text_nodes(llm_client, episode, previous_episodes)
elif episode.source == EpisodeType.json:
extracted_node_data = await extract_json_nodes(llm_client, episode)