normalize string formatting in dedupe_nodes.py to use single quotes
This commit is contained in:
parent
23511f3b5e
commit
ad384372a7
1 changed files with 16 additions and 18 deletions
|
|
@ -23,25 +23,23 @@ from .prompt_helpers import to_prompt_json
|
||||||
|
|
||||||
|
|
||||||
class NodeDuplicate(BaseModel):
|
class NodeDuplicate(BaseModel):
|
||||||
id: int = Field(..., description="integer id of the entity")
|
id: int = Field(..., description='integer id of the entity')
|
||||||
duplicate_idx: int = Field(
|
duplicate_idx: int = Field(
|
||||||
...,
|
...,
|
||||||
description="idx of the duplicate entity. If no duplicate entities are found, default to -1.",
|
description='idx of the duplicate entity. If no duplicate entities are found, default to -1.',
|
||||||
)
|
)
|
||||||
name: str = Field(
|
name: str = Field(
|
||||||
...,
|
...,
|
||||||
description="Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.",
|
description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
|
||||||
)
|
)
|
||||||
duplicates: list[int] = Field(
|
duplicates: list[int] = Field(
|
||||||
...,
|
...,
|
||||||
description="idx of all entities that are a duplicate of the entity with the above id.",
|
description='idx of all entities that are a duplicate of the entity with the above id.',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NodeResolutions(BaseModel):
|
class NodeResolutions(BaseModel):
|
||||||
entity_resolutions: list[NodeDuplicate] = Field(
|
entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
|
||||||
..., description="List of resolved nodes"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
|
|
@ -59,11 +57,11 @@ class Versions(TypedDict):
|
||||||
def node(context: dict[str, Any]) -> list[Message]:
|
def node(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role="system",
|
role='system',
|
||||||
content="You are a helpful assistant that determines whether or not a NEW ENTITY is a duplicate of any EXISTING ENTITIES.",
|
content='You are a helpful assistant that determines whether or not a NEW ENTITY is a duplicate of any EXISTING ENTITIES.',
|
||||||
),
|
),
|
||||||
Message(
|
Message(
|
||||||
role="user",
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
<PREVIOUS MESSAGES>
|
<PREVIOUS MESSAGES>
|
||||||
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
||||||
|
|
@ -119,12 +117,12 @@ def node(context: dict[str, Any]) -> list[Message]:
|
||||||
def nodes(context: dict[str, Any]) -> list[Message]:
|
def nodes(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role="system",
|
role='system',
|
||||||
content="You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates"
|
content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
|
||||||
" of existing entities.",
|
' of existing entities.',
|
||||||
),
|
),
|
||||||
Message(
|
Message(
|
||||||
role="user",
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
<PREVIOUS MESSAGES>
|
<PREVIOUS MESSAGES>
|
||||||
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
|
|
@ -189,11 +187,11 @@ def nodes(context: dict[str, Any]) -> list[Message]:
|
||||||
def node_list(context: dict[str, Any]) -> list[Message]:
|
def node_list(context: dict[str, Any]) -> list[Message]:
|
||||||
return [
|
return [
|
||||||
Message(
|
Message(
|
||||||
role="system",
|
role='system',
|
||||||
content="You are a helpful assistant that de-duplicates nodes from node lists.",
|
content='You are a helpful assistant that de-duplicates nodes from node lists.',
|
||||||
),
|
),
|
||||||
Message(
|
Message(
|
||||||
role="user",
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
Given the following context, deduplicate a list of nodes:
|
Given the following context, deduplicate a list of nodes:
|
||||||
|
|
||||||
|
|
@ -223,4 +221,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {"node": node, "node_list": node_list, "nodes": nodes}
|
versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue