Cross encoder reranker in search query (#202)
* cross encoder reranker * update reranker * add openai reranker * format * mypy * update * updates * MyPy typing * bump version
This commit is contained in:
parent
544f9e3fba
commit
ceb60a3d33
8 changed files with 445 additions and 190 deletions
113
graphiti_core/cross_encoder/openai_reranker_client.py
Normal file
113
graphiti_core/cross_encoder/openai_reranker_client.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..llm_client import LLMConfig, RateLimitError
|
||||
from ..prompts import Message
|
||||
from .client import CrossEncoderClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = 'gpt-4o-mini'
|
||||
|
||||
|
||||
class BooleanClassifier(BaseModel):
|
||||
isTrue: bool
|
||||
|
||||
|
||||
class OpenAIRerankerClient(CrossEncoderClient):
|
||||
def __init__(self, config: LLMConfig | None = None):
|
||||
"""
|
||||
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
||||
|
||||
Args:
|
||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||
|
||||
"""
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
|
||||
self.config = config
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
||||
openai_messages_list: Any = [
|
||||
[
|
||||
Message(
|
||||
role='system',
|
||||
content='You are an expert tasked with determining whether the passage is relevant to the query',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
|
||||
<PASSAGE>
|
||||
{query}
|
||||
</PASSAGE>
|
||||
{passage}
|
||||
<QUERY>
|
||||
</QUERY>
|
||||
""",
|
||||
),
|
||||
]
|
||||
for passage in passages
|
||||
]
|
||||
try:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.client.chat.completions.create(
|
||||
model=DEFAULT_MODEL,
|
||||
messages=openai_messages,
|
||||
temperature=0,
|
||||
max_tokens=1,
|
||||
logit_bias={'6432': 1, '7983': 1},
|
||||
logprobs=True,
|
||||
top_logprobs=2,
|
||||
)
|
||||
for openai_messages in openai_messages_list
|
||||
]
|
||||
)
|
||||
|
||||
responses_top_logprobs = [
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
if response.choices[0].logprobs is not None
|
||||
and response.choices[0].logprobs.content is not None
|
||||
else []
|
||||
for response in responses
|
||||
]
|
||||
scores: list[float] = []
|
||||
for top_logprobs in responses_top_logprobs:
|
||||
for logprob in top_logprobs:
|
||||
if bool(logprob.token):
|
||||
scores.append(logprob.logprob)
|
||||
|
||||
results = [(passage, score) for passage, score in zip(passages, scores)]
|
||||
results.sort(reverse=True, key=lambda x: x[1])
|
||||
return results
|
||||
except openai.RateLimitError as e:
|
||||
raise RateLimitError from e
|
||||
except Exception as e:
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
|
@ -23,8 +23,11 @@ from dotenv import load_dotenv
|
|||
from neo4j import AsyncGraphDatabase
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search import SearchConfig, search
|
||||
|
|
@ -92,6 +95,7 @@ class Graphiti:
|
|||
password: str,
|
||||
llm_client: LLMClient | None = None,
|
||||
embedder: EmbedderClient | None = None,
|
||||
cross_encoder: CrossEncoderClient | None = None,
|
||||
store_raw_episode_content: bool = True,
|
||||
):
|
||||
"""
|
||||
|
|
@ -131,7 +135,7 @@ class Graphiti:
|
|||
Graphiti if you're using the default OpenAIClient.
|
||||
"""
|
||||
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
||||
self.database = 'neo4j'
|
||||
self.database = DEFAULT_DATABASE
|
||||
self.store_raw_episode_content = store_raw_episode_content
|
||||
if llm_client:
|
||||
self.llm_client = llm_client
|
||||
|
|
@ -141,6 +145,10 @@ class Graphiti:
|
|||
self.embedder = embedder
|
||||
else:
|
||||
self.embedder = OpenAIEmbedder()
|
||||
if cross_encoder:
|
||||
self.cross_encoder = cross_encoder
|
||||
else:
|
||||
self.cross_encoder = OpenAIRerankerClient()
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
|
|
@ -648,6 +656,7 @@ class Graphiti:
|
|||
await search(
|
||||
self.driver,
|
||||
self.embedder,
|
||||
self.cross_encoder,
|
||||
query,
|
||||
group_ids,
|
||||
search_config,
|
||||
|
|
@ -663,8 +672,18 @@ class Graphiti:
|
|||
config: SearchConfig,
|
||||
group_ids: list[str] | None = None,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
) -> SearchResults:
|
||||
return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
|
||||
return await search(
|
||||
self.driver,
|
||||
self.embedder,
|
||||
self.cross_encoder,
|
||||
query,
|
||||
group_ids,
|
||||
config,
|
||||
center_node_uuid,
|
||||
bfs_origin_node_uuids,
|
||||
)
|
||||
|
||||
async def get_nodes_by_query(
|
||||
self,
|
||||
|
|
@ -716,7 +735,13 @@ class Graphiti:
|
|||
|
||||
nodes = (
|
||||
await search(
|
||||
self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
|
||||
self.driver,
|
||||
self.embedder,
|
||||
self.cross_encoder,
|
||||
query,
|
||||
group_ids,
|
||||
search_config,
|
||||
center_node_uuid,
|
||||
)
|
||||
).nodes
|
||||
return nodes
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from time import time
|
|||
|
||||
from neo4j import AsyncDriver
|
||||
|
||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import SearchRerankerError
|
||||
|
|
@ -39,6 +40,7 @@ from graphiti_core.search.search_config import (
|
|||
from graphiti_core.search.search_utils import (
|
||||
community_fulltext_search,
|
||||
community_similarity_search,
|
||||
edge_bfs_search,
|
||||
edge_fulltext_search,
|
||||
edge_similarity_search,
|
||||
episode_mentions_reranker,
|
||||
|
|
@ -55,10 +57,12 @@ logger = logging.getLogger(__name__)
|
|||
async def search(
|
||||
driver: AsyncDriver,
|
||||
embedder: EmbedderClient,
|
||||
cross_encoder: CrossEncoderClient,
|
||||
query: str,
|
||||
group_ids: list[str] | None,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
|
||||
|
|
@ -68,28 +72,34 @@ async def search(
|
|||
edges, nodes, communities = await asyncio.gather(
|
||||
edge_search(
|
||||
driver,
|
||||
cross_encoder,
|
||||
query,
|
||||
query_vector,
|
||||
group_ids,
|
||||
config.edge_config,
|
||||
center_node_uuid,
|
||||
bfs_origin_node_uuids,
|
||||
config.limit,
|
||||
),
|
||||
node_search(
|
||||
driver,
|
||||
cross_encoder,
|
||||
query,
|
||||
query_vector,
|
||||
group_ids,
|
||||
config.node_config,
|
||||
center_node_uuid,
|
||||
bfs_origin_node_uuids,
|
||||
config.limit,
|
||||
),
|
||||
community_search(
|
||||
driver,
|
||||
cross_encoder,
|
||||
query,
|
||||
query_vector,
|
||||
group_ids,
|
||||
config.community_config,
|
||||
bfs_origin_node_uuids,
|
||||
config.limit,
|
||||
),
|
||||
)
|
||||
|
|
@ -109,11 +119,13 @@ async def search(
|
|||
|
||||
async def edge_search(
|
||||
driver: AsyncDriver,
|
||||
cross_encoder: CrossEncoderClient,
|
||||
query: str,
|
||||
query_vector: list[float],
|
||||
group_ids: list[str] | None,
|
||||
config: EdgeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
if config is None:
|
||||
|
|
@ -126,6 +138,7 @@ async def edge_search(
|
|||
edge_similarity_search(
|
||||
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
|
||||
),
|
||||
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
@ -146,6 +159,10 @@ async def edge_search(
|
|||
reranked_uuids = maximal_marginal_relevance(
|
||||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||
)
|
||||
elif config.reranker == EdgeReranker.cross_encoder:
|
||||
fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result}
|
||||
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
||||
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
|
||||
elif config.reranker == EdgeReranker.node_distance:
|
||||
if center_node_uuid is None:
|
||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||
|
|
@ -176,11 +193,13 @@ async def edge_search(
|
|||
|
||||
async def node_search(
|
||||
driver: AsyncDriver,
|
||||
cross_encoder: CrossEncoderClient,
|
||||
query: str,
|
||||
query_vector: list[float],
|
||||
group_ids: list[str] | None,
|
||||
config: NodeSearchConfig | None,
|
||||
center_node_uuid: str | None = None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
if config is None:
|
||||
|
|
@ -212,6 +231,12 @@ async def node_search(
|
|||
reranked_uuids = maximal_marginal_relevance(
|
||||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||
)
|
||||
elif config.reranker == NodeReranker.cross_encoder:
|
||||
summary_to_uuid_map = {
|
||||
node.summary: node.uuid for result in search_results for node in result
|
||||
}
|
||||
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
||||
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
|
||||
elif config.reranker == NodeReranker.episode_mentions:
|
||||
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
||||
elif config.reranker == NodeReranker.node_distance:
|
||||
|
|
@ -228,10 +253,12 @@ async def node_search(
|
|||
|
||||
async def community_search(
|
||||
driver: AsyncDriver,
|
||||
cross_encoder: CrossEncoderClient,
|
||||
query: str,
|
||||
query_vector: list[float],
|
||||
group_ids: list[str] | None,
|
||||
config: CommunitySearchConfig | None,
|
||||
bfs_origin_node_uuids: list[str] | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
if config is None:
|
||||
|
|
@ -268,6 +295,12 @@ async def community_search(
|
|||
reranked_uuids = maximal_marginal_relevance(
|
||||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||
)
|
||||
elif config.reranker == CommunityReranker.cross_encoder:
|
||||
summary_to_uuid_map = {
|
||||
node.summary: node.uuid for result in search_results for node in result
|
||||
}
|
||||
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
||||
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
|
||||
|
||||
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,11 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||
from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
|
||||
from graphiti_core.search.search_utils import (
|
||||
DEFAULT_MIN_SCORE,
|
||||
DEFAULT_MMR_LAMBDA,
|
||||
MAX_SEARCH_DEPTH,
|
||||
)
|
||||
|
||||
DEFAULT_SEARCH_LIMIT = 10
|
||||
|
||||
|
|
@ -28,11 +32,13 @@ DEFAULT_SEARCH_LIMIT = 10
|
|||
class EdgeSearchMethod(Enum):
|
||||
cosine_similarity = 'cosine_similarity'
|
||||
bm25 = 'bm25'
|
||||
bfs = 'breadth_first_search'
|
||||
|
||||
|
||||
class NodeSearchMethod(Enum):
|
||||
cosine_similarity = 'cosine_similarity'
|
||||
bm25 = 'bm25'
|
||||
bfs = 'breadth_first_search'
|
||||
|
||||
|
||||
class CommunitySearchMethod(Enum):
|
||||
|
|
@ -45,6 +51,7 @@ class EdgeReranker(Enum):
|
|||
node_distance = 'node_distance'
|
||||
episode_mentions = 'episode_mentions'
|
||||
mmr = 'mmr'
|
||||
cross_encoder = 'cross_encoder'
|
||||
|
||||
|
||||
class NodeReranker(Enum):
|
||||
|
|
@ -52,11 +59,13 @@ class NodeReranker(Enum):
|
|||
node_distance = 'node_distance'
|
||||
episode_mentions = 'episode_mentions'
|
||||
mmr = 'mmr'
|
||||
cross_encoder = 'cross_encoder'
|
||||
|
||||
|
||||
class CommunityReranker(Enum):
|
||||
rrf = 'reciprocal_rank_fusion'
|
||||
mmr = 'mmr'
|
||||
cross_encoder = 'cross_encoder'
|
||||
|
||||
|
||||
class EdgeSearchConfig(BaseModel):
|
||||
|
|
@ -64,6 +73,7 @@ class EdgeSearchConfig(BaseModel):
|
|||
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
|
||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||
|
||||
|
||||
class NodeSearchConfig(BaseModel):
|
||||
|
|
@ -71,6 +81,7 @@ class NodeSearchConfig(BaseModel):
|
|||
reranker: NodeReranker = Field(default=NodeReranker.rrf)
|
||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||
|
||||
|
||||
class CommunitySearchConfig(BaseModel):
|
||||
|
|
@ -78,6 +89,7 @@ class CommunitySearchConfig(BaseModel):
|
|||
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -48,14 +48,41 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
|
|||
edge_config=EdgeSearchConfig(
|
||||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||
reranker=EdgeReranker.mmr,
|
||||
mmr_lambda=1,
|
||||
),
|
||||
node_config=NodeSearchConfig(
|
||||
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||
reranker=NodeReranker.mmr,
|
||||
mmr_lambda=1,
|
||||
),
|
||||
community_config=CommunitySearchConfig(
|
||||
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||
reranker=CommunityReranker.mmr,
|
||||
mmr_lambda=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
|
||||
COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
|
||||
edge_config=EdgeSearchConfig(
|
||||
search_methods=[
|
||||
EdgeSearchMethod.bm25,
|
||||
EdgeSearchMethod.cosine_similarity,
|
||||
EdgeSearchMethod.bfs,
|
||||
],
|
||||
reranker=EdgeReranker.cross_encoder,
|
||||
),
|
||||
node_config=NodeSearchConfig(
|
||||
search_methods=[
|
||||
NodeSearchMethod.bm25,
|
||||
NodeSearchMethod.cosine_similarity,
|
||||
NodeSearchMethod.bfs,
|
||||
],
|
||||
reranker=NodeReranker.cross_encoder,
|
||||
),
|
||||
community_config=CommunitySearchConfig(
|
||||
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||
reranker=CommunityReranker.cross_encoder,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -81,7 +108,6 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
|||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||
reranker=EdgeReranker.node_distance,
|
||||
),
|
||||
limit=30,
|
||||
)
|
||||
|
||||
# performs a hybrid search over edges with episode mention reranking
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ import logging
|
|||
from collections import defaultdict
|
||||
from time import time
|
||||
|
||||
import neo4j
|
||||
import numpy as np
|
||||
from neo4j import AsyncDriver, Query
|
||||
|
||||
|
|
@ -38,6 +37,7 @@ logger = logging.getLogger(__name__)
|
|||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
DEFAULT_MIN_SCORE = 0.6
|
||||
DEFAULT_MMR_LAMBDA = 0.5
|
||||
MAX_SEARCH_DEPTH = 3
|
||||
MAX_QUERY_LENGTH = 128
|
||||
|
||||
|
||||
|
|
@ -80,23 +80,21 @@ async def get_mentioned_nodes(
|
|||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
) -> list[EntityNode]:
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
{'uuids': episode_uuids},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
uuids=episode_uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -107,23 +105,21 @@ async def get_communities_by_nodes(
|
|||
driver: AsyncDriver, nodes: list[EntityNode]
|
||||
) -> list[CommunityNode]:
|
||||
node_uuids = [node.uuid for node in nodes]
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
c.uuid As uuid,
|
||||
c.group_id AS group_id,
|
||||
c.name AS name,
|
||||
c.name_embedding AS name_embedding
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
{'uuids': node_uuids},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
c.uuid As uuid,
|
||||
c.group_id AS group_id,
|
||||
c.name AS name,
|
||||
c.name_embedding AS name_embedding
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
uuids=node_uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -149,7 +145,7 @@ async def edge_fulltext_search(
|
|||
MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity)
|
||||
WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
||||
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
||||
RETURN
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
|
|
@ -165,20 +161,16 @@ async def edge_fulltext_search(
|
|||
ORDER BY score DESC LIMIT $limit
|
||||
""")
|
||||
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
cypher_query,
|
||||
{
|
||||
'query': fuzzy_query,
|
||||
'source_uuid': source_node_uuid,
|
||||
'target_uuid': target_node_uuid,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
records, _, _ = await driver.execute_query(
|
||||
cypher_query,
|
||||
query=fuzzy_query,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -201,13 +193,13 @@ async def edge_similarity_search(
|
|||
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
||||
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
||||
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
||||
WITH DISTINCT n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
startNode(r).uuid AS source_node_uuid,
|
||||
endNode(r).uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
|
|
@ -220,21 +212,59 @@ async def edge_similarity_search(
|
|||
LIMIT $limit
|
||||
""")
|
||||
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
query,
|
||||
{
|
||||
'search_vector': search_vector,
|
||||
'source_uuid': source_node_uuid,
|
||||
'target_uuid': target_node_uuid,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
'min_score': min_score,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
search_vector=search_vector,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def edge_bfs_search(
|
||||
driver: AsyncDriver,
|
||||
bfs_origin_node_uuids: list[str] | None,
|
||||
bfs_max_depth: int,
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
if bfs_origin_node_uuids is None:
|
||||
return []
|
||||
|
||||
query = Query("""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
UNWIND relationships(path) AS rel
|
||||
MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-()
|
||||
RETURN DISTINCT
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
startNode(r).uuid AS source_node_uuid,
|
||||
endNode(r).uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
""")
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -252,30 +282,26 @@ async def node_fulltext_search(
|
|||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
||||
YIELD node AS n, score
|
||||
RETURN
|
||||
n.uuid AS uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
{
|
||||
'query': fuzzy_query,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
||||
YIELD node AS n, score
|
||||
RETURN
|
||||
n.uuid AS uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
|
@ -289,34 +315,62 @@ async def node_similarity_search(
|
|||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
MATCH (n:Entity)
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
{
|
||||
'search_vector': search_vector,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
'min_score': min_score,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
MATCH (n:Entity)
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def node_bfs_search(
|
||||
driver: AsyncDriver,
|
||||
bfs_origin_node_uuids: list[str] | None,
|
||||
bfs_max_depth: int,
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
if bfs_origin_node_uuids is None:
|
||||
return []
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
LIMIT $limit
|
||||
""",
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
return nodes
|
||||
|
|
@ -333,30 +387,26 @@ async def community_fulltext_search(
|
|||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||
YIELD node AS comm, score
|
||||
RETURN
|
||||
comm.uuid AS uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.name_embedding AS name_embedding,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
{
|
||||
'query': fuzzy_query,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||
YIELD node AS comm, score
|
||||
RETURN
|
||||
comm.uuid AS uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.name_embedding AS name_embedding,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
return communities
|
||||
|
|
@ -370,34 +420,30 @@ async def community_similarity_search(
|
|||
min_score=DEFAULT_MIN_SCORE,
|
||||
) -> list[CommunityNode]:
|
||||
# vector similarity search over entity names
|
||||
async with driver.session(
|
||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
||||
) as session:
|
||||
result = await session.run(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
MATCH (comm:Community)
|
||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
comm.uuid As uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.name_embedding AS name_embedding,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
{
|
||||
'search_vector': search_vector,
|
||||
'group_ids': group_ids,
|
||||
'limit': limit,
|
||||
'min_score': min_score,
|
||||
},
|
||||
)
|
||||
records = [record async for record in result]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
MATCH (comm:Community)
|
||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
comm.uuid As uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.name_embedding AS name_embedding,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
return communities
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.3.16"
|
||||
version = "0.3.17"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
|
|
|||
|
|
@ -26,7 +26,9 @@ from dotenv import load_dotenv
|
|||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
|
||||
from graphiti_core.search.search_config_recipes import (
|
||||
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
|
@ -60,22 +62,19 @@ def setup_logging():
|
|||
return logger
|
||||
|
||||
|
||||
def format_context(facts):
|
||||
formatted_string = ''
|
||||
formatted_string += 'FACTS:\n'
|
||||
for fact in facts:
|
||||
formatted_string += f' - {fact}\n'
|
||||
formatted_string += '\n'
|
||||
|
||||
return formatted_string.strip()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphiti_init():
|
||||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
episodes = await graphiti.retrieve_episodes(datetime.now(), group_ids=None)
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
|
||||
results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None)
|
||||
results = await graphiti._search(
|
||||
"Emily: I can't log in",
|
||||
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
|
||||
bfs_origin_node_uuids=episode_uuids,
|
||||
group_ids=None,
|
||||
)
|
||||
pretty_results = {
|
||||
'edges': [edge.fact for edge in results.edges],
|
||||
'nodes': [node.name for node in results.nodes],
|
||||
|
|
@ -83,6 +82,7 @@ async def test_graphiti_init():
|
|||
}
|
||||
|
||||
logger.info(pretty_results)
|
||||
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue