Merge branch 'getzep:main' into main
This commit is contained in:
commit
4f8eb310f2
20 changed files with 176 additions and 69 deletions
|
|
@ -123,6 +123,7 @@ class Graphiti:
|
||||||
store_raw_episode_content: bool = True,
|
store_raw_episode_content: bool = True,
|
||||||
graph_driver: GraphDriver | None = None,
|
graph_driver: GraphDriver | None = None,
|
||||||
max_coroutines: int | None = None,
|
max_coroutines: int | None = None,
|
||||||
|
ensure_ascii: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a Graphiti instance.
|
Initialize a Graphiti instance.
|
||||||
|
|
@ -155,6 +156,10 @@ class Graphiti:
|
||||||
max_coroutines : int | None, optional
|
max_coroutines : int | None, optional
|
||||||
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
|
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
|
||||||
If not set, the Graphiti default is used.
|
If not set, the Graphiti default is used.
|
||||||
|
ensure_ascii : bool, optional
|
||||||
|
Whether to escape non-ASCII characters in JSON serialization for prompts. Defaults to False.
|
||||||
|
Set to False to preserve non-ASCII characters (e.g., Korean, Japanese, Chinese) in their
|
||||||
|
original form, making them readable in LLM logs and improving model understanding.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|
@ -184,6 +189,7 @@ class Graphiti:
|
||||||
|
|
||||||
self.store_raw_episode_content = store_raw_episode_content
|
self.store_raw_episode_content = store_raw_episode_content
|
||||||
self.max_coroutines = max_coroutines
|
self.max_coroutines = max_coroutines
|
||||||
|
self.ensure_ascii = ensure_ascii
|
||||||
if llm_client:
|
if llm_client:
|
||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
else:
|
else:
|
||||||
|
|
@ -202,6 +208,7 @@ class Graphiti:
|
||||||
llm_client=self.llm_client,
|
llm_client=self.llm_client,
|
||||||
embedder=self.embedder,
|
embedder=self.embedder,
|
||||||
cross_encoder=self.cross_encoder,
|
cross_encoder=self.cross_encoder,
|
||||||
|
ensure_ascii=self.ensure_ascii,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Capture telemetry event
|
# Capture telemetry event
|
||||||
|
|
@ -541,7 +548,9 @@ class Graphiti:
|
||||||
if update_communities:
|
if update_communities:
|
||||||
communities, community_edges = await semaphore_gather(
|
communities, community_edges = await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
update_community(self.driver, self.llm_client, self.embedder, node)
|
update_community(
|
||||||
|
self.driver, self.llm_client, self.embedder, node, self.ensure_ascii
|
||||||
|
)
|
||||||
for node in nodes
|
for node in nodes
|
||||||
],
|
],
|
||||||
max_coroutines=self.max_coroutines,
|
max_coroutines=self.max_coroutines,
|
||||||
|
|
@ -1021,6 +1030,8 @@ class Graphiti:
|
||||||
entity_edges=[],
|
entity_edges=[],
|
||||||
group_id=edge.group_id,
|
group_id=edge.group_id,
|
||||||
),
|
),
|
||||||
|
None,
|
||||||
|
self.ensure_ascii,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
|
edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
|
||||||
|
|
|
||||||
|
|
@ -27,5 +27,6 @@ class GraphitiClients(BaseModel):
|
||||||
llm_client: LLMClient
|
llm_client: LLMClient
|
||||||
embedder: EmbedderClient
|
embedder: EmbedderClient
|
||||||
cross_encoder: CrossEncoderClient
|
cross_encoder: CrossEncoderClient
|
||||||
|
ensure_ascii: bool = False
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ class Node(BaseModel, ABC):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
|
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
|
||||||
if driver.provider == GraphProvider.FALKORDB:
|
if driver.provider == GraphProvider.FALKORDB:
|
||||||
for label in ['Entity', 'Episodic', 'Community']:
|
for label in ['Entity', 'Episodic', 'Community']:
|
||||||
await driver.execute_query(
|
await driver.execute_query(
|
||||||
|
|
@ -129,13 +129,18 @@ class Node(BaseModel, ABC):
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await driver.execute_query(
|
async with driver.session() as session:
|
||||||
"""
|
await session.run(
|
||||||
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
"""
|
||||||
DETACH DELETE n
|
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
||||||
""",
|
CALL {
|
||||||
group_id=group_id,
|
WITH n
|
||||||
)
|
DETACH DELETE n
|
||||||
|
} IN TRANSACTIONS OF $batch_size ROWS
|
||||||
|
""",
|
||||||
|
group_id=group_id,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,12 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
from .prompt_helpers import to_prompt_json
|
||||||
|
|
||||||
|
|
||||||
class EdgeDuplicate(BaseModel):
|
class EdgeDuplicate(BaseModel):
|
||||||
|
|
@ -67,11 +67,11 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
||||||
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
|
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
|
||||||
|
|
||||||
<EXISTING EDGES>
|
<EXISTING EDGES>
|
||||||
{json.dumps(context['related_edges'], indent=2)}
|
{to_prompt_json(context['related_edges'], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
||||||
</EXISTING EDGES>
|
</EXISTING EDGES>
|
||||||
|
|
||||||
<NEW EDGE>
|
<NEW EDGE>
|
||||||
{json.dumps(context['extracted_edges'], indent=2)}
|
{to_prompt_json(context['extracted_edges'], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
||||||
</NEW EDGE>
|
</NEW EDGE>
|
||||||
|
|
||||||
Task:
|
Task:
|
||||||
|
|
@ -98,7 +98,7 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
|
||||||
Given the following context, find all of the duplicates in a list of facts:
|
Given the following context, find all of the duplicates in a list of facts:
|
||||||
|
|
||||||
Facts:
|
Facts:
|
||||||
{json.dumps(context['edges'], indent=2)}
|
{to_prompt_json(context['edges'], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
||||||
|
|
||||||
Task:
|
Task:
|
||||||
If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
|
If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,12 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
from .prompt_helpers import to_prompt_json
|
||||||
|
|
||||||
|
|
||||||
class NodeDuplicate(BaseModel):
|
class NodeDuplicate(BaseModel):
|
||||||
|
|
@ -64,20 +64,20 @@ def node(context: dict[str, Any]) -> list[Message]:
|
||||||
role='user',
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
<PREVIOUS MESSAGES>
|
<PREVIOUS MESSAGES>
|
||||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
||||||
</PREVIOUS MESSAGES>
|
</PREVIOUS MESSAGES>
|
||||||
<CURRENT MESSAGE>
|
<CURRENT MESSAGE>
|
||||||
{context['episode_content']}
|
{context['episode_content']}
|
||||||
</CURRENT MESSAGE>
|
</CURRENT MESSAGE>
|
||||||
<NEW ENTITY>
|
<NEW ENTITY>
|
||||||
{json.dumps(context['extracted_node'], indent=2)}
|
{to_prompt_json(context['extracted_node'], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
||||||
</NEW ENTITY>
|
</NEW ENTITY>
|
||||||
<ENTITY TYPE DESCRIPTION>
|
<ENTITY TYPE DESCRIPTION>
|
||||||
{json.dumps(context['entity_type_description'], indent=2)}
|
{to_prompt_json(context['entity_type_description'], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
||||||
</ENTITY TYPE DESCRIPTION>
|
</ENTITY TYPE DESCRIPTION>
|
||||||
|
|
||||||
<EXISTING ENTITIES>
|
<EXISTING ENTITIES>
|
||||||
{json.dumps(context['existing_nodes'], indent=2)}
|
{to_prompt_json(context['existing_nodes'], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
||||||
</EXISTING ENTITIES>
|
</EXISTING ENTITIES>
|
||||||
|
|
||||||
Given the above EXISTING ENTITIES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW ENTITY extracted from the conversation
|
Given the above EXISTING ENTITIES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW ENTITY extracted from the conversation
|
||||||
|
|
@ -114,7 +114,7 @@ def nodes(context: dict[str, Any]) -> list[Message]:
|
||||||
role='user',
|
role='user',
|
||||||
content=f"""
|
content=f"""
|
||||||
<PREVIOUS MESSAGES>
|
<PREVIOUS MESSAGES>
|
||||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
</PREVIOUS MESSAGES>
|
</PREVIOUS MESSAGES>
|
||||||
<CURRENT MESSAGE>
|
<CURRENT MESSAGE>
|
||||||
{context['episode_content']}
|
{context['episode_content']}
|
||||||
|
|
@ -139,11 +139,11 @@ def nodes(context: dict[str, Any]) -> list[Message]:
|
||||||
}}
|
}}
|
||||||
|
|
||||||
<ENTITIES>
|
<ENTITIES>
|
||||||
{json.dumps(context['extracted_nodes'], indent=2)}
|
{to_prompt_json(context['extracted_nodes'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
</ENTITIES>
|
</ENTITIES>
|
||||||
|
|
||||||
<EXISTING ENTITIES>
|
<EXISTING ENTITIES>
|
||||||
{json.dumps(context['existing_nodes'], indent=2)}
|
{to_prompt_json(context['existing_nodes'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
</EXISTING ENTITIES>
|
</EXISTING ENTITIES>
|
||||||
|
|
||||||
For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
|
For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
|
||||||
|
|
@ -180,7 +180,7 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
||||||
Given the following context, deduplicate a list of nodes:
|
Given the following context, deduplicate a list of nodes:
|
||||||
|
|
||||||
Nodes:
|
Nodes:
|
||||||
{json.dumps(context['nodes'], indent=2)}
|
{to_prompt_json(context['nodes'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
|
|
||||||
Task:
|
Task:
|
||||||
1. Group nodes together such that all duplicate nodes are in the same list of uuids
|
1. Group nodes together such that all duplicate nodes are in the same list of uuids
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,12 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
from .prompt_helpers import to_prompt_json
|
||||||
|
|
||||||
|
|
||||||
class QueryExpansion(BaseModel):
|
class QueryExpansion(BaseModel):
|
||||||
|
|
@ -68,7 +68,7 @@ def query_expansion(context: dict[str, Any]) -> list[Message]:
|
||||||
Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
|
Bob is asking Alice a question, are you able to rephrase the question into a simpler one about Alice in the third person
|
||||||
that maintains the relevant context?
|
that maintains the relevant context?
|
||||||
<QUESTION>
|
<QUESTION>
|
||||||
{json.dumps(context['query'])}
|
{to_prompt_json(context['query'], ensure_ascii=context.get('ensure_ascii', False))}
|
||||||
</QUESTION>
|
</QUESTION>
|
||||||
"""
|
"""
|
||||||
return [
|
return [
|
||||||
|
|
@ -84,10 +84,10 @@ def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
||||||
Your task is to briefly answer the question in the way that you think Alice would answer the question.
|
Your task is to briefly answer the question in the way that you think Alice would answer the question.
|
||||||
You are given the following entity summaries and facts to help you determine the answer to your question.
|
You are given the following entity summaries and facts to help you determine the answer to your question.
|
||||||
<ENTITY_SUMMARIES>
|
<ENTITY_SUMMARIES>
|
||||||
{json.dumps(context['entity_summaries'])}
|
{to_prompt_json(context['entity_summaries'], ensure_ascii=context.get('ensure_ascii', False))}
|
||||||
</ENTITY_SUMMARIES>
|
</ENTITY_SUMMARIES>
|
||||||
<FACTS>
|
<FACTS>
|
||||||
{json.dumps(context['facts'])}
|
{to_prompt_json(context['facts'], ensure_ascii=context.get('ensure_ascii', False))}
|
||||||
</FACTS>
|
</FACTS>
|
||||||
<QUESTION>
|
<QUESTION>
|
||||||
{context['query']}
|
{context['query']}
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,12 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
from .prompt_helpers import to_prompt_json
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel):
|
class Edge(BaseModel):
|
||||||
|
|
@ -73,7 +73,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
||||||
</FACT TYPES>
|
</FACT TYPES>
|
||||||
|
|
||||||
<PREVIOUS_MESSAGES>
|
<PREVIOUS_MESSAGES>
|
||||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
||||||
</PREVIOUS_MESSAGES>
|
</PREVIOUS_MESSAGES>
|
||||||
|
|
||||||
<CURRENT_MESSAGE>
|
<CURRENT_MESSAGE>
|
||||||
|
|
@ -132,7 +132,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
||||||
user_prompt = f"""
|
user_prompt = f"""
|
||||||
<PREVIOUS MESSAGES>
|
<PREVIOUS MESSAGES>
|
||||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
||||||
</PREVIOUS MESSAGES>
|
</PREVIOUS MESSAGES>
|
||||||
<CURRENT MESSAGE>
|
<CURRENT MESSAGE>
|
||||||
{context['episode_content']}
|
{context['episode_content']}
|
||||||
|
|
@ -166,7 +166,7 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
||||||
content=f"""
|
content=f"""
|
||||||
|
|
||||||
<MESSAGE>
|
<MESSAGE>
|
||||||
{json.dumps(context['episode_content'], indent=2)}
|
{to_prompt_json(context['episode_content'], ensure_ascii=context.get('ensure_ascii', False), indent=2)}
|
||||||
</MESSAGE>
|
</MESSAGE>
|
||||||
<REFERENCE TIME>
|
<REFERENCE TIME>
|
||||||
{context['reference_time']}
|
{context['reference_time']}
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,12 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
from .prompt_helpers import to_prompt_json
|
||||||
|
|
||||||
|
|
||||||
class ExtractedEntity(BaseModel):
|
class ExtractedEntity(BaseModel):
|
||||||
|
|
@ -89,7 +89,7 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
|
||||||
</ENTITY TYPES>
|
</ENTITY TYPES>
|
||||||
|
|
||||||
<PREVIOUS MESSAGES>
|
<PREVIOUS MESSAGES>
|
||||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
</PREVIOUS MESSAGES>
|
</PREVIOUS MESSAGES>
|
||||||
|
|
||||||
<CURRENT MESSAGE>
|
<CURRENT MESSAGE>
|
||||||
|
|
@ -196,7 +196,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
||||||
user_prompt = f"""
|
user_prompt = f"""
|
||||||
<PREVIOUS MESSAGES>
|
<PREVIOUS MESSAGES>
|
||||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
</PREVIOUS MESSAGES>
|
</PREVIOUS MESSAGES>
|
||||||
<CURRENT MESSAGE>
|
<CURRENT MESSAGE>
|
||||||
{context['episode_content']}
|
{context['episode_content']}
|
||||||
|
|
@ -220,7 +220,7 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
|
||||||
|
|
||||||
user_prompt = f"""
|
user_prompt = f"""
|
||||||
<PREVIOUS MESSAGES>
|
<PREVIOUS MESSAGES>
|
||||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
{to_prompt_json([ep for ep in context['previous_episodes']], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
</PREVIOUS MESSAGES>
|
</PREVIOUS MESSAGES>
|
||||||
<CURRENT MESSAGE>
|
<CURRENT MESSAGE>
|
||||||
{context['episode_content']}
|
{context['episode_content']}
|
||||||
|
|
@ -258,8 +258,8 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
||||||
content=f"""
|
content=f"""
|
||||||
|
|
||||||
<MESSAGES>
|
<MESSAGES>
|
||||||
{json.dumps(context['previous_episodes'], indent=2)}
|
{to_prompt_json(context['previous_episodes'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
{json.dumps(context['episode_content'], indent=2)}
|
{to_prompt_json(context['episode_content'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
</MESSAGES>
|
</MESSAGES>
|
||||||
|
|
||||||
Given the above MESSAGES and the following ENTITY, update any of its attributes based on the information provided
|
Given the above MESSAGES and the following ENTITY, update any of its attributes based on the information provided
|
||||||
|
|
@ -288,8 +288,8 @@ def extract_summary(context: dict[str, Any]) -> list[Message]:
|
||||||
content=f"""
|
content=f"""
|
||||||
|
|
||||||
<MESSAGES>
|
<MESSAGES>
|
||||||
{json.dumps(context['previous_episodes'], indent=2)}
|
{to_prompt_json(context['previous_episodes'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
{json.dumps(context['episode_content'], indent=2)}
|
{to_prompt_json(context['episode_content'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
</MESSAGES>
|
</MESSAGES>
|
||||||
|
|
||||||
Given the above MESSAGES and the following ENTITY, update the summary that combines relevant information about the entity
|
Given the above MESSAGES and the following ENTITY, update the summary that combines relevant information about the entity
|
||||||
|
|
|
||||||
|
|
@ -1 +1,24 @@
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
DO_NOT_ESCAPE_UNICODE = '\nDo not escape unicode characters.\n'
|
DO_NOT_ESCAPE_UNICODE = '\nDo not escape unicode characters.\n'
|
||||||
|
|
||||||
|
|
||||||
|
def to_prompt_json(data: Any, ensure_ascii: bool = True, indent: int = 2) -> str:
|
||||||
|
"""
|
||||||
|
Serialize data to JSON for use in prompts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The data to serialize
|
||||||
|
ensure_ascii: If True, escape non-ASCII characters. If False, preserve them.
|
||||||
|
indent: Number of spaces for indentation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string representation of the data
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
When ensure_ascii=False, non-ASCII characters (e.g., Korean, Japanese, Chinese)
|
||||||
|
are preserved in their original form in the prompt, making them readable
|
||||||
|
in LLM logs and improving model understanding.
|
||||||
|
"""
|
||||||
|
return json.dumps(data, ensure_ascii=ensure_ascii, indent=indent)
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,12 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, Protocol, TypedDict
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .models import Message, PromptFunction, PromptVersion
|
from .models import Message, PromptFunction, PromptVersion
|
||||||
|
from .prompt_helpers import to_prompt_json
|
||||||
|
|
||||||
|
|
||||||
class Summary(BaseModel):
|
class Summary(BaseModel):
|
||||||
|
|
@ -59,7 +59,7 @@ def summarize_pair(context: dict[str, Any]) -> list[Message]:
|
||||||
Summaries must be under 250 words.
|
Summaries must be under 250 words.
|
||||||
|
|
||||||
Summaries:
|
Summaries:
|
||||||
{json.dumps(context['node_summaries'], indent=2)}
|
{to_prompt_json(context['node_summaries'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
@ -76,8 +76,8 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
||||||
content=f"""
|
content=f"""
|
||||||
|
|
||||||
<MESSAGES>
|
<MESSAGES>
|
||||||
{json.dumps(context['previous_episodes'], indent=2)}
|
{to_prompt_json(context['previous_episodes'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
{json.dumps(context['episode_content'], indent=2)}
|
{to_prompt_json(context['episode_content'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
</MESSAGES>
|
</MESSAGES>
|
||||||
|
|
||||||
Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
|
Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
|
||||||
|
|
@ -100,7 +100,7 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
||||||
</ENTITY CONTEXT>
|
</ENTITY CONTEXT>
|
||||||
|
|
||||||
<ATTRIBUTES>
|
<ATTRIBUTES>
|
||||||
{json.dumps(context['attributes'], indent=2)}
|
{to_prompt_json(context['attributes'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
</ATTRIBUTES>
|
</ATTRIBUTES>
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
|
|
@ -120,7 +120,7 @@ def summary_description(context: dict[str, Any]) -> list[Message]:
|
||||||
Summaries must be under 250 words.
|
Summaries must be under 250 words.
|
||||||
|
|
||||||
Summary:
|
Summary:
|
||||||
{json.dumps(context['summary'], indent=2)}
|
{to_prompt_json(context['summary'], ensure_ascii=context.get('ensure_ascii', True), indent=2)}
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ class ComparisonOperator(Enum):
|
||||||
|
|
||||||
|
|
||||||
class DateFilter(BaseModel):
|
class DateFilter(BaseModel):
|
||||||
date: datetime = Field(description='A datetime to filter on')
|
date: datetime | None = Field(description='A datetime to filter on')
|
||||||
comparison_operator: ComparisonOperator = Field(
|
comparison_operator: ComparisonOperator = Field(
|
||||||
description='Comparison operator for date filter'
|
description='Comparison operator for date filter'
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,8 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
|
from graphiti_core.prompts.prompt_helpers import to_prompt_json
|
||||||
from graphiti_core.search.search_config import SearchResults
|
from graphiti_core.search.search_config import SearchResults
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -25,7 +24,9 @@ def format_edge_date_range(edge: EntityEdge) -> str:
|
||||||
return f'{edge.valid_at if edge.valid_at else "date unknown"} - {(edge.invalid_at if edge.invalid_at else "present")}'
|
return f'{edge.valid_at if edge.valid_at else "date unknown"} - {(edge.invalid_at if edge.invalid_at else "present")}'
|
||||||
|
|
||||||
|
|
||||||
def search_results_to_context_string(search_results: SearchResults) -> str:
|
def search_results_to_context_string(
|
||||||
|
search_results: SearchResults, ensure_ascii: bool = False
|
||||||
|
) -> str:
|
||||||
"""Reformats a set of SearchResults into a single string to pass directly to an LLM as context"""
|
"""Reformats a set of SearchResults into a single string to pass directly to an LLM as context"""
|
||||||
fact_json = [
|
fact_json = [
|
||||||
{
|
{
|
||||||
|
|
@ -57,16 +58,16 @@ def search_results_to_context_string(search_results: SearchResults) -> str:
|
||||||
These are the most relevant facts and their valid and invalid dates. Facts are considered valid
|
These are the most relevant facts and their valid and invalid dates. Facts are considered valid
|
||||||
between their valid_at and invalid_at dates. Facts with an invalid_at date of "Present" are considered valid.
|
between their valid_at and invalid_at dates. Facts with an invalid_at date of "Present" are considered valid.
|
||||||
<FACTS>
|
<FACTS>
|
||||||
{json.dumps(fact_json, indent=12)}
|
{to_prompt_json(fact_json, ensure_ascii=ensure_ascii, indent=12)}
|
||||||
</FACTS>
|
</FACTS>
|
||||||
<ENTITIES>
|
<ENTITIES>
|
||||||
{json.dumps(entity_json, indent=12)}
|
{to_prompt_json(entity_json, ensure_ascii=ensure_ascii, indent=12)}
|
||||||
</ENTITIES>
|
</ENTITIES>
|
||||||
<EPISODES>
|
<EPISODES>
|
||||||
{json.dumps(episode_json, indent=12)}
|
{to_prompt_json(episode_json, ensure_ascii=ensure_ascii, indent=12)}
|
||||||
</EPISODES>
|
</EPISODES>
|
||||||
<COMMUNITIES>
|
<COMMUNITIES>
|
||||||
{json.dumps(community_json, indent=12)}
|
{to_prompt_json(community_json, ensure_ascii=ensure_ascii, indent=12)}
|
||||||
</COMMUNITIES>
|
</COMMUNITIES>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -343,7 +343,13 @@ async def dedupe_edges_bulk(
|
||||||
] = await semaphore_gather(
|
] = await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
resolve_extracted_edge(
|
resolve_extracted_edge(
|
||||||
clients.llm_client, edge, candidates, candidates, episode, edge_types
|
clients.llm_client,
|
||||||
|
edge,
|
||||||
|
candidates,
|
||||||
|
candidates,
|
||||||
|
episode,
|
||||||
|
edge_types,
|
||||||
|
clients.ensure_ascii,
|
||||||
)
|
)
|
||||||
for episode, edge, candidates in dedupe_tuples
|
for episode, edge, candidates in dedupe_tuples
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -122,9 +122,14 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
|
||||||
return clusters
|
return clusters
|
||||||
|
|
||||||
|
|
||||||
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
|
async def summarize_pair(
|
||||||
|
llm_client: LLMClient, summary_pair: tuple[str, str], ensure_ascii: bool = True
|
||||||
|
) -> str:
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
|
context = {
|
||||||
|
'node_summaries': [{'summary': summary} for summary in summary_pair],
|
||||||
|
'ensure_ascii': ensure_ascii,
|
||||||
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.summarize_nodes.summarize_pair(context), response_model=Summary
|
prompt_library.summarize_nodes.summarize_pair(context), response_model=Summary
|
||||||
|
|
@ -135,8 +140,13 @@ async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -
|
||||||
return pair_summary
|
return pair_summary
|
||||||
|
|
||||||
|
|
||||||
async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
|
async def generate_summary_description(
|
||||||
context = {'summary': summary}
|
llm_client: LLMClient, summary: str, ensure_ascii: bool = True
|
||||||
|
) -> str:
|
||||||
|
context = {
|
||||||
|
'summary': summary,
|
||||||
|
'ensure_ascii': ensure_ascii,
|
||||||
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.summarize_nodes.summary_description(context),
|
prompt_library.summarize_nodes.summary_description(context),
|
||||||
|
|
@ -149,7 +159,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
|
||||||
|
|
||||||
|
|
||||||
async def build_community(
|
async def build_community(
|
||||||
llm_client: LLMClient, community_cluster: list[EntityNode]
|
llm_client: LLMClient, community_cluster: list[EntityNode], ensure_ascii: bool = True
|
||||||
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
||||||
summaries = [entity.summary for entity in community_cluster]
|
summaries = [entity.summary for entity in community_cluster]
|
||||||
length = len(summaries)
|
length = len(summaries)
|
||||||
|
|
@ -161,7 +171,9 @@ async def build_community(
|
||||||
new_summaries: list[str] = list(
|
new_summaries: list[str] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
summarize_pair(
|
||||||
|
llm_client, (str(left_summary), str(right_summary)), ensure_ascii
|
||||||
|
)
|
||||||
for left_summary, right_summary in zip(
|
for left_summary, right_summary in zip(
|
||||||
summaries[: int(length / 2)], summaries[int(length / 2) :], strict=False
|
summaries[: int(length / 2)], summaries[int(length / 2) :], strict=False
|
||||||
)
|
)
|
||||||
|
|
@ -174,7 +186,7 @@ async def build_community(
|
||||||
length = len(summaries)
|
length = len(summaries)
|
||||||
|
|
||||||
summary = summaries[0]
|
summary = summaries[0]
|
||||||
name = await generate_summary_description(llm_client, summary)
|
name = await generate_summary_description(llm_client, summary, ensure_ascii)
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
community_node = CommunityNode(
|
community_node = CommunityNode(
|
||||||
name=name,
|
name=name,
|
||||||
|
|
@ -191,7 +203,10 @@ async def build_community(
|
||||||
|
|
||||||
|
|
||||||
async def build_communities(
|
async def build_communities(
|
||||||
driver: GraphDriver, llm_client: LLMClient, group_ids: list[str] | None
|
driver: GraphDriver,
|
||||||
|
llm_client: LLMClient,
|
||||||
|
group_ids: list[str] | None,
|
||||||
|
ensure_ascii: bool = True,
|
||||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||||
community_clusters = await get_community_clusters(driver, group_ids)
|
community_clusters = await get_community_clusters(driver, group_ids)
|
||||||
|
|
||||||
|
|
@ -199,7 +214,7 @@ async def build_communities(
|
||||||
|
|
||||||
async def limited_build_community(cluster):
|
async def limited_build_community(cluster):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
return await build_community(llm_client, cluster)
|
return await build_community(llm_client, cluster, ensure_ascii)
|
||||||
|
|
||||||
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
|
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
|
|
@ -285,15 +300,21 @@ async def determine_entity_community(
|
||||||
|
|
||||||
|
|
||||||
async def update_community(
|
async def update_community(
|
||||||
driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
|
driver: GraphDriver,
|
||||||
|
llm_client: LLMClient,
|
||||||
|
embedder: EmbedderClient,
|
||||||
|
entity: EntityNode,
|
||||||
|
ensure_ascii: bool = True,
|
||||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||||
community, is_new = await determine_entity_community(driver, entity)
|
community, is_new = await determine_entity_community(driver, entity)
|
||||||
|
|
||||||
if community is None:
|
if community is None:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
|
new_summary = await summarize_pair(
|
||||||
new_name = await generate_summary_description(llm_client, new_summary)
|
llm_client, (entity.summary, community.summary), ensure_ascii
|
||||||
|
)
|
||||||
|
new_name = await generate_summary_description(llm_client, new_summary, ensure_ascii)
|
||||||
|
|
||||||
community.summary = new_summary
|
community.summary = new_summary
|
||||||
community.name = new_name
|
community.name = new_name
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,7 @@ async def extract_edges(
|
||||||
'reference_time': episode.valid_at,
|
'reference_time': episode.valid_at,
|
||||||
'edge_types': edge_types_context,
|
'edge_types': edge_types_context,
|
||||||
'custom_prompt': '',
|
'custom_prompt': '',
|
||||||
|
'ensure_ascii': clients.ensure_ascii,
|
||||||
}
|
}
|
||||||
|
|
||||||
facts_missed = True
|
facts_missed = True
|
||||||
|
|
@ -311,6 +312,7 @@ async def resolve_extracted_edges(
|
||||||
existing_edges,
|
existing_edges,
|
||||||
episode,
|
episode,
|
||||||
extracted_edge_types,
|
extracted_edge_types,
|
||||||
|
clients.ensure_ascii,
|
||||||
)
|
)
|
||||||
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
||||||
extracted_edges,
|
extracted_edges,
|
||||||
|
|
@ -382,6 +384,7 @@ async def resolve_extracted_edge(
|
||||||
existing_edges: list[EntityEdge],
|
existing_edges: list[EntityEdge],
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
edge_types: dict[str, type[BaseModel]] | None = None,
|
edge_types: dict[str, type[BaseModel]] | None = None,
|
||||||
|
ensure_ascii: bool = True,
|
||||||
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
||||||
if len(related_edges) == 0 and len(existing_edges) == 0:
|
if len(related_edges) == 0 and len(existing_edges) == 0:
|
||||||
return extracted_edge, [], []
|
return extracted_edge, [], []
|
||||||
|
|
@ -415,6 +418,7 @@ async def resolve_extracted_edge(
|
||||||
'new_edge': extracted_edge.fact,
|
'new_edge': extracted_edge.fact,
|
||||||
'edge_invalidation_candidates': invalidation_edge_candidates_context,
|
'edge_invalidation_candidates': invalidation_edge_candidates_context,
|
||||||
'edge_types': edge_types_context,
|
'edge_types': edge_types_context,
|
||||||
|
'ensure_ascii': ensure_ascii,
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
|
|
@ -449,6 +453,7 @@ async def resolve_extracted_edge(
|
||||||
'episode_content': episode.content,
|
'episode_content': episode.content,
|
||||||
'reference_time': episode.valid_at,
|
'reference_time': episode.valid_at,
|
||||||
'fact': resolved_edge.fact,
|
'fact': resolved_edge.fact,
|
||||||
|
'ensure_ascii': ensure_ascii,
|
||||||
}
|
}
|
||||||
|
|
||||||
edge_model = edge_types.get(fact_type)
|
edge_model = edge_types.get(fact_type)
|
||||||
|
|
|
||||||
|
|
@ -48,12 +48,14 @@ async def extract_nodes_reflexion(
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
node_names: list[str],
|
node_names: list[str],
|
||||||
|
ensure_ascii: bool = False,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
'episode_content': episode.content,
|
'episode_content': episode.content,
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||||
'extracted_entities': node_names,
|
'extracted_entities': node_names,
|
||||||
|
'ensure_ascii': ensure_ascii,
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
|
|
@ -106,6 +108,7 @@ async def extract_nodes(
|
||||||
'custom_prompt': custom_prompt,
|
'custom_prompt': custom_prompt,
|
||||||
'entity_types': entity_types_context,
|
'entity_types': entity_types_context,
|
||||||
'source_description': episode.source_description,
|
'source_description': episode.source_description,
|
||||||
|
'ensure_ascii': clients.ensure_ascii,
|
||||||
}
|
}
|
||||||
|
|
||||||
while entities_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
|
while entities_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
|
||||||
|
|
@ -134,6 +137,7 @@ async def extract_nodes(
|
||||||
episode,
|
episode,
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
[entity.name for entity in extracted_entities],
|
[entity.name for entity in extracted_entities],
|
||||||
|
clients.ensure_ascii,
|
||||||
)
|
)
|
||||||
|
|
||||||
entities_missed = len(missing_entities) != 0
|
entities_missed = len(missing_entities) != 0
|
||||||
|
|
@ -244,6 +248,7 @@ async def resolve_extracted_nodes(
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes]
|
'previous_episodes': [ep.content for ep in previous_episodes]
|
||||||
if previous_episodes is not None
|
if previous_episodes is not None
|
||||||
else [],
|
else [],
|
||||||
|
'ensure_ascii': clients.ensure_ascii,
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
|
|
@ -309,6 +314,7 @@ async def extract_attributes_from_nodes(
|
||||||
entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
|
entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
|
||||||
if entity_types is not None
|
if entity_types is not None
|
||||||
else None,
|
else None,
|
||||||
|
clients.ensure_ascii,
|
||||||
)
|
)
|
||||||
for node in nodes
|
for node in nodes
|
||||||
]
|
]
|
||||||
|
|
@ -325,6 +331,7 @@ async def extract_attributes_from_node(
|
||||||
episode: EpisodicNode | None = None,
|
episode: EpisodicNode | None = None,
|
||||||
previous_episodes: list[EpisodicNode] | None = None,
|
previous_episodes: list[EpisodicNode] | None = None,
|
||||||
entity_type: type[BaseModel] | None = None,
|
entity_type: type[BaseModel] | None = None,
|
||||||
|
ensure_ascii: bool = False,
|
||||||
) -> EntityNode:
|
) -> EntityNode:
|
||||||
node_context: dict[str, Any] = {
|
node_context: dict[str, Any] = {
|
||||||
'name': node.name,
|
'name': node.name,
|
||||||
|
|
@ -339,6 +346,7 @@ async def extract_attributes_from_node(
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes]
|
'previous_episodes': [ep.content for ep in previous_episodes]
|
||||||
if previous_episodes is not None
|
if previous_episodes is not None
|
||||||
else [],
|
else [],
|
||||||
|
'ensure_ascii': ensure_ascii,
|
||||||
}
|
}
|
||||||
|
|
||||||
summary_context: dict[str, Any] = {
|
summary_context: dict[str, Any] = {
|
||||||
|
|
@ -347,6 +355,7 @@ async def extract_attributes_from_node(
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes]
|
'previous_episodes': [ep.content for ep in previous_episodes]
|
||||||
if previous_episodes is not None
|
if previous_episodes is not None
|
||||||
else [],
|
else [],
|
||||||
|
'ensure_ascii': ensure_ascii,
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = (
|
llm_response = (
|
||||||
|
|
|
||||||
|
|
@ -35,12 +35,14 @@ async def extract_edge_dates(
|
||||||
edge: EntityEdge,
|
edge: EntityEdge,
|
||||||
current_episode: EpisodicNode,
|
current_episode: EpisodicNode,
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
|
ensure_ascii: bool = False,
|
||||||
) -> tuple[datetime | None, datetime | None]:
|
) -> tuple[datetime | None, datetime | None]:
|
||||||
context = {
|
context = {
|
||||||
'edge_fact': edge.fact,
|
'edge_fact': edge.fact,
|
||||||
'current_episode': current_episode.content,
|
'current_episode': current_episode.content,
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||||
'reference_timestamp': current_episode.valid_at.isoformat(),
|
'reference_timestamp': current_episode.valid_at.isoformat(),
|
||||||
|
'ensure_ascii': ensure_ascii,
|
||||||
}
|
}
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_edge_dates.v1(context), response_model=EdgeDates
|
prompt_library.extract_edge_dates.v1(context), response_model=EdgeDates
|
||||||
|
|
@ -70,7 +72,10 @@ async def extract_edge_dates(
|
||||||
|
|
||||||
|
|
||||||
async def get_edge_contradictions(
|
async def get_edge_contradictions(
|
||||||
llm_client: LLMClient, new_edge: EntityEdge, existing_edges: list[EntityEdge]
|
llm_client: LLMClient,
|
||||||
|
new_edge: EntityEdge,
|
||||||
|
existing_edges: list[EntityEdge],
|
||||||
|
ensure_ascii: bool = False,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
@ -79,7 +84,11 @@ async def get_edge_contradictions(
|
||||||
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
||||||
]
|
]
|
||||||
|
|
||||||
context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context}
|
context = {
|
||||||
|
'new_edge': new_edge_context,
|
||||||
|
'existing_edges': existing_edge_context,
|
||||||
|
'ensure_ascii': ensure_ascii,
|
||||||
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.invalidate_edges.v2(context),
|
prompt_library.invalidate_edges.v2(context),
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.18.3"
|
version = "0.18.5"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
|
||||||
|
|
@ -287,6 +287,22 @@
|
||||||
"created_at": "2025-08-07T02:23:09Z",
|
"created_at": "2025-08-07T02:23:09Z",
|
||||||
"repoId": 840056306,
|
"repoId": 840056306,
|
||||||
"pullRequestNo": 812
|
"pullRequestNo": 812
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "vjeeva",
|
||||||
|
"id": 13189349,
|
||||||
|
"comment_id": 3165600173,
|
||||||
|
"created_at": "2025-08-07T20:24:08Z",
|
||||||
|
"repoId": 840056306,
|
||||||
|
"pullRequestNo": 814
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "liebertar",
|
||||||
|
"id": 99405438,
|
||||||
|
"comment_id": 3166905812,
|
||||||
|
"created_at": "2025-08-08T07:52:27Z",
|
||||||
|
"repoId": 840056306,
|
||||||
|
"pullRequestNo": 816
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -746,7 +746,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.18.3"
|
version = "0.18.4"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue