Cross encoder reranker in search query (#202)
* cross encoder reranker * update reranker * add openai reranker * format * mypy * update * updates * MyPy typing * bump version
This commit is contained in:
parent
544f9e3fba
commit
ceb60a3d33
8 changed files with 445 additions and 190 deletions
113
graphiti_core/cross_encoder/openai_reranker_client.py
Normal file
113
graphiti_core/cross_encoder/openai_reranker_client.py
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..llm_client import LLMConfig, RateLimitError
|
||||||
|
from ..prompts import Message
|
||||||
|
from .client import CrossEncoderClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MODEL = 'gpt-4o-mini'
|
||||||
|
|
||||||
|
|
||||||
|
class BooleanClassifier(BaseModel):
|
||||||
|
isTrue: bool
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIRerankerClient(CrossEncoderClient):
|
||||||
|
def __init__(self, config: LLMConfig | None = None):
|
||||||
|
"""
|
||||||
|
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||||
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||||
|
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = LLMConfig()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
|
||||||
|
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
||||||
|
openai_messages_list: Any = [
|
||||||
|
[
|
||||||
|
Message(
|
||||||
|
role='system',
|
||||||
|
content='You are an expert tasked with determining whether the passage is relevant to the query',
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role='user',
|
||||||
|
content=f"""
|
||||||
|
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
|
||||||
|
<PASSAGE>
|
||||||
|
{query}
|
||||||
|
</PASSAGE>
|
||||||
|
{passage}
|
||||||
|
<QUERY>
|
||||||
|
</QUERY>
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for passage in passages
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
responses = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
self.client.chat.completions.create(
|
||||||
|
model=DEFAULT_MODEL,
|
||||||
|
messages=openai_messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=1,
|
||||||
|
logit_bias={'6432': 1, '7983': 1},
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=2,
|
||||||
|
)
|
||||||
|
for openai_messages in openai_messages_list
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
responses_top_logprobs = [
|
||||||
|
response.choices[0].logprobs.content[0].top_logprobs
|
||||||
|
if response.choices[0].logprobs is not None
|
||||||
|
and response.choices[0].logprobs.content is not None
|
||||||
|
else []
|
||||||
|
for response in responses
|
||||||
|
]
|
||||||
|
scores: list[float] = []
|
||||||
|
for top_logprobs in responses_top_logprobs:
|
||||||
|
for logprob in top_logprobs:
|
||||||
|
if bool(logprob.token):
|
||||||
|
scores.append(logprob.logprob)
|
||||||
|
|
||||||
|
results = [(passage, score) for passage, score in zip(passages, scores)]
|
||||||
|
results.sort(reverse=True, key=lambda x: x[1])
|
||||||
|
return results
|
||||||
|
except openai.RateLimitError as e:
|
||||||
|
raise RateLimitError from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Error in generating LLM response: {e}')
|
||||||
|
raise
|
||||||
|
|
@ -23,8 +23,11 @@ from dotenv import load_dotenv
|
||||||
from neo4j import AsyncGraphDatabase
|
from neo4j import AsyncGraphDatabase
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||||
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||||
|
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.search.search import SearchConfig, search
|
from graphiti_core.search.search import SearchConfig, search
|
||||||
|
|
@ -92,6 +95,7 @@ class Graphiti:
|
||||||
password: str,
|
password: str,
|
||||||
llm_client: LLMClient | None = None,
|
llm_client: LLMClient | None = None,
|
||||||
embedder: EmbedderClient | None = None,
|
embedder: EmbedderClient | None = None,
|
||||||
|
cross_encoder: CrossEncoderClient | None = None,
|
||||||
store_raw_episode_content: bool = True,
|
store_raw_episode_content: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -131,7 +135,7 @@ class Graphiti:
|
||||||
Graphiti if you're using the default OpenAIClient.
|
Graphiti if you're using the default OpenAIClient.
|
||||||
"""
|
"""
|
||||||
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
||||||
self.database = 'neo4j'
|
self.database = DEFAULT_DATABASE
|
||||||
self.store_raw_episode_content = store_raw_episode_content
|
self.store_raw_episode_content = store_raw_episode_content
|
||||||
if llm_client:
|
if llm_client:
|
||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
|
|
@ -141,6 +145,10 @@ class Graphiti:
|
||||||
self.embedder = embedder
|
self.embedder = embedder
|
||||||
else:
|
else:
|
||||||
self.embedder = OpenAIEmbedder()
|
self.embedder = OpenAIEmbedder()
|
||||||
|
if cross_encoder:
|
||||||
|
self.cross_encoder = cross_encoder
|
||||||
|
else:
|
||||||
|
self.cross_encoder = OpenAIRerankerClient()
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -648,6 +656,7 @@ class Graphiti:
|
||||||
await search(
|
await search(
|
||||||
self.driver,
|
self.driver,
|
||||||
self.embedder,
|
self.embedder,
|
||||||
|
self.cross_encoder,
|
||||||
query,
|
query,
|
||||||
group_ids,
|
group_ids,
|
||||||
search_config,
|
search_config,
|
||||||
|
|
@ -663,8 +672,18 @@ class Graphiti:
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
|
bfs_origin_node_uuids: list[str] | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
|
return await search(
|
||||||
|
self.driver,
|
||||||
|
self.embedder,
|
||||||
|
self.cross_encoder,
|
||||||
|
query,
|
||||||
|
group_ids,
|
||||||
|
config,
|
||||||
|
center_node_uuid,
|
||||||
|
bfs_origin_node_uuids,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_nodes_by_query(
|
async def get_nodes_by_query(
|
||||||
self,
|
self,
|
||||||
|
|
@ -716,7 +735,13 @@ class Graphiti:
|
||||||
|
|
||||||
nodes = (
|
nodes = (
|
||||||
await search(
|
await search(
|
||||||
self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
|
self.driver,
|
||||||
|
self.embedder,
|
||||||
|
self.cross_encoder,
|
||||||
|
query,
|
||||||
|
group_ids,
|
||||||
|
search_config,
|
||||||
|
center_node_uuid,
|
||||||
)
|
)
|
||||||
).nodes
|
).nodes
|
||||||
return nodes
|
return nodes
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from time import time
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
|
||||||
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import SearchRerankerError
|
from graphiti_core.errors import SearchRerankerError
|
||||||
|
|
@ -39,6 +40,7 @@ from graphiti_core.search.search_config import (
|
||||||
from graphiti_core.search.search_utils import (
|
from graphiti_core.search.search_utils import (
|
||||||
community_fulltext_search,
|
community_fulltext_search,
|
||||||
community_similarity_search,
|
community_similarity_search,
|
||||||
|
edge_bfs_search,
|
||||||
edge_fulltext_search,
|
edge_fulltext_search,
|
||||||
edge_similarity_search,
|
edge_similarity_search,
|
||||||
episode_mentions_reranker,
|
episode_mentions_reranker,
|
||||||
|
|
@ -55,10 +57,12 @@ logger = logging.getLogger(__name__)
|
||||||
async def search(
|
async def search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
embedder: EmbedderClient,
|
embedder: EmbedderClient,
|
||||||
|
cross_encoder: CrossEncoderClient,
|
||||||
query: str,
|
query: str,
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: SearchConfig,
|
config: SearchConfig,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
|
bfs_origin_node_uuids: list[str] | None = None,
|
||||||
) -> SearchResults:
|
) -> SearchResults:
|
||||||
start = time()
|
start = time()
|
||||||
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
|
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
|
||||||
|
|
@ -68,28 +72,34 @@ async def search(
|
||||||
edges, nodes, communities = await asyncio.gather(
|
edges, nodes, communities = await asyncio.gather(
|
||||||
edge_search(
|
edge_search(
|
||||||
driver,
|
driver,
|
||||||
|
cross_encoder,
|
||||||
query,
|
query,
|
||||||
query_vector,
|
query_vector,
|
||||||
group_ids,
|
group_ids,
|
||||||
config.edge_config,
|
config.edge_config,
|
||||||
center_node_uuid,
|
center_node_uuid,
|
||||||
|
bfs_origin_node_uuids,
|
||||||
config.limit,
|
config.limit,
|
||||||
),
|
),
|
||||||
node_search(
|
node_search(
|
||||||
driver,
|
driver,
|
||||||
|
cross_encoder,
|
||||||
query,
|
query,
|
||||||
query_vector,
|
query_vector,
|
||||||
group_ids,
|
group_ids,
|
||||||
config.node_config,
|
config.node_config,
|
||||||
center_node_uuid,
|
center_node_uuid,
|
||||||
|
bfs_origin_node_uuids,
|
||||||
config.limit,
|
config.limit,
|
||||||
),
|
),
|
||||||
community_search(
|
community_search(
|
||||||
driver,
|
driver,
|
||||||
|
cross_encoder,
|
||||||
query,
|
query,
|
||||||
query_vector,
|
query_vector,
|
||||||
group_ids,
|
group_ids,
|
||||||
config.community_config,
|
config.community_config,
|
||||||
|
bfs_origin_node_uuids,
|
||||||
config.limit,
|
config.limit,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -109,11 +119,13 @@ async def search(
|
||||||
|
|
||||||
async def edge_search(
|
async def edge_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
|
cross_encoder: CrossEncoderClient,
|
||||||
query: str,
|
query: str,
|
||||||
query_vector: list[float],
|
query_vector: list[float],
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: EdgeSearchConfig | None,
|
config: EdgeSearchConfig | None,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
|
bfs_origin_node_uuids: list[str] | None = None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
if config is None:
|
if config is None:
|
||||||
|
|
@ -126,6 +138,7 @@ async def edge_search(
|
||||||
edge_similarity_search(
|
edge_similarity_search(
|
||||||
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
|
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
|
||||||
),
|
),
|
||||||
|
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -146,6 +159,10 @@ async def edge_search(
|
||||||
reranked_uuids = maximal_marginal_relevance(
|
reranked_uuids = maximal_marginal_relevance(
|
||||||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||||
)
|
)
|
||||||
|
elif config.reranker == EdgeReranker.cross_encoder:
|
||||||
|
fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result}
|
||||||
|
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
||||||
|
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
|
||||||
elif config.reranker == EdgeReranker.node_distance:
|
elif config.reranker == EdgeReranker.node_distance:
|
||||||
if center_node_uuid is None:
|
if center_node_uuid is None:
|
||||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||||
|
|
@ -176,11 +193,13 @@ async def edge_search(
|
||||||
|
|
||||||
async def node_search(
|
async def node_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
|
cross_encoder: CrossEncoderClient,
|
||||||
query: str,
|
query: str,
|
||||||
query_vector: list[float],
|
query_vector: list[float],
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: NodeSearchConfig | None,
|
config: NodeSearchConfig | None,
|
||||||
center_node_uuid: str | None = None,
|
center_node_uuid: str | None = None,
|
||||||
|
bfs_origin_node_uuids: list[str] | None = None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
if config is None:
|
if config is None:
|
||||||
|
|
@ -212,6 +231,12 @@ async def node_search(
|
||||||
reranked_uuids = maximal_marginal_relevance(
|
reranked_uuids = maximal_marginal_relevance(
|
||||||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||||
)
|
)
|
||||||
|
elif config.reranker == NodeReranker.cross_encoder:
|
||||||
|
summary_to_uuid_map = {
|
||||||
|
node.summary: node.uuid for result in search_results for node in result
|
||||||
|
}
|
||||||
|
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
||||||
|
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
|
||||||
elif config.reranker == NodeReranker.episode_mentions:
|
elif config.reranker == NodeReranker.episode_mentions:
|
||||||
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
||||||
elif config.reranker == NodeReranker.node_distance:
|
elif config.reranker == NodeReranker.node_distance:
|
||||||
|
|
@ -228,10 +253,12 @@ async def node_search(
|
||||||
|
|
||||||
async def community_search(
|
async def community_search(
|
||||||
driver: AsyncDriver,
|
driver: AsyncDriver,
|
||||||
|
cross_encoder: CrossEncoderClient,
|
||||||
query: str,
|
query: str,
|
||||||
query_vector: list[float],
|
query_vector: list[float],
|
||||||
group_ids: list[str] | None,
|
group_ids: list[str] | None,
|
||||||
config: CommunitySearchConfig | None,
|
config: CommunitySearchConfig | None,
|
||||||
|
bfs_origin_node_uuids: list[str] | None = None,
|
||||||
limit=DEFAULT_SEARCH_LIMIT,
|
limit=DEFAULT_SEARCH_LIMIT,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
if config is None:
|
if config is None:
|
||||||
|
|
@ -268,6 +295,12 @@ async def community_search(
|
||||||
reranked_uuids = maximal_marginal_relevance(
|
reranked_uuids = maximal_marginal_relevance(
|
||||||
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
|
||||||
)
|
)
|
||||||
|
elif config.reranker == CommunityReranker.cross_encoder:
|
||||||
|
summary_to_uuid_map = {
|
||||||
|
node.summary: node.uuid for result in search_results for node in result
|
||||||
|
}
|
||||||
|
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
||||||
|
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
|
||||||
|
|
||||||
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,11 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||||
from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
|
from graphiti_core.search.search_utils import (
|
||||||
|
DEFAULT_MIN_SCORE,
|
||||||
|
DEFAULT_MMR_LAMBDA,
|
||||||
|
MAX_SEARCH_DEPTH,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_SEARCH_LIMIT = 10
|
DEFAULT_SEARCH_LIMIT = 10
|
||||||
|
|
||||||
|
|
@ -28,11 +32,13 @@ DEFAULT_SEARCH_LIMIT = 10
|
||||||
class EdgeSearchMethod(Enum):
|
class EdgeSearchMethod(Enum):
|
||||||
cosine_similarity = 'cosine_similarity'
|
cosine_similarity = 'cosine_similarity'
|
||||||
bm25 = 'bm25'
|
bm25 = 'bm25'
|
||||||
|
bfs = 'breadth_first_search'
|
||||||
|
|
||||||
|
|
||||||
class NodeSearchMethod(Enum):
|
class NodeSearchMethod(Enum):
|
||||||
cosine_similarity = 'cosine_similarity'
|
cosine_similarity = 'cosine_similarity'
|
||||||
bm25 = 'bm25'
|
bm25 = 'bm25'
|
||||||
|
bfs = 'breadth_first_search'
|
||||||
|
|
||||||
|
|
||||||
class CommunitySearchMethod(Enum):
|
class CommunitySearchMethod(Enum):
|
||||||
|
|
@ -45,6 +51,7 @@ class EdgeReranker(Enum):
|
||||||
node_distance = 'node_distance'
|
node_distance = 'node_distance'
|
||||||
episode_mentions = 'episode_mentions'
|
episode_mentions = 'episode_mentions'
|
||||||
mmr = 'mmr'
|
mmr = 'mmr'
|
||||||
|
cross_encoder = 'cross_encoder'
|
||||||
|
|
||||||
|
|
||||||
class NodeReranker(Enum):
|
class NodeReranker(Enum):
|
||||||
|
|
@ -52,11 +59,13 @@ class NodeReranker(Enum):
|
||||||
node_distance = 'node_distance'
|
node_distance = 'node_distance'
|
||||||
episode_mentions = 'episode_mentions'
|
episode_mentions = 'episode_mentions'
|
||||||
mmr = 'mmr'
|
mmr = 'mmr'
|
||||||
|
cross_encoder = 'cross_encoder'
|
||||||
|
|
||||||
|
|
||||||
class CommunityReranker(Enum):
|
class CommunityReranker(Enum):
|
||||||
rrf = 'reciprocal_rank_fusion'
|
rrf = 'reciprocal_rank_fusion'
|
||||||
mmr = 'mmr'
|
mmr = 'mmr'
|
||||||
|
cross_encoder = 'cross_encoder'
|
||||||
|
|
||||||
|
|
||||||
class EdgeSearchConfig(BaseModel):
|
class EdgeSearchConfig(BaseModel):
|
||||||
|
|
@ -64,6 +73,7 @@ class EdgeSearchConfig(BaseModel):
|
||||||
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
|
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
|
||||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||||
|
|
||||||
|
|
||||||
class NodeSearchConfig(BaseModel):
|
class NodeSearchConfig(BaseModel):
|
||||||
|
|
@ -71,6 +81,7 @@ class NodeSearchConfig(BaseModel):
|
||||||
reranker: NodeReranker = Field(default=NodeReranker.rrf)
|
reranker: NodeReranker = Field(default=NodeReranker.rrf)
|
||||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||||
|
|
||||||
|
|
||||||
class CommunitySearchConfig(BaseModel):
|
class CommunitySearchConfig(BaseModel):
|
||||||
|
|
@ -78,6 +89,7 @@ class CommunitySearchConfig(BaseModel):
|
||||||
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
||||||
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
||||||
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
||||||
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
||||||
|
|
||||||
|
|
||||||
class SearchConfig(BaseModel):
|
class SearchConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -48,14 +48,41 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
|
||||||
edge_config=EdgeSearchConfig(
|
edge_config=EdgeSearchConfig(
|
||||||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||||
reranker=EdgeReranker.mmr,
|
reranker=EdgeReranker.mmr,
|
||||||
|
mmr_lambda=1,
|
||||||
),
|
),
|
||||||
node_config=NodeSearchConfig(
|
node_config=NodeSearchConfig(
|
||||||
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||||
reranker=NodeReranker.mmr,
|
reranker=NodeReranker.mmr,
|
||||||
|
mmr_lambda=1,
|
||||||
),
|
),
|
||||||
community_config=CommunitySearchConfig(
|
community_config=CommunitySearchConfig(
|
||||||
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||||
reranker=CommunityReranker.mmr,
|
reranker=CommunityReranker.mmr,
|
||||||
|
mmr_lambda=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
|
||||||
|
COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
|
||||||
|
edge_config=EdgeSearchConfig(
|
||||||
|
search_methods=[
|
||||||
|
EdgeSearchMethod.bm25,
|
||||||
|
EdgeSearchMethod.cosine_similarity,
|
||||||
|
EdgeSearchMethod.bfs,
|
||||||
|
],
|
||||||
|
reranker=EdgeReranker.cross_encoder,
|
||||||
|
),
|
||||||
|
node_config=NodeSearchConfig(
|
||||||
|
search_methods=[
|
||||||
|
NodeSearchMethod.bm25,
|
||||||
|
NodeSearchMethod.cosine_similarity,
|
||||||
|
NodeSearchMethod.bfs,
|
||||||
|
],
|
||||||
|
reranker=NodeReranker.cross_encoder,
|
||||||
|
),
|
||||||
|
community_config=CommunitySearchConfig(
|
||||||
|
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
||||||
|
reranker=CommunityReranker.cross_encoder,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -81,7 +108,6 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||||
reranker=EdgeReranker.node_distance,
|
reranker=EdgeReranker.node_distance,
|
||||||
),
|
),
|
||||||
limit=30,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# performs a hybrid search over edges with episode mention reranking
|
# performs a hybrid search over edges with episode mention reranking
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
import neo4j
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from neo4j import AsyncDriver, Query
|
from neo4j import AsyncDriver, Query
|
||||||
|
|
||||||
|
|
@ -38,6 +37,7 @@ logger = logging.getLogger(__name__)
|
||||||
RELEVANT_SCHEMA_LIMIT = 3
|
RELEVANT_SCHEMA_LIMIT = 3
|
||||||
DEFAULT_MIN_SCORE = 0.6
|
DEFAULT_MIN_SCORE = 0.6
|
||||||
DEFAULT_MMR_LAMBDA = 0.5
|
DEFAULT_MMR_LAMBDA = 0.5
|
||||||
|
MAX_SEARCH_DEPTH = 3
|
||||||
MAX_QUERY_LENGTH = 128
|
MAX_QUERY_LENGTH = 128
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -80,23 +80,21 @@ async def get_mentioned_nodes(
|
||||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
async with driver.session(
|
records, _, _ = await driver.execute_query(
|
||||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
"""
|
||||||
) as session:
|
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||||
result = await session.run(
|
RETURN DISTINCT
|
||||||
"""
|
n.uuid As uuid,
|
||||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
n.group_id AS group_id,
|
||||||
RETURN DISTINCT
|
n.name AS name,
|
||||||
n.uuid As uuid,
|
n.name_embedding AS name_embedding,
|
||||||
n.group_id AS group_id,
|
n.created_at AS created_at,
|
||||||
n.name AS name,
|
n.summary AS summary
|
||||||
n.name_embedding AS name_embedding,
|
""",
|
||||||
n.created_at AS created_at,
|
uuids=episode_uuids,
|
||||||
n.summary AS summary
|
database_=DEFAULT_DATABASE,
|
||||||
""",
|
routing_='r',
|
||||||
{'uuids': episode_uuids},
|
)
|
||||||
)
|
|
||||||
records = [record async for record in result]
|
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -107,23 +105,21 @@ async def get_communities_by_nodes(
|
||||||
driver: AsyncDriver, nodes: list[EntityNode]
|
driver: AsyncDriver, nodes: list[EntityNode]
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
node_uuids = [node.uuid for node in nodes]
|
node_uuids = [node.uuid for node in nodes]
|
||||||
async with driver.session(
|
records, _, _ = await driver.execute_query(
|
||||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
"""
|
||||||
) as session:
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||||
result = await session.run(
|
RETURN DISTINCT
|
||||||
"""
|
c.uuid As uuid,
|
||||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
c.group_id AS group_id,
|
||||||
RETURN DISTINCT
|
c.name AS name,
|
||||||
c.uuid As uuid,
|
c.name_embedding AS name_embedding
|
||||||
c.group_id AS group_id,
|
c.created_at AS created_at,
|
||||||
c.name AS name,
|
c.summary AS summary
|
||||||
c.name_embedding AS name_embedding
|
""",
|
||||||
c.created_at AS created_at,
|
uuids=node_uuids,
|
||||||
c.summary AS summary
|
database_=DEFAULT_DATABASE,
|
||||||
""",
|
routing_='r',
|
||||||
{'uuids': node_uuids},
|
)
|
||||||
)
|
|
||||||
records = [record async for record in result]
|
|
||||||
|
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -165,20 +161,16 @@ async def edge_fulltext_search(
|
||||||
ORDER BY score DESC LIMIT $limit
|
ORDER BY score DESC LIMIT $limit
|
||||||
""")
|
""")
|
||||||
|
|
||||||
async with driver.session(
|
records, _, _ = await driver.execute_query(
|
||||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
cypher_query,
|
||||||
) as session:
|
query=fuzzy_query,
|
||||||
result = await session.run(
|
source_uuid=source_node_uuid,
|
||||||
cypher_query,
|
target_uuid=target_node_uuid,
|
||||||
{
|
group_ids=group_ids,
|
||||||
'query': fuzzy_query,
|
limit=limit,
|
||||||
'source_uuid': source_node_uuid,
|
database_=DEFAULT_DATABASE,
|
||||||
'target_uuid': target_node_uuid,
|
routing_='r',
|
||||||
'group_ids': group_ids,
|
)
|
||||||
'limit': limit,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
records = [record async for record in result]
|
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -201,13 +193,13 @@ async def edge_similarity_search(
|
||||||
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
||||||
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
||||||
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
||||||
WITH DISTINCT n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||||
WHERE score > $min_score
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
r.group_id AS group_id,
|
r.group_id AS group_id,
|
||||||
n.uuid AS source_node_uuid,
|
startNode(r).uuid AS source_node_uuid,
|
||||||
m.uuid AS target_node_uuid,
|
endNode(r).uuid AS target_node_uuid,
|
||||||
r.created_at AS created_at,
|
r.created_at AS created_at,
|
||||||
r.name AS name,
|
r.name AS name,
|
||||||
r.fact AS fact,
|
r.fact AS fact,
|
||||||
|
|
@ -220,21 +212,59 @@ async def edge_similarity_search(
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""")
|
""")
|
||||||
|
|
||||||
async with driver.session(
|
records, _, _ = await driver.execute_query(
|
||||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
query,
|
||||||
) as session:
|
search_vector=search_vector,
|
||||||
result = await session.run(
|
source_uuid=source_node_uuid,
|
||||||
query,
|
target_uuid=target_node_uuid,
|
||||||
{
|
group_ids=group_ids,
|
||||||
'search_vector': search_vector,
|
limit=limit,
|
||||||
'source_uuid': source_node_uuid,
|
min_score=min_score,
|
||||||
'target_uuid': target_node_uuid,
|
database_=DEFAULT_DATABASE,
|
||||||
'group_ids': group_ids,
|
routing_='r',
|
||||||
'limit': limit,
|
)
|
||||||
'min_score': min_score,
|
|
||||||
},
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
)
|
|
||||||
records = [record async for record in result]
|
return edges
|
||||||
|
|
||||||
|
|
||||||
|
async def edge_bfs_search(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
bfs_origin_node_uuids: list[str] | None,
|
||||||
|
bfs_max_depth: int,
|
||||||
|
) -> list[EntityEdge]:
|
||||||
|
# vector similarity search over embedded facts
|
||||||
|
if bfs_origin_node_uuids is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query = Query("""
|
||||||
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||||
|
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||||
|
UNWIND relationships(path) AS rel
|
||||||
|
MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-()
|
||||||
|
RETURN DISTINCT
|
||||||
|
r.uuid AS uuid,
|
||||||
|
r.group_id AS group_id,
|
||||||
|
startNode(r).uuid AS source_node_uuid,
|
||||||
|
endNode(r).uuid AS target_node_uuid,
|
||||||
|
r.created_at AS created_at,
|
||||||
|
r.name AS name,
|
||||||
|
r.fact AS fact,
|
||||||
|
r.fact_embedding AS fact_embedding,
|
||||||
|
r.episodes AS episodes,
|
||||||
|
r.expired_at AS expired_at,
|
||||||
|
r.valid_at AS valid_at,
|
||||||
|
r.invalid_at AS invalid_at
|
||||||
|
""")
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||||
|
depth=bfs_max_depth,
|
||||||
|
database_=DEFAULT_DATABASE,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -252,30 +282,26 @@ async def node_fulltext_search(
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async with driver.session(
|
records, _, _ = await driver.execute_query(
|
||||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
"""
|
||||||
) as session:
|
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
||||||
result = await session.run(
|
YIELD node AS n, score
|
||||||
"""
|
RETURN
|
||||||
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
|
n.uuid AS uuid,
|
||||||
YIELD node AS n, score
|
n.group_id AS group_id,
|
||||||
RETURN
|
n.name AS name,
|
||||||
n.uuid AS uuid,
|
n.name_embedding AS name_embedding,
|
||||||
n.group_id AS group_id,
|
n.created_at AS created_at,
|
||||||
n.name AS name,
|
n.summary AS summary
|
||||||
n.name_embedding AS name_embedding,
|
ORDER BY score DESC
|
||||||
n.created_at AS created_at,
|
LIMIT $limit
|
||||||
n.summary AS summary
|
""",
|
||||||
ORDER BY score DESC
|
query=fuzzy_query,
|
||||||
LIMIT $limit
|
group_ids=group_ids,
|
||||||
""",
|
limit=limit,
|
||||||
{
|
database_=DEFAULT_DATABASE,
|
||||||
'query': fuzzy_query,
|
routing_='r',
|
||||||
'group_ids': group_ids,
|
)
|
||||||
'limit': limit,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
records = [record async for record in result]
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
@ -289,34 +315,62 @@ async def node_similarity_search(
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
async with driver.session(
|
records, _, _ = await driver.execute_query(
|
||||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
"""
|
||||||
) as session:
|
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||||
result = await session.run(
|
MATCH (n:Entity)
|
||||||
"""
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||||
MATCH (n:Entity)
|
WHERE score > $min_score
|
||||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
RETURN
|
||||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
n.uuid As uuid,
|
||||||
WHERE score > $min_score
|
n.group_id AS group_id,
|
||||||
RETURN
|
n.name AS name,
|
||||||
n.uuid As uuid,
|
n.name_embedding AS name_embedding,
|
||||||
n.group_id AS group_id,
|
n.created_at AS created_at,
|
||||||
n.name AS name,
|
n.summary AS summary
|
||||||
n.name_embedding AS name_embedding,
|
ORDER BY score DESC
|
||||||
n.created_at AS created_at,
|
LIMIT $limit
|
||||||
n.summary AS summary
|
""",
|
||||||
ORDER BY score DESC
|
search_vector=search_vector,
|
||||||
LIMIT $limit
|
group_ids=group_ids,
|
||||||
""",
|
limit=limit,
|
||||||
{
|
min_score=min_score,
|
||||||
'search_vector': search_vector,
|
database_=DEFAULT_DATABASE,
|
||||||
'group_ids': group_ids,
|
routing_='r',
|
||||||
'limit': limit,
|
)
|
||||||
'min_score': min_score,
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
},
|
|
||||||
)
|
return nodes
|
||||||
records = [record async for record in result]
|
|
||||||
|
|
||||||
|
async def node_bfs_search(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
bfs_origin_node_uuids: list[str] | None,
|
||||||
|
bfs_max_depth: int,
|
||||||
|
) -> list[EntityNode]:
|
||||||
|
# vector similarity search over entity names
|
||||||
|
if bfs_origin_node_uuids is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||||
|
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
||||||
|
RETURN DISTINCT
|
||||||
|
n.uuid As uuid,
|
||||||
|
n.group_id AS group_id,
|
||||||
|
n.name AS name,
|
||||||
|
n.name_embedding AS name_embedding,
|
||||||
|
n.created_at AS created_at,
|
||||||
|
n.summary AS summary
|
||||||
|
LIMIT $limit
|
||||||
|
""",
|
||||||
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||||
|
depth=bfs_max_depth,
|
||||||
|
database_=DEFAULT_DATABASE,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
@ -333,30 +387,26 @@ async def community_fulltext_search(
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async with driver.session(
|
records, _, _ = await driver.execute_query(
|
||||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
"""
|
||||||
) as session:
|
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||||
result = await session.run(
|
YIELD node AS comm, score
|
||||||
"""
|
RETURN
|
||||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
comm.uuid AS uuid,
|
||||||
YIELD node AS comm, score
|
comm.group_id AS group_id,
|
||||||
RETURN
|
comm.name AS name,
|
||||||
comm.uuid AS uuid,
|
comm.name_embedding AS name_embedding,
|
||||||
comm.group_id AS group_id,
|
comm.created_at AS created_at,
|
||||||
comm.name AS name,
|
comm.summary AS summary
|
||||||
comm.name_embedding AS name_embedding,
|
ORDER BY score DESC
|
||||||
comm.created_at AS created_at,
|
LIMIT $limit
|
||||||
comm.summary AS summary
|
""",
|
||||||
ORDER BY score DESC
|
query=fuzzy_query,
|
||||||
LIMIT $limit
|
group_ids=group_ids,
|
||||||
""",
|
limit=limit,
|
||||||
{
|
database_=DEFAULT_DATABASE,
|
||||||
'query': fuzzy_query,
|
routing_='r',
|
||||||
'group_ids': group_ids,
|
)
|
||||||
'limit': limit,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
records = [record async for record in result]
|
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
return communities
|
return communities
|
||||||
|
|
@ -370,34 +420,30 @@ async def community_similarity_search(
|
||||||
min_score=DEFAULT_MIN_SCORE,
|
min_score=DEFAULT_MIN_SCORE,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
async with driver.session(
|
records, _, _ = await driver.execute_query(
|
||||||
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
|
"""
|
||||||
) as session:
|
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||||
result = await session.run(
|
MATCH (comm:Community)
|
||||||
"""
|
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||||
MATCH (comm:Community)
|
WHERE score > $min_score
|
||||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
RETURN
|
||||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
comm.uuid As uuid,
|
||||||
WHERE score > $min_score
|
comm.group_id AS group_id,
|
||||||
RETURN
|
comm.name AS name,
|
||||||
comm.uuid As uuid,
|
comm.name_embedding AS name_embedding,
|
||||||
comm.group_id AS group_id,
|
comm.created_at AS created_at,
|
||||||
comm.name AS name,
|
comm.summary AS summary
|
||||||
comm.name_embedding AS name_embedding,
|
ORDER BY score DESC
|
||||||
comm.created_at AS created_at,
|
LIMIT $limit
|
||||||
comm.summary AS summary
|
""",
|
||||||
ORDER BY score DESC
|
search_vector=search_vector,
|
||||||
LIMIT $limit
|
group_ids=group_ids,
|
||||||
""",
|
limit=limit,
|
||||||
{
|
min_score=min_score,
|
||||||
'search_vector': search_vector,
|
database_=DEFAULT_DATABASE,
|
||||||
'group_ids': group_ids,
|
routing_='r',
|
||||||
'limit': limit,
|
)
|
||||||
'min_score': min_score,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
records = [record async for record in result]
|
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
return communities
|
return communities
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.3.16"
|
version = "0.3.17"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
authors = [
|
authors = [
|
||||||
"Paul Paliychuk <paul@getzep.com>",
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,9 @@ from dotenv import load_dotenv
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.graphiti import Graphiti
|
from graphiti_core.graphiti import Graphiti
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||||
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
|
from graphiti_core.search.search_config_recipes import (
|
||||||
|
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
|
||||||
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
@ -60,22 +62,19 @@ def setup_logging():
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
def format_context(facts):
|
|
||||||
formatted_string = ''
|
|
||||||
formatted_string += 'FACTS:\n'
|
|
||||||
for fact in facts:
|
|
||||||
formatted_string += f' - {fact}\n'
|
|
||||||
formatted_string += '\n'
|
|
||||||
|
|
||||||
return formatted_string.strip()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graphiti_init():
|
async def test_graphiti_init():
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||||
|
episodes = await graphiti.retrieve_episodes(datetime.now(), group_ids=None)
|
||||||
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
|
|
||||||
results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None)
|
results = await graphiti._search(
|
||||||
|
"Emily: I can't log in",
|
||||||
|
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
|
||||||
|
bfs_origin_node_uuids=episode_uuids,
|
||||||
|
group_ids=None,
|
||||||
|
)
|
||||||
pretty_results = {
|
pretty_results = {
|
||||||
'edges': [edge.fact for edge in results.edges],
|
'edges': [edge.fact for edge in results.edges],
|
||||||
'nodes': [node.name for node in results.nodes],
|
'nodes': [node.name for node in results.nodes],
|
||||||
|
|
@ -83,6 +82,7 @@ async def test_graphiti_init():
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(pretty_results)
|
logger.info(pretty_results)
|
||||||
|
|
||||||
await graphiti.close()
|
await graphiti.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue