normalize string formatting in dedupe_nodes.py to use single quotes

This commit is contained in:
Daniel Chalef 2025-09-27 14:01:47 -07:00
parent 23511f3b5e
commit ad384372a7

View file

@ -23,25 +23,23 @@ from .prompt_helpers import to_prompt_json
class NodeDuplicate(BaseModel): class NodeDuplicate(BaseModel):
id: int = Field(..., description="integer id of the entity") id: int = Field(..., description='integer id of the entity')
duplicate_idx: int = Field( duplicate_idx: int = Field(
..., ...,
description="idx of the duplicate entity. If no duplicate entities are found, default to -1.", description='idx of the duplicate entity. If no duplicate entities are found, default to -1.',
) )
name: str = Field( name: str = Field(
..., ...,
description="Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.", description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
) )
duplicates: list[int] = Field( duplicates: list[int] = Field(
..., ...,
description="idx of all entities that are a duplicate of the entity with the above id.", description='idx of all entities that are a duplicate of the entity with the above id.',
) )
class NodeResolutions(BaseModel): class NodeResolutions(BaseModel):
entity_resolutions: list[NodeDuplicate] = Field( entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
..., description="List of resolved nodes"
)
class Prompt(Protocol): class Prompt(Protocol):
@ -59,11 +57,11 @@ class Versions(TypedDict):
def node(context: dict[str, Any]) -> list[Message]: def node(context: dict[str, Any]) -> list[Message]:
return [ return [
Message( Message(
role="system", role='system',
content="You are a helpful assistant that determines whether or not a NEW ENTITY is a duplicate of any EXISTING ENTITIES.", content='You are a helpful assistant that determines whether or not a NEW ENTITY is a duplicate of any EXISTING ENTITIES.',
), ),
Message( Message(
role="user", role='user',
content=f""" content=f"""
<PREVIOUS MESSAGES> <PREVIOUS MESSAGES>
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', False), indent=2)} {to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
@ -119,12 +117,12 @@ def node(context: dict[str, Any]) -> list[Message]:
def nodes(context: dict[str, Any]) -> list[Message]: def nodes(context: dict[str, Any]) -> list[Message]:
return [ return [
Message( Message(
role="system", role='system',
content="You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates" content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
" of existing entities.", ' of existing entities.',
), ),
Message( Message(
role="user", role='user',
content=f""" content=f"""
<PREVIOUS MESSAGES> <PREVIOUS MESSAGES>
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', True), indent=2)} {to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
@ -189,11 +187,11 @@ def nodes(context: dict[str, Any]) -> list[Message]:
def node_list(context: dict[str, Any]) -> list[Message]: def node_list(context: dict[str, Any]) -> list[Message]:
return [ return [
Message( Message(
role="system", role='system',
content="You are a helpful assistant that de-duplicates nodes from node lists.", content='You are a helpful assistant that de-duplicates nodes from node lists.',
), ),
Message( Message(
role="user", role='user',
content=f""" content=f"""
Given the following context, deduplicate a list of nodes: Given the following context, deduplicate a list of nodes:
@ -223,4 +221,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
] ]
versions: Versions = {"node": node, "node_list": node_list, "nodes": nodes} versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}