Cross encoder reranker in search query (#202)

* cross encoder reranker

* update reranker

* add openai reranker

* format

* mypy

* update

* updates

* MyPy typing

* bump version
This commit is contained in:
Preston Rasmussen 2024-10-25 12:29:27 -04:00 committed by GitHub
parent 544f9e3fba
commit ceb60a3d33
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 445 additions and 190 deletions

View file

@ -0,0 +1,113 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from typing import Any
import openai
from openai import AsyncOpenAI
from pydantic import BaseModel
from ..llm_client import LLMConfig, RateLimitError
from ..prompts import Message
from .client import CrossEncoderClient
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gpt-4o-mini'
class BooleanClassifier(BaseModel):
isTrue: bool
class OpenAIRerankerClient(CrossEncoderClient):
def __init__(self, config: LLMConfig | None = None):
"""
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
"""
if config is None:
config = LLMConfig()
self.config = config
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
openai_messages_list: Any = [
[
Message(
role='system',
content='You are an expert tasked with determining whether the passage is relevant to the query',
),
Message(
role='user',
content=f"""
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
<PASSAGE>
{query}
</PASSAGE>
{passage}
<QUERY>
</QUERY>
""",
),
]
for passage in passages
]
try:
responses = await asyncio.gather(
*[
self.client.chat.completions.create(
model=DEFAULT_MODEL,
messages=openai_messages,
temperature=0,
max_tokens=1,
logit_bias={'6432': 1, '7983': 1},
logprobs=True,
top_logprobs=2,
)
for openai_messages in openai_messages_list
]
)
responses_top_logprobs = [
response.choices[0].logprobs.content[0].top_logprobs
if response.choices[0].logprobs is not None
and response.choices[0].logprobs.content is not None
else []
for response in responses
]
scores: list[float] = []
for top_logprobs in responses_top_logprobs:
for logprob in top_logprobs:
if bool(logprob.token):
scores.append(logprob.logprob)
results = [(passage, score) for passage, score in zip(passages, scores)]
results.sort(reverse=True, key=lambda x: x[1])
return results
except openai.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise

View file

@ -23,8 +23,11 @@ from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from pydantic import BaseModel from pydantic import BaseModel
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search from graphiti_core.search.search import SearchConfig, search
@ -92,6 +95,7 @@ class Graphiti:
password: str, password: str,
llm_client: LLMClient | None = None, llm_client: LLMClient | None = None,
embedder: EmbedderClient | None = None, embedder: EmbedderClient | None = None,
cross_encoder: CrossEncoderClient | None = None,
store_raw_episode_content: bool = True, store_raw_episode_content: bool = True,
): ):
""" """
@ -131,7 +135,7 @@ class Graphiti:
Graphiti if you're using the default OpenAIClient. Graphiti if you're using the default OpenAIClient.
""" """
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.database = 'neo4j' self.database = DEFAULT_DATABASE
self.store_raw_episode_content = store_raw_episode_content self.store_raw_episode_content = store_raw_episode_content
if llm_client: if llm_client:
self.llm_client = llm_client self.llm_client = llm_client
@ -141,6 +145,10 @@ class Graphiti:
self.embedder = embedder self.embedder = embedder
else: else:
self.embedder = OpenAIEmbedder() self.embedder = OpenAIEmbedder()
if cross_encoder:
self.cross_encoder = cross_encoder
else:
self.cross_encoder = OpenAIRerankerClient()
async def close(self): async def close(self):
""" """
@ -648,6 +656,7 @@ class Graphiti:
await search( await search(
self.driver, self.driver,
self.embedder, self.embedder,
self.cross_encoder,
query, query,
group_ids, group_ids,
search_config, search_config,
@ -663,8 +672,18 @@ class Graphiti:
config: SearchConfig, config: SearchConfig,
group_ids: list[str] | None = None, group_ids: list[str] | None = None,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults: ) -> SearchResults:
return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid) return await search(
self.driver,
self.embedder,
self.cross_encoder,
query,
group_ids,
config,
center_node_uuid,
bfs_origin_node_uuids,
)
async def get_nodes_by_query( async def get_nodes_by_query(
self, self,
@ -716,7 +735,13 @@ class Graphiti:
nodes = ( nodes = (
await search( await search(
self.driver, self.embedder, query, group_ids, search_config, center_node_uuid self.driver,
self.embedder,
self.cross_encoder,
query,
group_ids,
search_config,
center_node_uuid,
) )
).nodes ).nodes
return nodes return nodes

View file

@ -21,6 +21,7 @@ from time import time
from neo4j import AsyncDriver from neo4j import AsyncDriver
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError from graphiti_core.errors import SearchRerankerError
@ -39,6 +40,7 @@ from graphiti_core.search.search_config import (
from graphiti_core.search.search_utils import ( from graphiti_core.search.search_utils import (
community_fulltext_search, community_fulltext_search,
community_similarity_search, community_similarity_search,
edge_bfs_search,
edge_fulltext_search, edge_fulltext_search,
edge_similarity_search, edge_similarity_search,
episode_mentions_reranker, episode_mentions_reranker,
@ -55,10 +57,12 @@ logger = logging.getLogger(__name__)
async def search( async def search(
driver: AsyncDriver, driver: AsyncDriver,
embedder: EmbedderClient, embedder: EmbedderClient,
cross_encoder: CrossEncoderClient,
query: str, query: str,
group_ids: list[str] | None, group_ids: list[str] | None,
config: SearchConfig, config: SearchConfig,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults: ) -> SearchResults:
start = time() start = time()
query_vector = await embedder.create(input=[query.replace('\n', ' ')]) query_vector = await embedder.create(input=[query.replace('\n', ' ')])
@ -68,28 +72,34 @@ async def search(
edges, nodes, communities = await asyncio.gather( edges, nodes, communities = await asyncio.gather(
edge_search( edge_search(
driver, driver,
cross_encoder,
query, query,
query_vector, query_vector,
group_ids, group_ids,
config.edge_config, config.edge_config,
center_node_uuid, center_node_uuid,
bfs_origin_node_uuids,
config.limit, config.limit,
), ),
node_search( node_search(
driver, driver,
cross_encoder,
query, query,
query_vector, query_vector,
group_ids, group_ids,
config.node_config, config.node_config,
center_node_uuid, center_node_uuid,
bfs_origin_node_uuids,
config.limit, config.limit,
), ),
community_search( community_search(
driver, driver,
cross_encoder,
query, query,
query_vector, query_vector,
group_ids, group_ids,
config.community_config, config.community_config,
bfs_origin_node_uuids,
config.limit, config.limit,
), ),
) )
@ -109,11 +119,13 @@ async def search(
async def edge_search( async def edge_search(
driver: AsyncDriver, driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str, query: str,
query_vector: list[float], query_vector: list[float],
group_ids: list[str] | None, group_ids: list[str] | None,
config: EdgeSearchConfig | None, config: EdgeSearchConfig | None,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]: ) -> list[EntityEdge]:
if config is None: if config is None:
@ -126,6 +138,7 @@ async def edge_search(
edge_similarity_search( edge_similarity_search(
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
), ),
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth),
] ]
) )
) )
@ -146,6 +159,10 @@ async def edge_search(
reranked_uuids = maximal_marginal_relevance( reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda query_vector, search_result_uuids_and_vectors, config.mmr_lambda
) )
elif config.reranker == EdgeReranker.cross_encoder:
fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result}
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
elif config.reranker == EdgeReranker.node_distance: elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None: if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker') raise SearchRerankerError('No center node provided for Node Distance reranker')
@ -176,11 +193,13 @@ async def edge_search(
async def node_search( async def node_search(
driver: AsyncDriver, driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str, query: str,
query_vector: list[float], query_vector: list[float],
group_ids: list[str] | None, group_ids: list[str] | None,
config: NodeSearchConfig | None, config: NodeSearchConfig | None,
center_node_uuid: str | None = None, center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]: ) -> list[EntityNode]:
if config is None: if config is None:
@ -212,6 +231,12 @@ async def node_search(
reranked_uuids = maximal_marginal_relevance( reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda query_vector, search_result_uuids_and_vectors, config.mmr_lambda
) )
elif config.reranker == NodeReranker.cross_encoder:
summary_to_uuid_map = {
node.summary: node.uuid for result in search_results for node in result
}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
elif config.reranker == NodeReranker.episode_mentions: elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids) reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
elif config.reranker == NodeReranker.node_distance: elif config.reranker == NodeReranker.node_distance:
@ -228,10 +253,12 @@ async def node_search(
async def community_search( async def community_search(
driver: AsyncDriver, driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str, query: str,
query_vector: list[float], query_vector: list[float],
group_ids: list[str] | None, group_ids: list[str] | None,
config: CommunitySearchConfig | None, config: CommunitySearchConfig | None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT, limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]: ) -> list[CommunityNode]:
if config is None: if config is None:
@ -268,6 +295,12 @@ async def community_search(
reranked_uuids = maximal_marginal_relevance( reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda query_vector, search_result_uuids_and_vectors, config.mmr_lambda
) )
elif config.reranker == CommunityReranker.cross_encoder:
summary_to_uuid_map = {
node.summary: node.uuid for result in search_results for node in result
}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids] reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]

View file

@ -20,7 +20,11 @@ from pydantic import BaseModel, Field
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import CommunityNode, EntityNode from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA from graphiti_core.search.search_utils import (
DEFAULT_MIN_SCORE,
DEFAULT_MMR_LAMBDA,
MAX_SEARCH_DEPTH,
)
DEFAULT_SEARCH_LIMIT = 10 DEFAULT_SEARCH_LIMIT = 10
@ -28,11 +32,13 @@ DEFAULT_SEARCH_LIMIT = 10
class EdgeSearchMethod(Enum): class EdgeSearchMethod(Enum):
cosine_similarity = 'cosine_similarity' cosine_similarity = 'cosine_similarity'
bm25 = 'bm25' bm25 = 'bm25'
bfs = 'breadth_first_search'
class NodeSearchMethod(Enum): class NodeSearchMethod(Enum):
cosine_similarity = 'cosine_similarity' cosine_similarity = 'cosine_similarity'
bm25 = 'bm25' bm25 = 'bm25'
bfs = 'breadth_first_search'
class CommunitySearchMethod(Enum): class CommunitySearchMethod(Enum):
@ -45,6 +51,7 @@ class EdgeReranker(Enum):
node_distance = 'node_distance' node_distance = 'node_distance'
episode_mentions = 'episode_mentions' episode_mentions = 'episode_mentions'
mmr = 'mmr' mmr = 'mmr'
cross_encoder = 'cross_encoder'
class NodeReranker(Enum): class NodeReranker(Enum):
@ -52,11 +59,13 @@ class NodeReranker(Enum):
node_distance = 'node_distance' node_distance = 'node_distance'
episode_mentions = 'episode_mentions' episode_mentions = 'episode_mentions'
mmr = 'mmr' mmr = 'mmr'
cross_encoder = 'cross_encoder'
class CommunityReranker(Enum): class CommunityReranker(Enum):
rrf = 'reciprocal_rank_fusion' rrf = 'reciprocal_rank_fusion'
mmr = 'mmr' mmr = 'mmr'
cross_encoder = 'cross_encoder'
class EdgeSearchConfig(BaseModel): class EdgeSearchConfig(BaseModel):
@ -64,6 +73,7 @@ class EdgeSearchConfig(BaseModel):
reranker: EdgeReranker = Field(default=EdgeReranker.rrf) reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
class NodeSearchConfig(BaseModel): class NodeSearchConfig(BaseModel):
@ -71,6 +81,7 @@ class NodeSearchConfig(BaseModel):
reranker: NodeReranker = Field(default=NodeReranker.rrf) reranker: NodeReranker = Field(default=NodeReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
class CommunitySearchConfig(BaseModel): class CommunitySearchConfig(BaseModel):
@ -78,6 +89,7 @@ class CommunitySearchConfig(BaseModel):
reranker: CommunityReranker = Field(default=CommunityReranker.rrf) reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
class SearchConfig(BaseModel): class SearchConfig(BaseModel):

View file

@ -48,14 +48,41 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
edge_config=EdgeSearchConfig( edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.mmr, reranker=EdgeReranker.mmr,
mmr_lambda=1,
), ),
node_config=NodeSearchConfig( node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity], search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.mmr, reranker=NodeReranker.mmr,
mmr_lambda=1,
), ),
community_config=CommunitySearchConfig( community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity], search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.mmr, reranker=CommunityReranker.mmr,
mmr_lambda=1,
),
)
# Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[
EdgeSearchMethod.bm25,
EdgeSearchMethod.cosine_similarity,
EdgeSearchMethod.bfs,
],
reranker=EdgeReranker.cross_encoder,
),
node_config=NodeSearchConfig(
search_methods=[
NodeSearchMethod.bm25,
NodeSearchMethod.cosine_similarity,
NodeSearchMethod.bfs,
],
reranker=NodeReranker.cross_encoder,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.cross_encoder,
), ),
) )
@ -81,7 +108,6 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.node_distance, reranker=EdgeReranker.node_distance,
), ),
limit=30,
) )
# performs a hybrid search over edges with episode mention reranking # performs a hybrid search over edges with episode mention reranking

View file

@ -19,7 +19,6 @@ import logging
from collections import defaultdict from collections import defaultdict
from time import time from time import time
import neo4j
import numpy as np import numpy as np
from neo4j import AsyncDriver, Query from neo4j import AsyncDriver, Query
@ -38,6 +37,7 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3 RELEVANT_SCHEMA_LIMIT = 3
DEFAULT_MIN_SCORE = 0.6 DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5 DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 128 MAX_QUERY_LENGTH = 128
@ -80,23 +80,21 @@ async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode] driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]: ) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes] episode_uuids = [episode.uuid for episode in episodes]
async with driver.session( records, _, _ = await driver.execute_query(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS """
) as session: MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
result = await session.run( RETURN DISTINCT
""" n.uuid As uuid,
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids n.group_id AS group_id,
RETURN DISTINCT n.name AS name,
n.uuid As uuid, n.name_embedding AS name_embedding,
n.group_id AS group_id, n.created_at AS created_at,
n.name AS name, n.summary AS summary
n.name_embedding AS name_embedding, """,
n.created_at AS created_at, uuids=episode_uuids,
n.summary AS summary database_=DEFAULT_DATABASE,
""", routing_='r',
{'uuids': episode_uuids}, )
)
records = [record async for record in result]
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record) for record in records]
@ -107,23 +105,21 @@ async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode] driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]: ) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes] node_uuids = [node.uuid for node in nodes]
async with driver.session( records, _, _ = await driver.execute_query(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS """
) as session: MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
result = await session.run( RETURN DISTINCT
""" c.uuid As uuid,
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids c.group_id AS group_id,
RETURN DISTINCT c.name AS name,
c.uuid As uuid, c.name_embedding AS name_embedding
c.group_id AS group_id, c.created_at AS created_at,
c.name AS name, c.summary AS summary
c.name_embedding AS name_embedding """,
c.created_at AS created_at, uuids=node_uuids,
c.summary AS summary database_=DEFAULT_DATABASE,
""", routing_='r',
{'uuids': node_uuids}, )
)
records = [record async for record in result]
communities = [get_community_node_from_record(record) for record in records] communities = [get_community_node_from_record(record) for record in records]
@ -149,7 +145,7 @@ async def edge_fulltext_search(
MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity) MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity)
WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid]) WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid]) AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
n.uuid AS source_node_uuid, n.uuid AS source_node_uuid,
@ -165,20 +161,16 @@ async def edge_fulltext_search(
ORDER BY score DESC LIMIT $limit ORDER BY score DESC LIMIT $limit
""") """)
async with driver.session( records, _, _ = await driver.execute_query(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS cypher_query,
) as session: query=fuzzy_query,
result = await session.run( source_uuid=source_node_uuid,
cypher_query, target_uuid=target_node_uuid,
{ group_ids=group_ids,
'query': fuzzy_query, limit=limit,
'source_uuid': source_node_uuid, database_=DEFAULT_DATABASE,
'target_uuid': target_node_uuid, routing_='r',
'group_ids': group_ids, )
'limit': limit,
},
)
records = [record async for record in result]
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record) for record in records]
@ -201,13 +193,13 @@ async def edge_similarity_search(
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids) WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid]) AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid]) AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
WITH DISTINCT n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > $min_score WHERE score > $min_score
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
r.group_id AS group_id, r.group_id AS group_id,
n.uuid AS source_node_uuid, startNode(r).uuid AS source_node_uuid,
m.uuid AS target_node_uuid, endNode(r).uuid AS target_node_uuid,
r.created_at AS created_at, r.created_at AS created_at,
r.name AS name, r.name AS name,
r.fact AS fact, r.fact AS fact,
@ -220,21 +212,59 @@ async def edge_similarity_search(
LIMIT $limit LIMIT $limit
""") """)
async with driver.session( records, _, _ = await driver.execute_query(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS query,
) as session: search_vector=search_vector,
result = await session.run( source_uuid=source_node_uuid,
query, target_uuid=target_node_uuid,
{ group_ids=group_ids,
'search_vector': search_vector, limit=limit,
'source_uuid': source_node_uuid, min_score=min_score,
'target_uuid': target_node_uuid, database_=DEFAULT_DATABASE,
'group_ids': group_ids, routing_='r',
'limit': limit, )
'min_score': min_score,
}, edges = [get_entity_edge_from_record(record) for record in records]
)
records = [record async for record in result] return edges
async def edge_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
if bfs_origin_node_uuids is None:
return []
query = Query("""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-()
RETURN DISTINCT
r.uuid AS uuid,
r.group_id AS group_id,
startNode(r).uuid AS source_node_uuid,
endNode(r).uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
""")
records, _, _ = await driver.execute_query(
query,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
database_=DEFAULT_DATABASE,
routing_='r',
)
edges = [get_entity_edge_from_record(record) for record in records] edges = [get_entity_edge_from_record(record) for record in records]
@ -252,30 +282,26 @@ async def node_fulltext_search(
if fuzzy_query == '': if fuzzy_query == '':
return [] return []
async with driver.session( records, _, _ = await driver.execute_query(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS """
) as session: CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
result = await session.run( YIELD node AS n, score
""" RETURN
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query) n.uuid AS uuid,
YIELD node AS n, score n.group_id AS group_id,
RETURN n.name AS name,
n.uuid AS uuid, n.name_embedding AS name_embedding,
n.group_id AS group_id, n.created_at AS created_at,
n.name AS name, n.summary AS summary
n.name_embedding AS name_embedding, ORDER BY score DESC
n.created_at AS created_at, LIMIT $limit
n.summary AS summary """,
ORDER BY score DESC query=fuzzy_query,
LIMIT $limit group_ids=group_ids,
""", limit=limit,
{ database_=DEFAULT_DATABASE,
'query': fuzzy_query, routing_='r',
'group_ids': group_ids, )
'limit': limit,
},
)
records = [record async for record in result]
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record) for record in records]
return nodes return nodes
@ -289,34 +315,62 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE, min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]: ) -> list[EntityNode]:
# vector similarity search over entity names # vector similarity search over entity names
async with driver.session( records, _, _ = await driver.execute_query(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS """
) as session: CYPHER runtime = parallel parallelRuntimeSupport=all
result = await session.run( MATCH (n:Entity)
""" WHERE $group_ids IS NULL OR n.group_id IN $group_ids
CYPHER runtime = parallel parallelRuntimeSupport=all WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
MATCH (n:Entity) WHERE score > $min_score
WHERE $group_ids IS NULL OR n.group_id IN $group_ids RETURN
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score n.uuid As uuid,
WHERE score > $min_score n.group_id AS group_id,
RETURN n.name AS name,
n.uuid As uuid, n.name_embedding AS name_embedding,
n.group_id AS group_id, n.created_at AS created_at,
n.name AS name, n.summary AS summary
n.name_embedding AS name_embedding, ORDER BY score DESC
n.created_at AS created_at, LIMIT $limit
n.summary AS summary """,
ORDER BY score DESC search_vector=search_vector,
LIMIT $limit group_ids=group_ids,
""", limit=limit,
{ min_score=min_score,
'search_vector': search_vector, database_=DEFAULT_DATABASE,
'group_ids': group_ids, routing_='r',
'limit': limit, )
'min_score': min_score, nodes = [get_entity_node_from_record(record) for record in records]
},
) return nodes
records = [record async for record in result]
async def node_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
) -> list[EntityNode]:
# vector similarity search over entity names
if bfs_origin_node_uuids is None:
return []
records, _, _ = await driver.execute_query(
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
LIMIT $limit
""",
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
database_=DEFAULT_DATABASE,
routing_='r',
)
nodes = [get_entity_node_from_record(record) for record in records] nodes = [get_entity_node_from_record(record) for record in records]
return nodes return nodes
@ -333,30 +387,26 @@ async def community_fulltext_search(
if fuzzy_query == '': if fuzzy_query == '':
return [] return []
async with driver.session( records, _, _ = await driver.execute_query(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS """
) as session: CALL db.index.fulltext.queryNodes("community_name", $query)
result = await session.run( YIELD node AS comm, score
""" RETURN
CALL db.index.fulltext.queryNodes("community_name", $query) comm.uuid AS uuid,
YIELD node AS comm, score comm.group_id AS group_id,
RETURN comm.name AS name,
comm.uuid AS uuid, comm.name_embedding AS name_embedding,
comm.group_id AS group_id, comm.created_at AS created_at,
comm.name AS name, comm.summary AS summary
comm.name_embedding AS name_embedding, ORDER BY score DESC
comm.created_at AS created_at, LIMIT $limit
comm.summary AS summary """,
ORDER BY score DESC query=fuzzy_query,
LIMIT $limit group_ids=group_ids,
""", limit=limit,
{ database_=DEFAULT_DATABASE,
'query': fuzzy_query, routing_='r',
'group_ids': group_ids, )
'limit': limit,
},
)
records = [record async for record in result]
communities = [get_community_node_from_record(record) for record in records] communities = [get_community_node_from_record(record) for record in records]
return communities return communities
@ -370,34 +420,30 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE, min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]: ) -> list[CommunityNode]:
# vector similarity search over entity names # vector similarity search over entity names
async with driver.session( records, _, _ = await driver.execute_query(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS """
) as session: CYPHER runtime = parallel parallelRuntimeSupport=all
result = await session.run( MATCH (comm:Community)
""" WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
CYPHER runtime = parallel parallelRuntimeSupport=all WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
MATCH (comm:Community) WHERE score > $min_score
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids) RETURN
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score comm.uuid As uuid,
WHERE score > $min_score comm.group_id AS group_id,
RETURN comm.name AS name,
comm.uuid As uuid, comm.name_embedding AS name_embedding,
comm.group_id AS group_id, comm.created_at AS created_at,
comm.name AS name, comm.summary AS summary
comm.name_embedding AS name_embedding, ORDER BY score DESC
comm.created_at AS created_at, LIMIT $limit
comm.summary AS summary """,
ORDER BY score DESC search_vector=search_vector,
LIMIT $limit group_ids=group_ids,
""", limit=limit,
{ min_score=min_score,
'search_vector': search_vector, database_=DEFAULT_DATABASE,
'group_ids': group_ids, routing_='r',
'limit': limit, )
'min_score': min_score,
},
)
records = [record async for record in result]
communities = [get_community_node_from_record(record) for record in records] communities = [get_community_node_from_record(record) for record in records]
return communities return communities

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "graphiti-core" name = "graphiti-core"
version = "0.3.16" version = "0.3.17"
description = "A temporal graph building library" description = "A temporal graph building library"
authors = [ authors = [
"Paul Paliychuk <paul@getzep.com>", "Paul Paliychuk <paul@getzep.com>",

View file

@ -26,7 +26,9 @@ from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti from graphiti_core.graphiti import Graphiti
from graphiti_core.nodes import EntityNode, EpisodicNode from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF from graphiti_core.search.search_config_recipes import (
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
)
pytestmark = pytest.mark.integration pytestmark = pytest.mark.integration
@ -60,22 +62,19 @@ def setup_logging():
return logger return logger
def format_context(facts):
formatted_string = ''
formatted_string += 'FACTS:\n'
for fact in facts:
formatted_string += f' - {fact}\n'
formatted_string += '\n'
return formatted_string.strip()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graphiti_init(): async def test_graphiti_init():
logger = setup_logging() logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
episodes = await graphiti.retrieve_episodes(datetime.now(), group_ids=None)
episode_uuids = [episode.uuid for episode in episodes]
results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None) results = await graphiti._search(
"Emily: I can't log in",
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
bfs_origin_node_uuids=episode_uuids,
group_ids=None,
)
pretty_results = { pretty_results = {
'edges': [edge.fact for edge in results.edges], 'edges': [edge.fact for edge in results.edges],
'nodes': [node.name for node in results.nodes], 'nodes': [node.name for node in results.nodes],
@ -83,6 +82,7 @@ async def test_graphiti_init():
} }
logger.info(pretty_results) logger.info(pretty_results)
await graphiti.close() await graphiti.close()