From fc4bf3bde20d263c228edee26a44be2887849a52 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:53:16 -0700 Subject: [PATCH 1/4] Implement retry for LLMClient (#44) * implement retry * chore: Refactor tenacity retry logic and improve LLMClient error handling * poetry * remove unnecessary try --- graphiti_core/llm_client/client.py | 24 +++++++++++++++++++++++- poetry.lock | 17 ++++++++++++++++- pyproject.toml | 1 + 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/graphiti_core/llm_client/client.py b/graphiti_core/llm_client/client.py index 02bd6f4f..5de06d76 100644 --- a/graphiti_core/llm_client/client.py +++ b/graphiti_core/llm_client/client.py @@ -20,7 +20,9 @@ import logging import typing from abc import ABC, abstractmethod +import httpx from diskcache import Cache +from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential from ..prompts.models import Message from .config import LLMConfig @@ -31,6 +33,12 @@ DEFAULT_CACHE_DIR = './llm_cache' logger = logging.getLogger(__name__) +def is_server_error(exception): + return ( + isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600 + ) + + class LLMClient(ABC): def __init__(self, config: LLMConfig | None, cache: bool = False): if config is None: @@ -47,6 +55,20 @@ class LLMClient(ABC): def get_embedder(self) -> typing.Any: pass + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception(is_server_error), + ) + async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]: + try: + return await self._generate_response(messages) + except httpx.HTTPStatusError as e: + if not is_server_error(e): + raise Exception(f'LLM request error: {e}') from e + else: + raise + @abstractmethod async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: pass @@ -66,7 +88,7 @@ class LLMClient(ABC): logger.debug(f'Cache hit for {cache_key}') return cached_response - response = await self._generate_response(messages) + response = await self._generate_response_with_retry(messages) if self.cache_enabled: self.cache_dir.set(cache_key, response) diff --git a/poetry.lock b/poetry.lock index 6d964a5f..22653da0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3253,6 +3253,21 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] +[[package]] +name = "tenacity" +version = "9.0.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, + {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "terminado" version = "0.18.1" @@ -3743,4 +3758,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "5b90bb6d58d36a2553f5410c418b179aa1c86b55078567c33aaa6fddf6a8c6c6" +content-hash = "001663dfc8078ad473675c994b15191db1f53a844e23f40ffa4a704379a61132" diff --git a/pyproject.toml b/pyproject.toml index 2456f13d..956e48d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ diskcache = "^5.6.3" arrow = "^1.3.0" openai = "^1.38.0" anthropic = "^0.34.1" +tenacity = "^9.0.0" [tool.poetry.dev-dependencies] pytest = "^8.3.2" From 2d01e5d7b7d2e91c0835f7d53f88597c97f7230b Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Mon, 26 Aug 2024 18:34:57 -0400 Subject: [PATCH 2/4] Search node centering (#45) * add new search reranker and update search * node distance reranking * format * rebase * no need for enumerate * mypy typing * defaultdict update * rrf prelim ranking --- graphiti_core/graphiti.py | 26 ++++++++++--- graphiti_core/search/search.py | 55 ++++++++++++++++++++-------- graphiti_core/search/search_utils.py | 42 ++++++++++++++++++++- 3 files changed, 101 insertions(+), 22 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index ec038f1e..6ff5b52b 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -26,7 +26,7 @@ from neo4j import AsyncGraphDatabase from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode -from graphiti_core.search.search import SearchConfig, hybrid_search +from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search from graphiti_core.search.search_utils import ( get_relevant_edges, get_relevant_nodes, @@ -515,7 +515,7 @@ class Graphiti: except Exception as e: raise e - async def search(self, query: str, num_results=10): + async def search(self, query: str, center_node_uuid: str | None = None, num_results=10): """ Perform a hybrid search on the knowledge graph. @@ -526,6 +526,8 @@ class Graphiti: ---------- query : str The search query string. + center_node_uuid: str, optional + Facts will be reranked based on proximity to this node num_results : int, optional The maximum number of results to return. Defaults to 10. @@ -543,7 +545,14 @@ class Graphiti: The search is performed using the current date and time as the reference point for temporal relevance. """ - search_config = SearchConfig(num_episodes=0, num_results=num_results) + reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance + search_config = SearchConfig( + num_episodes=0, + num_edges=num_results, + num_nodes=0, + search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity], + reranker=reranker, + ) edges = ( await hybrid_search( self.driver, @@ -551,6 +560,7 @@ class Graphiti: query, datetime.now(), search_config, + center_node_uuid, ) ).edges @@ -558,7 +568,13 @@ class Graphiti: return facts - async def _search(self, query: str, timestamp: datetime, config: SearchConfig): + async def _search( + self, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, + ): return await hybrid_search( - self.driver, self.llm_client.get_embedder(), query, timestamp, config + self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid ) diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 956ae65d..03111225 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -16,6 +16,7 @@ limitations under the License. import logging from datetime import datetime +from enum import Enum from time import time from neo4j import AsyncDriver @@ -28,6 +29,7 @@ from graphiti_core.search.search_utils import ( edge_fulltext_search, edge_similarity_search, get_mentioned_nodes, + node_distance_reranker, rrf, ) from graphiti_core.utils import retrieve_episodes @@ -36,12 +38,22 @@ from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW logger = logging.getLogger(__name__) +class SearchMethod(Enum): + cosine_similarity = 'cosine_similarity' + bm25 = 'bm25' + + +class Reranker(Enum): + rrf = 'reciprocal_rank_fusion' + node_distance = 'node_distance' + + class SearchConfig(BaseModel): - num_results: int = 10 + num_edges: int = 10 + num_nodes: int = 10 num_episodes: int = EPISODE_WINDOW_LEN - similarity_search: str = 'cosine' - text_search: str = 'BM25' - reranker: str = 'rrf' + search_methods: list[SearchMethod] + reranker: Reranker | None class SearchResults(BaseModel): @@ -51,7 +63,12 @@ class SearchResults(BaseModel): async def hybrid_search( - driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig + driver: AsyncDriver, + embedder, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ) -> SearchResults: start = time() @@ -65,11 +82,11 @@ async def hybrid_search( episodes.extend(await retrieve_episodes(driver, timestamp)) nodes.extend(await get_mentioned_nodes(driver, episodes)) - if config.text_search == 'BM25': + if SearchMethod.bm25 in config.search_methods: text_search = await edge_fulltext_search(query, driver) search_results.append(text_search) - if config.similarity_search == 'cosine': + if SearchMethod.cosine_similarity in config.search_methods: query_text = query.replace('\n', ' ') search_vector = ( (await embedder.create(input=[query_text], model='text-embedding-3-small')) @@ -80,19 +97,14 @@ async def hybrid_search( similarity_search = await edge_similarity_search(search_vector, driver) search_results.append(similarity_search) - if len(search_results) == 1: - edges = search_results[0] - - elif len(search_results) > 1 and config.reranker != 'rrf': + if len(search_results) > 1 and config.reranker is None: logger.exception('Multiple searches enabled without a reranker') raise Exception('Multiple searches enabled without a reranker') - elif config.reranker == 'rrf': + else: edge_uuid_map = {} search_result_uuids = [] - logger.info([[edge.fact for edge in result] for result in search_results]) - for result in search_results: result_uuids = [] for edge in result: @@ -103,12 +115,23 @@ async def hybrid_search( search_result_uuids = [[edge.uuid for edge in result] for result in search_results] - reranked_uuids = rrf(search_result_uuids) + reranked_uuids: list[str] = [] + if config.reranker == Reranker.rrf: + reranked_uuids = rrf(search_result_uuids) + elif config.reranker == Reranker.node_distance: + if center_node_uuid is None: + logger.exception('No center node provided for Node Distance reranker') + raise Exception('No center node provided for Node Distance reranker') + reranked_uuids = await node_distance_reranker( + driver, search_result_uuids, center_node_uuid + ) reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] edges.extend(reranked_edges) - context = SearchResults(episodes=episodes, nodes=nodes, edges=edges) + context = SearchResults( + episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges] + ) end = time() diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index d73ea5e6..e9d658e0 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -333,7 +333,7 @@ async def get_relevant_edges( # takes in a list of rankings of uuids def rrf(results: list[list[str]], rank_const=1) -> list[str]: - scores: dict[str, int] = defaultdict(int) + scores: dict[str, float] = defaultdict(float) for result in results: for i, uuid in enumerate(result): scores[uuid] += 1 / (i + rank_const) @@ -344,3 +344,43 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: sorted_uuids = [term[0] for term in scored_uuids] return sorted_uuids + + +async def node_distance_reranker( + driver: AsyncDriver, results: list[list[str]], center_node_uuid: str +) -> list[str]: + # use rrf as a preliminary ranker + sorted_uuids = rrf(results) + scores: dict[str, float] = {} + + for uuid in sorted_uuids: + # Find shortest path to center node + records, _, _ = await driver.execute_query( + """ + MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity) + MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity) + WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid] + RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid + """, + edge_uuid=uuid, + center_uuid=center_node_uuid, + ) + distance = 0.01 + + for record in records: + if ( + record['source_uuid'] == center_node_uuid + or record['target_uuid'] == center_node_uuid + ): + continue + distance = record['score'] + + if uuid in scores: + scores[uuid] = min(1 / distance, scores[uuid]) + else: + scores[uuid] = 1 / distance + + # rerank on shortest distance + sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid]) + + return sorted_uuids From a6d63f0c0d4ec0ac0d11070664cf3a2116338949 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:51:13 -0700 Subject: [PATCH 3/4] Add text episode type (#46) Add a new `text` episode type and update the `extract_nodes` function to handle it. * **EpisodeType Enum:** - Add `text` to the `EpisodeType` enum in `graphiti_core/nodes.py`. - Update the `from_str` method to handle the `text` episode type. * **extract_nodes Function:** - Update the `extract_nodes` function in `graphiti_core/utils/maintenance/node_operations.py` to handle the `text` episode type. - Use the `message` type prompt for both `message` and `text` episodes. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/getzep/graphiti?shareId=XXXX-XXXX-XXXX-XXXX). --- graphiti_core/nodes.py | 5 +++++ graphiti_core/utils/maintenance/node_operations.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 6a4df2bb..bf053569 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -46,10 +46,13 @@ class EpisodeType(Enum): or "assistant: I'm doing well, thank you for asking." json : str Represents an episode containing a JSON string object with structured data. + text : str + Represents a plain text episode. """ message = 'message' json = 'json' + text = 'text' @staticmethod def from_str(episode_type: str): @@ -57,6 +60,8 @@ class EpisodeType(Enum): return EpisodeType.message if episode_type == 'json': return EpisodeType.json + if episode_type == 'text': + return EpisodeType.text logger.error(f'Episode type: {episode_type} not implemented') raise NotImplementedError diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index ab9ffc51..2dfdaccb 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -72,7 +72,7 @@ async def extract_nodes( ) -> list[EntityNode]: start = time() extracted_node_data: list[dict[str, Any]] = [] - if episode.source == EpisodeType.message: + if episode.source in [EpisodeType.message, EpisodeType.text]: extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes) elif episode.source == EpisodeType.json: extracted_node_data = await extract_json_nodes(llm_client, episode) From 598e9fd0c58bbb5b29a1b09fec15d5e338074466 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Mon, 26 Aug 2024 16:04:41 -0700 Subject: [PATCH 4/4] Update cla.yml for dependabot[bot] whitelist (#47) --- .github/workflows/cla.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml index b9f52751..1ec1ddd4 100644 --- a/.github/workflows/cla.yml +++ b/.github/workflows/cla.yml @@ -26,7 +26,7 @@ jobs: path-to-document: "https://github.com/getzep/graphiti/blob/main/Zep-CLA.md" # e.g. a CLA or a DCO document # branch should not be protected branch: "main" - allowlist: paul-paliychuk,prasmussen15,danielchalef,dependabot,ellipsisdev + allowlist: paul-paliychuk,prasmussen15,danielchalef,dependabot[bot],ellipsisdev # the followings are the optional inputs - If the optional inputs are not given, then default values will be taken #remote-organization-name: enter the remote organization name where the signatures should be stored (Default is storing the signatures in the same repository)