Cross encoder reranker in search query (#202)

* cross encoder reranker

* update reranker

* add openai reranker

* format

* mypy

* update

* updates

* MyPy typing

* bump version
This commit is contained in:
Preston Rasmussen 2024-10-25 12:29:27 -04:00 committed by GitHub
parent 544f9e3fba
commit ceb60a3d33
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 445 additions and 190 deletions

View file

@ -0,0 +1,113 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
from typing import Any
import openai
from openai import AsyncOpenAI
from pydantic import BaseModel
from ..llm_client import LLMConfig, RateLimitError
from ..prompts import Message
from .client import CrossEncoderClient
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gpt-4o-mini'
class BooleanClassifier(BaseModel):
isTrue: bool
class OpenAIRerankerClient(CrossEncoderClient):
def __init__(self, config: LLMConfig | None = None):
"""
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
"""
if config is None:
config = LLMConfig()
self.config = config
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
openai_messages_list: Any = [
[
Message(
role='system',
content='You are an expert tasked with determining whether the passage is relevant to the query',
),
Message(
role='user',
content=f"""
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
<PASSAGE>
{query}
</PASSAGE>
{passage}
<QUERY>
</QUERY>
""",
),
]
for passage in passages
]
try:
responses = await asyncio.gather(
*[
self.client.chat.completions.create(
model=DEFAULT_MODEL,
messages=openai_messages,
temperature=0,
max_tokens=1,
logit_bias={'6432': 1, '7983': 1},
logprobs=True,
top_logprobs=2,
)
for openai_messages in openai_messages_list
]
)
responses_top_logprobs = [
response.choices[0].logprobs.content[0].top_logprobs
if response.choices[0].logprobs is not None
and response.choices[0].logprobs.content is not None
else []
for response in responses
]
scores: list[float] = []
for top_logprobs in responses_top_logprobs:
for logprob in top_logprobs:
if bool(logprob.token):
scores.append(logprob.logprob)
results = [(passage, score) for passage, score in zip(passages, scores)]
results.sort(reverse=True, key=lambda x: x[1])
return results
except openai.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise

View file

@ -23,8 +23,11 @@ from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
from pydantic import BaseModel
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search
@ -92,6 +95,7 @@ class Graphiti:
password: str,
llm_client: LLMClient | None = None,
embedder: EmbedderClient | None = None,
cross_encoder: CrossEncoderClient | None = None,
store_raw_episode_content: bool = True,
):
"""
@ -131,7 +135,7 @@ class Graphiti:
Graphiti if you're using the default OpenAIClient.
"""
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.database = 'neo4j'
self.database = DEFAULT_DATABASE
self.store_raw_episode_content = store_raw_episode_content
if llm_client:
self.llm_client = llm_client
@ -141,6 +145,10 @@ class Graphiti:
self.embedder = embedder
else:
self.embedder = OpenAIEmbedder()
if cross_encoder:
self.cross_encoder = cross_encoder
else:
self.cross_encoder = OpenAIRerankerClient()
async def close(self):
"""
@ -648,6 +656,7 @@ class Graphiti:
await search(
self.driver,
self.embedder,
self.cross_encoder,
query,
group_ids,
search_config,
@ -663,8 +672,18 @@ class Graphiti:
config: SearchConfig,
group_ids: list[str] | None = None,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults:
return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
return await search(
self.driver,
self.embedder,
self.cross_encoder,
query,
group_ids,
config,
center_node_uuid,
bfs_origin_node_uuids,
)
async def get_nodes_by_query(
self,
@ -716,7 +735,13 @@ class Graphiti:
nodes = (
await search(
self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
self.driver,
self.embedder,
self.cross_encoder,
query,
group_ids,
search_config,
center_node_uuid,
)
).nodes
return nodes

View file

@ -21,6 +21,7 @@ from time import time
from neo4j import AsyncDriver
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError
@ -39,6 +40,7 @@ from graphiti_core.search.search_config import (
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
edge_bfs_search,
edge_fulltext_search,
edge_similarity_search,
episode_mentions_reranker,
@ -55,10 +57,12 @@ logger = logging.getLogger(__name__)
async def search(
driver: AsyncDriver,
embedder: EmbedderClient,
cross_encoder: CrossEncoderClient,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults:
start = time()
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
@ -68,28 +72,34 @@ async def search(
edges, nodes, communities = await asyncio.gather(
edge_search(
driver,
cross_encoder,
query,
query_vector,
group_ids,
config.edge_config,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
),
node_search(
driver,
cross_encoder,
query,
query_vector,
group_ids,
config.node_config,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
),
community_search(
driver,
cross_encoder,
query,
query_vector,
group_ids,
config.community_config,
bfs_origin_node_uuids,
config.limit,
),
)
@ -109,11 +119,13 @@ async def search(
async def edge_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]:
if config is None:
@ -126,6 +138,7 @@ async def edge_search(
edge_similarity_search(
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
),
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth),
]
)
)
@ -146,6 +159,10 @@ async def edge_search(
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == EdgeReranker.cross_encoder:
fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result}
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
@ -176,11 +193,13 @@ async def edge_search(
async def node_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: NodeSearchConfig | None,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
if config is None:
@ -212,6 +231,12 @@ async def node_search(
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == NodeReranker.cross_encoder:
summary_to_uuid_map = {
node.summary: node.uuid for result in search_results for node in result
}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
elif config.reranker == NodeReranker.node_distance:
@ -228,10 +253,12 @@ async def node_search(
async def community_search(
driver: AsyncDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]:
if config is None:
@ -268,6 +295,12 @@ async def community_search(
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == CommunityReranker.cross_encoder:
summary_to_uuid_map = {
node.summary: node.uuid for result in search_results for node in result
}
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]

View file

@ -20,7 +20,11 @@ from pydantic import BaseModel, Field
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
from graphiti_core.search.search_utils import (
DEFAULT_MIN_SCORE,
DEFAULT_MMR_LAMBDA,
MAX_SEARCH_DEPTH,
)
DEFAULT_SEARCH_LIMIT = 10
@ -28,11 +32,13 @@ DEFAULT_SEARCH_LIMIT = 10
class EdgeSearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
bfs = 'breadth_first_search'
class NodeSearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
bfs = 'breadth_first_search'
class CommunitySearchMethod(Enum):
@ -45,6 +51,7 @@ class EdgeReranker(Enum):
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'
mmr = 'mmr'
cross_encoder = 'cross_encoder'
class NodeReranker(Enum):
@ -52,11 +59,13 @@ class NodeReranker(Enum):
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'
mmr = 'mmr'
cross_encoder = 'cross_encoder'
class CommunityReranker(Enum):
rrf = 'reciprocal_rank_fusion'
mmr = 'mmr'
cross_encoder = 'cross_encoder'
class EdgeSearchConfig(BaseModel):
@ -64,6 +73,7 @@ class EdgeSearchConfig(BaseModel):
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
class NodeSearchConfig(BaseModel):
@ -71,6 +81,7 @@ class NodeSearchConfig(BaseModel):
reranker: NodeReranker = Field(default=NodeReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
class CommunitySearchConfig(BaseModel):
@ -78,6 +89,7 @@ class CommunitySearchConfig(BaseModel):
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
class SearchConfig(BaseModel):

View file

@ -48,14 +48,41 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.mmr,
mmr_lambda=1,
),
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.mmr,
mmr_lambda=1,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.mmr,
mmr_lambda=1,
),
)
# Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[
EdgeSearchMethod.bm25,
EdgeSearchMethod.cosine_similarity,
EdgeSearchMethod.bfs,
],
reranker=EdgeReranker.cross_encoder,
),
node_config=NodeSearchConfig(
search_methods=[
NodeSearchMethod.bm25,
NodeSearchMethod.cosine_similarity,
NodeSearchMethod.bfs,
],
reranker=NodeReranker.cross_encoder,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.cross_encoder,
),
)
@ -81,7 +108,6 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.node_distance,
),
limit=30,
)
# performs a hybrid search over edges with episode mention reranking

View file

@ -19,7 +19,6 @@ import logging
from collections import defaultdict
from time import time
import neo4j
import numpy as np
from neo4j import AsyncDriver, Query
@ -38,6 +37,7 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 128
@ -80,23 +80,21 @@ async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes]
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
""",
{'uuids': episode_uuids},
)
records = [record async for record in result]
records, _, _ = await driver.execute_query(
"""
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
""",
uuids=episode_uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -107,23 +105,21 @@ async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes]
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
RETURN DISTINCT
c.uuid As uuid,
c.group_id AS group_id,
c.name AS name,
c.name_embedding AS name_embedding
c.created_at AS created_at,
c.summary AS summary
""",
{'uuids': node_uuids},
)
records = [record async for record in result]
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
RETURN DISTINCT
c.uuid As uuid,
c.group_id AS group_id,
c.name AS name,
c.name_embedding AS name_embedding
c.created_at AS created_at,
c.summary AS summary
""",
uuids=node_uuids,
database_=DEFAULT_DATABASE,
routing_='r',
)
communities = [get_community_node_from_record(record) for record in records]
@ -149,7 +145,7 @@ async def edge_fulltext_search(
MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity)
WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
RETURN
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
@ -165,20 +161,16 @@ async def edge_fulltext_search(
ORDER BY score DESC LIMIT $limit
""")
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
cypher_query,
{
'query': fuzzy_query,
'source_uuid': source_node_uuid,
'target_uuid': target_node_uuid,
'group_ids': group_ids,
'limit': limit,
},
)
records = [record async for record in result]
records, _, _ = await driver.execute_query(
cypher_query,
query=fuzzy_query,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -201,13 +193,13 @@ async def edge_similarity_search(
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
WITH DISTINCT n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
startNode(r).uuid AS source_node_uuid,
endNode(r).uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
@ -220,21 +212,59 @@ async def edge_similarity_search(
LIMIT $limit
""")
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
query,
{
'search_vector': search_vector,
'source_uuid': source_node_uuid,
'target_uuid': target_node_uuid,
'group_ids': group_ids,
'limit': limit,
'min_score': min_score,
},
)
records = [record async for record in result]
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
edges = [get_entity_edge_from_record(record) for record in records]
return edges
async def edge_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
if bfs_origin_node_uuids is None:
return []
query = Query("""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
UNWIND relationships(path) AS rel
MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-()
RETURN DISTINCT
r.uuid AS uuid,
r.group_id AS group_id,
startNode(r).uuid AS source_node_uuid,
endNode(r).uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
""")
records, _, _ = await driver.execute_query(
query,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
database_=DEFAULT_DATABASE,
routing_='r',
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -252,30 +282,26 @@ async def node_fulltext_search(
if fuzzy_query == '':
return []
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
YIELD node AS n, score
RETURN
n.uuid AS uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
{
'query': fuzzy_query,
'group_ids': group_ids,
'limit': limit,
},
)
records = [record async for record in result]
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
YIELD node AS n, score
RETURN
n.uuid AS uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
@ -289,34 +315,62 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
# vector similarity search over entity names
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
{
'search_vector': search_vector,
'group_ids': group_ids,
'limit': limit,
'min_score': min_score,
},
)
records = [record async for record in result]
records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
async def node_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
) -> list[EntityNode]:
# vector similarity search over entity names
if bfs_origin_node_uuids is None:
return []
records, _, _ = await driver.execute_query(
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
LIMIT $limit
""",
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
database_=DEFAULT_DATABASE,
routing_='r',
)
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
@ -333,30 +387,26 @@ async def community_fulltext_search(
if fuzzy_query == '':
return []
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
CALL db.index.fulltext.queryNodes("community_name", $query)
YIELD node AS comm, score
RETURN
comm.uuid AS uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
{
'query': fuzzy_query,
'group_ids': group_ids,
'limit': limit,
},
)
records = [record async for record in result]
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("community_name", $query)
YIELD node AS comm, score
RETURN
comm.uuid AS uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
communities = [get_community_node_from_record(record) for record in records]
return communities
@ -370,34 +420,30 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]:
# vector similarity search over entity names
async with driver.session(
database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
) as session:
result = await session.run(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
{
'search_vector': search_vector,
'group_ids': group_ids,
'limit': limit,
'min_score': min_score,
},
)
records = [record async for record in result]
records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
min_score=min_score,
database_=DEFAULT_DATABASE,
routing_='r',
)
communities = [get_community_node_from_record(record) for record in records]
return communities

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.3.16"
version = "0.3.17"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",

View file

@ -26,7 +26,9 @@ from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
from graphiti_core.search.search_config_recipes import (
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
)
pytestmark = pytest.mark.integration
@ -60,22 +62,19 @@ def setup_logging():
return logger
def format_context(facts):
formatted_string = ''
formatted_string += 'FACTS:\n'
for fact in facts:
formatted_string += f' - {fact}\n'
formatted_string += '\n'
return formatted_string.strip()
@pytest.mark.asyncio
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
episodes = await graphiti.retrieve_episodes(datetime.now(), group_ids=None)
episode_uuids = [episode.uuid for episode in episodes]
results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None)
results = await graphiti._search(
"Emily: I can't log in",
COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
bfs_origin_node_uuids=episode_uuids,
group_ids=None,
)
pretty_results = {
'edges': [edge.fact for edge in results.edges],
'nodes': [node.name for node in results.nodes],
@ -83,6 +82,7 @@ async def test_graphiti_init():
}
logger.info(pretty_results)
await graphiti.close()