Merge branch 'main' of github.com:getzep/graphiti into nba-langgraph

This commit is contained in:
paulpaliychuk 2024-08-26 20:04:40 -04:00
commit c304758791
9 changed files with 148 additions and 26 deletions

View file

@ -26,7 +26,7 @@ jobs:
path-to-document: "https://github.com/getzep/graphiti/blob/main/Zep-CLA.md" # e.g. a CLA or a DCO document path-to-document: "https://github.com/getzep/graphiti/blob/main/Zep-CLA.md" # e.g. a CLA or a DCO document
# branch should not be protected # branch should not be protected
branch: "main" branch: "main"
allowlist: paul-paliychuk,prasmussen15,danielchalef,dependabot,ellipsisdev allowlist: paul-paliychuk,prasmussen15,danielchalef,dependabot[bot],ellipsisdev
# the followings are the optional inputs - If the optional inputs are not given, then default values will be taken # the followings are the optional inputs - If the optional inputs are not given, then default values will be taken
#remote-organization-name: enter the remote organization name where the signatures should be stored (Default is storing the signatures in the same repository) #remote-organization-name: enter the remote organization name where the signatures should be stored (Default is storing the signatures in the same repository)

View file

@ -26,7 +26,7 @@ from neo4j import AsyncGraphDatabase
from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, hybrid_search from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
from graphiti_core.search.search_utils import ( from graphiti_core.search.search_utils import (
get_relevant_edges, get_relevant_edges,
get_relevant_nodes, get_relevant_nodes,
@ -515,7 +515,7 @@ class Graphiti:
except Exception as e: except Exception as e:
raise e raise e
async def search(self, query: str, num_results=10): async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
""" """
Perform a hybrid search on the knowledge graph. Perform a hybrid search on the knowledge graph.
@ -526,6 +526,8 @@ class Graphiti:
---------- ----------
query : str query : str
The search query string. The search query string.
center_node_uuid: str, optional
Facts will be reranked based on proximity to this node
num_results : int, optional num_results : int, optional
The maximum number of results to return. Defaults to 10. The maximum number of results to return. Defaults to 10.
@ -543,7 +545,14 @@ class Graphiti:
The search is performed using the current date and time as the reference The search is performed using the current date and time as the reference
point for temporal relevance. point for temporal relevance.
""" """
search_config = SearchConfig(num_episodes=0, num_results=num_results) reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
search_config = SearchConfig(
num_episodes=0,
num_edges=num_results,
num_nodes=0,
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
reranker=reranker,
)
edges = ( edges = (
await hybrid_search( await hybrid_search(
self.driver, self.driver,
@ -551,6 +560,7 @@ class Graphiti:
query, query,
datetime.now(), datetime.now(),
search_config, search_config,
center_node_uuid,
) )
).edges ).edges
@ -558,7 +568,13 @@ class Graphiti:
return facts return facts
async def _search(self, query: str, timestamp: datetime, config: SearchConfig): async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
):
return await hybrid_search( return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
) )

View file

@ -20,7 +20,9 @@ import logging
import typing import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import httpx
from diskcache import Cache from diskcache import Cache
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from ..prompts.models import Message from ..prompts.models import Message
from .config import LLMConfig from .config import LLMConfig
@ -31,6 +33,12 @@ DEFAULT_CACHE_DIR = './llm_cache'
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def is_server_error(exception):
return (
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
)
class LLMClient(ABC): class LLMClient(ABC):
def __init__(self, config: LLMConfig | None, cache: bool = False): def __init__(self, config: LLMConfig | None, cache: bool = False):
if config is None: if config is None:
@ -47,6 +55,20 @@ class LLMClient(ABC):
def get_embedder(self) -> typing.Any: def get_embedder(self) -> typing.Any:
pass pass
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception(is_server_error),
)
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
try:
return await self._generate_response(messages)
except httpx.HTTPStatusError as e:
if not is_server_error(e):
raise Exception(f'LLM request error: {e}') from e
else:
raise
@abstractmethod @abstractmethod
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
pass pass
@ -66,7 +88,7 @@ class LLMClient(ABC):
logger.debug(f'Cache hit for {cache_key}') logger.debug(f'Cache hit for {cache_key}')
return cached_response return cached_response
response = await self._generate_response(messages) response = await self._generate_response_with_retry(messages)
if self.cache_enabled: if self.cache_enabled:
self.cache_dir.set(cache_key, response) self.cache_dir.set(cache_key, response)

View file

@ -46,10 +46,13 @@ class EpisodeType(Enum):
or "assistant: I'm doing well, thank you for asking." or "assistant: I'm doing well, thank you for asking."
json : str json : str
Represents an episode containing a JSON string object with structured data. Represents an episode containing a JSON string object with structured data.
text : str
Represents a plain text episode.
""" """
message = 'message' message = 'message'
json = 'json' json = 'json'
text = 'text'
@staticmethod @staticmethod
def from_str(episode_type: str): def from_str(episode_type: str):
@ -57,6 +60,8 @@ class EpisodeType(Enum):
return EpisodeType.message return EpisodeType.message
if episode_type == 'json': if episode_type == 'json':
return EpisodeType.json return EpisodeType.json
if episode_type == 'text':
return EpisodeType.text
logger.error(f'Episode type: {episode_type} not implemented') logger.error(f'Episode type: {episode_type} not implemented')
raise NotImplementedError raise NotImplementedError

View file

@ -16,6 +16,7 @@ limitations under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from enum import Enum
from time import time from time import time
from neo4j import AsyncDriver from neo4j import AsyncDriver
@ -28,6 +29,7 @@ from graphiti_core.search.search_utils import (
edge_fulltext_search, edge_fulltext_search,
edge_similarity_search, edge_similarity_search,
get_mentioned_nodes, get_mentioned_nodes,
node_distance_reranker,
rrf, rrf,
) )
from graphiti_core.utils import retrieve_episodes from graphiti_core.utils import retrieve_episodes
@ -36,12 +38,22 @@ from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
class Reranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
class SearchConfig(BaseModel): class SearchConfig(BaseModel):
num_results: int = 10 num_edges: int = 10
num_nodes: int = 10
num_episodes: int = EPISODE_WINDOW_LEN num_episodes: int = EPISODE_WINDOW_LEN
similarity_search: str = 'cosine' search_methods: list[SearchMethod]
text_search: str = 'BM25' reranker: Reranker | None
reranker: str = 'rrf'
class SearchResults(BaseModel): class SearchResults(BaseModel):
@ -51,7 +63,12 @@ class SearchResults(BaseModel):
async def hybrid_search( async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults: ) -> SearchResults:
start = time() start = time()
@ -65,11 +82,11 @@ async def hybrid_search(
episodes.extend(await retrieve_episodes(driver, timestamp)) episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes)) nodes.extend(await get_mentioned_nodes(driver, episodes))
if config.text_search == 'BM25': if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(query, driver) text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search) search_results.append(text_search)
if config.similarity_search == 'cosine': if SearchMethod.cosine_similarity in config.search_methods:
query_text = query.replace('\n', ' ') query_text = query.replace('\n', ' ')
search_vector = ( search_vector = (
(await embedder.create(input=[query_text], model='text-embedding-3-small')) (await embedder.create(input=[query_text], model='text-embedding-3-small'))
@ -80,19 +97,14 @@ async def hybrid_search(
similarity_search = await edge_similarity_search(search_vector, driver) similarity_search = await edge_similarity_search(search_vector, driver)
search_results.append(similarity_search) search_results.append(similarity_search)
if len(search_results) == 1: if len(search_results) > 1 and config.reranker is None:
edges = search_results[0]
elif len(search_results) > 1 and config.reranker != 'rrf':
logger.exception('Multiple searches enabled without a reranker') logger.exception('Multiple searches enabled without a reranker')
raise Exception('Multiple searches enabled without a reranker') raise Exception('Multiple searches enabled without a reranker')
elif config.reranker == 'rrf': else:
edge_uuid_map = {} edge_uuid_map = {}
search_result_uuids = [] search_result_uuids = []
logger.info([[edge.fact for edge in result] for result in search_results])
for result in search_results: for result in search_results:
result_uuids = [] result_uuids = []
for edge in result: for edge in result:
@ -103,12 +115,23 @@ async def hybrid_search(
search_result_uuids = [[edge.uuid for edge in result] for result in search_results] search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids) reranked_uuids: list[str] = []
if config.reranker == Reranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == Reranker.node_distance:
if center_node_uuid is None:
logger.exception('No center node provided for Node Distance reranker')
raise Exception('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(
driver, search_result_uuids, center_node_uuid
)
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges) edges.extend(reranked_edges)
context = SearchResults(episodes=episodes, nodes=nodes, edges=edges) context = SearchResults(
episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
)
end = time() end = time()

View file

@ -333,7 +333,7 @@ async def get_relevant_edges(
# takes in a list of rankings of uuids # takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1) -> list[str]: def rrf(results: list[list[str]], rank_const=1) -> list[str]:
scores: dict[str, int] = defaultdict(int) scores: dict[str, float] = defaultdict(float)
for result in results: for result in results:
for i, uuid in enumerate(result): for i, uuid in enumerate(result):
scores[uuid] += 1 / (i + rank_const) scores[uuid] += 1 / (i + rank_const)
@ -344,3 +344,43 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
sorted_uuids = [term[0] for term in scored_uuids] sorted_uuids = [term[0] for term in scored_uuids]
return sorted_uuids return sorted_uuids
async def node_distance_reranker(
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(results)
scores: dict[str, float] = {}
for uuid in sorted_uuids:
# Find shortest path to center node
records, _, _ = await driver.execute_query(
"""
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
""",
edge_uuid=uuid,
center_uuid=center_node_uuid,
)
distance = 0.01
for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']
if uuid in scores:
scores[uuid] = min(1 / distance, scores[uuid])
else:
scores[uuid] = 1 / distance
# rerank on shortest distance
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])
return sorted_uuids

View file

@ -72,7 +72,7 @@ async def extract_nodes(
) -> list[EntityNode]: ) -> list[EntityNode]:
start = time() start = time()
extracted_node_data: list[dict[str, Any]] = [] extracted_node_data: list[dict[str, Any]] = []
if episode.source == EpisodeType.message: if episode.source in [EpisodeType.message, EpisodeType.text]:
extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes) extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
elif episode.source == EpisodeType.json: elif episode.source == EpisodeType.json:
extracted_node_data = await extract_json_nodes(llm_client, episode) extracted_node_data = await extract_json_nodes(llm_client, episode)

17
poetry.lock generated
View file

@ -3253,6 +3253,21 @@ mpmath = ">=1.1.0,<1.4"
[package.extras] [package.extras]
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
[[package]]
name = "tenacity"
version = "9.0.0"
description = "Retry code until it succeeds"
optional = false
python-versions = ">=3.8"
files = [
{file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"},
{file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"},
]
[package.extras]
doc = ["reno", "sphinx"]
test = ["pytest", "tornado (>=4.5)", "typeguard"]
[[package]] [[package]]
name = "terminado" name = "terminado"
version = "0.18.1" version = "0.18.1"
@ -3743,4 +3758,4 @@ test = ["websockets"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "5b90bb6d58d36a2553f5410c418b179aa1c86b55078567c33aaa6fddf6a8c6c6" content-hash = "001663dfc8078ad473675c994b15191db1f53a844e23f40ffa4a704379a61132"

View file

@ -23,6 +23,7 @@ diskcache = "^5.6.3"
arrow = "^1.3.0" arrow = "^1.3.0"
openai = "^1.38.0" openai = "^1.38.0"
anthropic = "^0.34.1" anthropic = "^0.34.1"
tenacity = "^9.0.0"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^8.3.2" pytest = "^8.3.2"