Merge branch 'main' of github.com:getzep/graphiti into nba-langgraph
This commit is contained in:
commit
c304758791
9 changed files with 148 additions and 26 deletions
2
.github/workflows/cla.yml
vendored
2
.github/workflows/cla.yml
vendored
|
|
@ -26,7 +26,7 @@ jobs:
|
|||
path-to-document: "https://github.com/getzep/graphiti/blob/main/Zep-CLA.md" # e.g. a CLA or a DCO document
|
||||
# branch should not be protected
|
||||
branch: "main"
|
||||
allowlist: paul-paliychuk,prasmussen15,danielchalef,dependabot,ellipsisdev
|
||||
allowlist: paul-paliychuk,prasmussen15,danielchalef,dependabot[bot],ellipsisdev
|
||||
|
||||
# the followings are the optional inputs - If the optional inputs are not given, then default values will be taken
|
||||
#remote-organization-name: enter the remote organization name where the signatures should be stored (Default is storing the signatures in the same repository)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from neo4j import AsyncGraphDatabase
|
|||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search import SearchConfig, hybrid_search
|
||||
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
|
||||
from graphiti_core.search.search_utils import (
|
||||
get_relevant_edges,
|
||||
get_relevant_nodes,
|
||||
|
|
@ -515,7 +515,7 @@ class Graphiti:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def search(self, query: str, num_results=10):
|
||||
async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
|
||||
"""
|
||||
Perform a hybrid search on the knowledge graph.
|
||||
|
||||
|
|
@ -526,6 +526,8 @@ class Graphiti:
|
|||
----------
|
||||
query : str
|
||||
The search query string.
|
||||
center_node_uuid: str, optional
|
||||
Facts will be reranked based on proximity to this node
|
||||
num_results : int, optional
|
||||
The maximum number of results to return. Defaults to 10.
|
||||
|
||||
|
|
@ -543,7 +545,14 @@ class Graphiti:
|
|||
The search is performed using the current date and time as the reference
|
||||
point for temporal relevance.
|
||||
"""
|
||||
search_config = SearchConfig(num_episodes=0, num_results=num_results)
|
||||
reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
|
||||
search_config = SearchConfig(
|
||||
num_episodes=0,
|
||||
num_edges=num_results,
|
||||
num_nodes=0,
|
||||
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
|
||||
reranker=reranker,
|
||||
)
|
||||
edges = (
|
||||
await hybrid_search(
|
||||
self.driver,
|
||||
|
|
@ -551,6 +560,7 @@ class Graphiti:
|
|||
query,
|
||||
datetime.now(),
|
||||
search_config,
|
||||
center_node_uuid,
|
||||
)
|
||||
).edges
|
||||
|
||||
|
|
@ -558,7 +568,13 @@ class Graphiti:
|
|||
|
||||
return facts
|
||||
|
||||
async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
|
||||
async def _search(
|
||||
self,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
):
|
||||
return await hybrid_search(
|
||||
self.driver, self.llm_client.get_embedder(), query, timestamp, config
|
||||
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ import logging
|
|||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import httpx
|
||||
from diskcache import Cache
|
||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .config import LLMConfig
|
||||
|
|
@ -31,6 +33,12 @@ DEFAULT_CACHE_DIR = './llm_cache'
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_server_error(exception):
|
||||
return (
|
||||
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
|
||||
)
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
def __init__(self, config: LLMConfig | None, cache: bool = False):
|
||||
if config is None:
|
||||
|
|
@ -47,6 +55,20 @@ class LLMClient(ABC):
|
|||
def get_embedder(self) -> typing.Any:
|
||||
pass
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception(is_server_error),
|
||||
)
|
||||
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||
try:
|
||||
return await self._generate_response(messages)
|
||||
except httpx.HTTPStatusError as e:
|
||||
if not is_server_error(e):
|
||||
raise Exception(f'LLM request error: {e}') from e
|
||||
else:
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||
pass
|
||||
|
|
@ -66,7 +88,7 @@ class LLMClient(ABC):
|
|||
logger.debug(f'Cache hit for {cache_key}')
|
||||
return cached_response
|
||||
|
||||
response = await self._generate_response(messages)
|
||||
response = await self._generate_response_with_retry(messages)
|
||||
|
||||
if self.cache_enabled:
|
||||
self.cache_dir.set(cache_key, response)
|
||||
|
|
|
|||
|
|
@ -46,10 +46,13 @@ class EpisodeType(Enum):
|
|||
or "assistant: I'm doing well, thank you for asking."
|
||||
json : str
|
||||
Represents an episode containing a JSON string object with structured data.
|
||||
text : str
|
||||
Represents a plain text episode.
|
||||
"""
|
||||
|
||||
message = 'message'
|
||||
json = 'json'
|
||||
text = 'text'
|
||||
|
||||
@staticmethod
|
||||
def from_str(episode_type: str):
|
||||
|
|
@ -57,6 +60,8 @@ class EpisodeType(Enum):
|
|||
return EpisodeType.message
|
||||
if episode_type == 'json':
|
||||
return EpisodeType.json
|
||||
if episode_type == 'text':
|
||||
return EpisodeType.text
|
||||
logger.error(f'Episode type: {episode_type} not implemented')
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
|
|
@ -28,6 +29,7 @@ from graphiti_core.search.search_utils import (
|
|||
edge_fulltext_search,
|
||||
edge_similarity_search,
|
||||
get_mentioned_nodes,
|
||||
node_distance_reranker,
|
||||
rrf,
|
||||
)
|
||||
from graphiti_core.utils import retrieve_episodes
|
||||
|
|
@ -36,12 +38,22 @@ from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SearchMethod(Enum):
|
||||
cosine_similarity = 'cosine_similarity'
|
||||
bm25 = 'bm25'
|
||||
|
||||
|
||||
class Reranker(Enum):
|
||||
rrf = 'reciprocal_rank_fusion'
|
||||
node_distance = 'node_distance'
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
num_results: int = 10
|
||||
num_edges: int = 10
|
||||
num_nodes: int = 10
|
||||
num_episodes: int = EPISODE_WINDOW_LEN
|
||||
similarity_search: str = 'cosine'
|
||||
text_search: str = 'BM25'
|
||||
reranker: str = 'rrf'
|
||||
search_methods: list[SearchMethod]
|
||||
reranker: Reranker | None
|
||||
|
||||
|
||||
class SearchResults(BaseModel):
|
||||
|
|
@ -51,7 +63,12 @@ class SearchResults(BaseModel):
|
|||
|
||||
|
||||
async def hybrid_search(
|
||||
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
|
||||
|
|
@ -65,11 +82,11 @@ async def hybrid_search(
|
|||
episodes.extend(await retrieve_episodes(driver, timestamp))
|
||||
nodes.extend(await get_mentioned_nodes(driver, episodes))
|
||||
|
||||
if config.text_search == 'BM25':
|
||||
if SearchMethod.bm25 in config.search_methods:
|
||||
text_search = await edge_fulltext_search(query, driver)
|
||||
search_results.append(text_search)
|
||||
|
||||
if config.similarity_search == 'cosine':
|
||||
if SearchMethod.cosine_similarity in config.search_methods:
|
||||
query_text = query.replace('\n', ' ')
|
||||
search_vector = (
|
||||
(await embedder.create(input=[query_text], model='text-embedding-3-small'))
|
||||
|
|
@ -80,19 +97,14 @@ async def hybrid_search(
|
|||
similarity_search = await edge_similarity_search(search_vector, driver)
|
||||
search_results.append(similarity_search)
|
||||
|
||||
if len(search_results) == 1:
|
||||
edges = search_results[0]
|
||||
|
||||
elif len(search_results) > 1 and config.reranker != 'rrf':
|
||||
if len(search_results) > 1 and config.reranker is None:
|
||||
logger.exception('Multiple searches enabled without a reranker')
|
||||
raise Exception('Multiple searches enabled without a reranker')
|
||||
|
||||
elif config.reranker == 'rrf':
|
||||
else:
|
||||
edge_uuid_map = {}
|
||||
search_result_uuids = []
|
||||
|
||||
logger.info([[edge.fact for edge in result] for result in search_results])
|
||||
|
||||
for result in search_results:
|
||||
result_uuids = []
|
||||
for edge in result:
|
||||
|
|
@ -103,12 +115,23 @@ async def hybrid_search(
|
|||
|
||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
reranked_uuids: list[str] = []
|
||||
if config.reranker == Reranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
elif config.reranker == Reranker.node_distance:
|
||||
if center_node_uuid is None:
|
||||
logger.exception('No center node provided for Node Distance reranker')
|
||||
raise Exception('No center node provided for Node Distance reranker')
|
||||
reranked_uuids = await node_distance_reranker(
|
||||
driver, search_result_uuids, center_node_uuid
|
||||
)
|
||||
|
||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
edges.extend(reranked_edges)
|
||||
|
||||
context = SearchResults(episodes=episodes, nodes=nodes, edges=edges)
|
||||
context = SearchResults(
|
||||
episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
|
||||
)
|
||||
|
||||
end = time()
|
||||
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ async def get_relevant_edges(
|
|||
|
||||
# takes in a list of rankings of uuids
|
||||
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||
scores: dict[str, int] = defaultdict(int)
|
||||
scores: dict[str, float] = defaultdict(float)
|
||||
for result in results:
|
||||
for i, uuid in enumerate(result):
|
||||
scores[uuid] += 1 / (i + rank_const)
|
||||
|
|
@ -344,3 +344,43 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
|||
sorted_uuids = [term[0] for term in scored_uuids]
|
||||
|
||||
return sorted_uuids
|
||||
|
||||
|
||||
async def node_distance_reranker(
|
||||
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
||||
) -> list[str]:
|
||||
# use rrf as a preliminary ranker
|
||||
sorted_uuids = rrf(results)
|
||||
scores: dict[str, float] = {}
|
||||
|
||||
for uuid in sorted_uuids:
|
||||
# Find shortest path to center node
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
||||
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
|
||||
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
|
||||
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
||||
""",
|
||||
edge_uuid=uuid,
|
||||
center_uuid=center_node_uuid,
|
||||
)
|
||||
distance = 0.01
|
||||
|
||||
for record in records:
|
||||
if (
|
||||
record['source_uuid'] == center_node_uuid
|
||||
or record['target_uuid'] == center_node_uuid
|
||||
):
|
||||
continue
|
||||
distance = record['score']
|
||||
|
||||
if uuid in scores:
|
||||
scores[uuid] = min(1 / distance, scores[uuid])
|
||||
else:
|
||||
scores[uuid] = 1 / distance
|
||||
|
||||
# rerank on shortest distance
|
||||
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
||||
return sorted_uuids
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ async def extract_nodes(
|
|||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
extracted_node_data: list[dict[str, Any]] = []
|
||||
if episode.source == EpisodeType.message:
|
||||
if episode.source in [EpisodeType.message, EpisodeType.text]:
|
||||
extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
|
||||
elif episode.source == EpisodeType.json:
|
||||
extracted_node_data = await extract_json_nodes(llm_client, episode)
|
||||
|
|
|
|||
17
poetry.lock
generated
17
poetry.lock
generated
|
|
@ -3253,6 +3253,21 @@ mpmath = ">=1.1.0,<1.4"
|
|||
[package.extras]
|
||||
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "9.0.0"
|
||||
description = "Retry code until it succeeds"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"},
|
||||
{file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
doc = ["reno", "sphinx"]
|
||||
test = ["pytest", "tornado (>=4.5)", "typeguard"]
|
||||
|
||||
[[package]]
|
||||
name = "terminado"
|
||||
version = "0.18.1"
|
||||
|
|
@ -3743,4 +3758,4 @@ test = ["websockets"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "5b90bb6d58d36a2553f5410c418b179aa1c86b55078567c33aaa6fddf6a8c6c6"
|
||||
content-hash = "001663dfc8078ad473675c994b15191db1f53a844e23f40ffa4a704379a61132"
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ diskcache = "^5.6.3"
|
|||
arrow = "^1.3.0"
|
||||
openai = "^1.38.0"
|
||||
anthropic = "^0.34.1"
|
||||
tenacity = "^9.0.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^8.3.2"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue