From a6d63f0c0d4ec0ac0d11070664cf3a2116338949 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:51:13 -0700 Subject: [PATCH] Add text episode type (#46) Add a new `text` episode type and update the `extract_nodes` function to handle it. * **EpisodeType Enum:** - Add `text` to the `EpisodeType` enum in `graphiti_core/nodes.py`. - Update the `from_str` method to handle the `text` episode type. * **extract_nodes Function:** - Update the `extract_nodes` function in `graphiti_core/utils/maintenance/node_operations.py` to handle the `text` episode type. - Use the `message` type prompt for both `message` and `text` episodes. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/getzep/graphiti?shareId=XXXX-XXXX-XXXX-XXXX). --- graphiti_core/nodes.py | 5 +++++ graphiti_core/utils/maintenance/node_operations.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 6a4df2bb..bf053569 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -46,10 +46,13 @@ class EpisodeType(Enum): or "assistant: I'm doing well, thank you for asking." json : str Represents an episode containing a JSON string object with structured data. + text : str + Represents a plain text episode. """ message = 'message' json = 'json' + text = 'text' @staticmethod def from_str(episode_type: str): @@ -57,6 +60,8 @@ class EpisodeType(Enum): return EpisodeType.message if episode_type == 'json': return EpisodeType.json + if episode_type == 'text': + return EpisodeType.text logger.error(f'Episode type: {episode_type} not implemented') raise NotImplementedError diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index ab9ffc51..2dfdaccb 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -72,7 +72,7 @@ async def extract_nodes( ) -> list[EntityNode]: start = time() extracted_node_data: list[dict[str, Any]] = [] - if episode.source == EpisodeType.message: + if episode.source in [EpisodeType.message, EpisodeType.text]: extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes) elif episode.source == EpisodeType.json: extracted_node_data = await extract_json_nodes(llm_client, episode)