From ceb60a3d335a9e14970b43b4c6aadb2a75961aa0 Mon Sep 17 00:00:00 2001
From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
Date: Fri, 25 Oct 2024 12:29:27 -0400
Subject: [PATCH] Cross encoder reranker in search query (#202)
* cross encoder reranker
* update reranker
* add openai reranker
* format
* mypy
* update
* updates
* MyPy typing
* bump version
---
.../cross_encoder/openai_reranker_client.py | 113 +++++
graphiti_core/graphiti.py | 31 +-
graphiti_core/search/search.py | 33 ++
graphiti_core/search/search_config.py | 14 +-
graphiti_core/search/search_config_recipes.py | 28 +-
graphiti_core/search/search_utils.py | 390 ++++++++++--------
pyproject.toml | 2 +-
tests/test_graphiti_int.py | 24 +-
8 files changed, 445 insertions(+), 190 deletions(-)
create mode 100644 graphiti_core/cross_encoder/openai_reranker_client.py
diff --git a/graphiti_core/cross_encoder/openai_reranker_client.py b/graphiti_core/cross_encoder/openai_reranker_client.py
new file mode 100644
index 00000000..d1e37228
--- /dev/null
+++ b/graphiti_core/cross_encoder/openai_reranker_client.py
@@ -0,0 +1,113 @@
+"""
+Copyright 2024, Zep Software, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import asyncio
+import logging
+from typing import Any
+
+import openai
+from openai import AsyncOpenAI
+from pydantic import BaseModel
+
+from ..llm_client import LLMConfig, RateLimitError
+from ..prompts import Message
+from .client import CrossEncoderClient
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_MODEL = 'gpt-4o-mini'
+
+
+class BooleanClassifier(BaseModel):
+ isTrue: bool
+
+
+class OpenAIRerankerClient(CrossEncoderClient):
+ def __init__(self, config: LLMConfig | None = None):
+ """
+ Initialize the OpenAIClient with the provided configuration, cache setting, and client.
+
+ Args:
+ config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
+ cache (bool): Whether to use caching for responses. Defaults to False.
+ client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
+
+ """
+ if config is None:
+ config = LLMConfig()
+
+ self.config = config
+ self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
+
+ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
+ openai_messages_list: Any = [
+ [
+ Message(
+ role='system',
+ content='You are an expert tasked with determining whether the passage is relevant to the query',
+ ),
+ Message(
+ role='user',
+ content=f"""
+ Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
+
+ {query}
+
+ {passage}
+
+
+ """,
+ ),
+ ]
+ for passage in passages
+ ]
+ try:
+ responses = await asyncio.gather(
+ *[
+ self.client.chat.completions.create(
+ model=DEFAULT_MODEL,
+ messages=openai_messages,
+ temperature=0,
+ max_tokens=1,
+ logit_bias={'6432': 1, '7983': 1},
+ logprobs=True,
+ top_logprobs=2,
+ )
+ for openai_messages in openai_messages_list
+ ]
+ )
+
+ responses_top_logprobs = [
+ response.choices[0].logprobs.content[0].top_logprobs
+ if response.choices[0].logprobs is not None
+ and response.choices[0].logprobs.content is not None
+ else []
+ for response in responses
+ ]
+ scores: list[float] = []
+ for top_logprobs in responses_top_logprobs:
+ for logprob in top_logprobs:
+ if bool(logprob.token):
+ scores.append(logprob.logprob)
+
+ results = [(passage, score) for passage, score in zip(passages, scores)]
+ results.sort(reverse=True, key=lambda x: x[1])
+ return results
+ except openai.RateLimitError as e:
+ raise RateLimitError from e
+ except Exception as e:
+ logger.error(f'Error in generating LLM response: {e}')
+ raise
diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py
index 9f9ec988..c7829c77 100644
--- a/graphiti_core/graphiti.py
+++ b/graphiti_core/graphiti.py
@@ -23,8 +23,11 @@ from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
from pydantic import BaseModel
+from graphiti_core.cross_encoder.client import CrossEncoderClient
+from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
+from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search
@@ -92,6 +95,7 @@ class Graphiti:
password: str,
llm_client: LLMClient | None = None,
embedder: EmbedderClient | None = None,
+ cross_encoder: CrossEncoderClient | None = None,
store_raw_episode_content: bool = True,
):
"""
@@ -131,7 +135,7 @@ class Graphiti:
Graphiti if you're using the default OpenAIClient.
"""
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
- self.database = 'neo4j'
+ self.database = DEFAULT_DATABASE
self.store_raw_episode_content = store_raw_episode_content
if llm_client:
self.llm_client = llm_client
@@ -141,6 +145,10 @@ class Graphiti:
self.embedder = embedder
else:
self.embedder = OpenAIEmbedder()
+ if cross_encoder:
+ self.cross_encoder = cross_encoder
+ else:
+ self.cross_encoder = OpenAIRerankerClient()
async def close(self):
"""
@@ -648,6 +656,7 @@ class Graphiti:
await search(
self.driver,
self.embedder,
+ self.cross_encoder,
query,
group_ids,
search_config,
@@ -663,8 +672,18 @@ class Graphiti:
config: SearchConfig,
group_ids: list[str] | None = None,
center_node_uuid: str | None = None,
+ bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults:
- return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
+ return await search(
+ self.driver,
+ self.embedder,
+ self.cross_encoder,
+ query,
+ group_ids,
+ config,
+ center_node_uuid,
+ bfs_origin_node_uuids,
+ )
async def get_nodes_by_query(
self,
@@ -716,7 +735,13 @@ class Graphiti:
nodes = (
await search(
- self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
+ self.driver,
+ self.embedder,
+ self.cross_encoder,
+ query,
+ group_ids,
+ search_config,
+ center_node_uuid,
)
).nodes
return nodes
diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py
index 1fa5eb7d..206ed5fd 100644
--- a/graphiti_core/search/search.py
+++ b/graphiti_core/search/search.py
@@ -21,6 +21,7 @@ from time import time
from neo4j import AsyncDriver
+from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError
@@ -39,6 +40,7 @@ from graphiti_core.search.search_config import (
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
+ edge_bfs_search,
edge_fulltext_search,
edge_similarity_search,
episode_mentions_reranker,
@@ -55,10 +57,12 @@ logger = logging.getLogger(__name__)
async def search(
driver: AsyncDriver,
embedder: EmbedderClient,
+ cross_encoder: CrossEncoderClient,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
center_node_uuid: str | None = None,
+ bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults:
start = time()
query_vector = await embedder.create(input=[query.replace('\n', ' ')])
@@ -68,28 +72,34 @@ async def search(
edges, nodes, communities = await asyncio.gather(
edge_search(
driver,
+ cross_encoder,
query,
query_vector,
group_ids,
config.edge_config,
center_node_uuid,
+ bfs_origin_node_uuids,
config.limit,
),
node_search(
driver,
+ cross_encoder,
query,
query_vector,
group_ids,
config.node_config,
center_node_uuid,
+ bfs_origin_node_uuids,
config.limit,
),
community_search(
driver,
+ cross_encoder,
query,
query_vector,
group_ids,
config.community_config,
+ bfs_origin_node_uuids,
config.limit,
),
)
@@ -109,11 +119,13 @@ async def search(
async def edge_search(
driver: AsyncDriver,
+ cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
center_node_uuid: str | None = None,
+ bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]:
if config is None:
@@ -126,6 +138,7 @@ async def edge_search(
edge_similarity_search(
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
),
+ edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth),
]
)
)
@@ -146,6 +159,10 @@ async def edge_search(
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
+ elif config.reranker == EdgeReranker.cross_encoder:
+ fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result}
+ reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
+ reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts]
elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
@@ -176,11 +193,13 @@ async def edge_search(
async def node_search(
driver: AsyncDriver,
+ cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: NodeSearchConfig | None,
center_node_uuid: str | None = None,
+ bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
if config is None:
@@ -212,6 +231,12 @@ async def node_search(
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
+ elif config.reranker == NodeReranker.cross_encoder:
+ summary_to_uuid_map = {
+ node.summary: node.uuid for result in search_results for node in result
+ }
+ reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
+ reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
elif config.reranker == NodeReranker.node_distance:
@@ -228,10 +253,12 @@ async def node_search(
async def community_search(
driver: AsyncDriver,
+ cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
+ bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]:
if config is None:
@@ -268,6 +295,12 @@ async def community_search(
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
+ elif config.reranker == CommunityReranker.cross_encoder:
+ summary_to_uuid_map = {
+ node.summary: node.uuid for result in search_results for node in result
+ }
+ reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
+ reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries]
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py
index badee7c6..9aa23daa 100644
--- a/graphiti_core/search/search_config.py
+++ b/graphiti_core/search/search_config.py
@@ -20,7 +20,11 @@ from pydantic import BaseModel, Field
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import CommunityNode, EntityNode
-from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA
+from graphiti_core.search.search_utils import (
+ DEFAULT_MIN_SCORE,
+ DEFAULT_MMR_LAMBDA,
+ MAX_SEARCH_DEPTH,
+)
DEFAULT_SEARCH_LIMIT = 10
@@ -28,11 +32,13 @@ DEFAULT_SEARCH_LIMIT = 10
class EdgeSearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
+ bfs = 'breadth_first_search'
class NodeSearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
+ bfs = 'breadth_first_search'
class CommunitySearchMethod(Enum):
@@ -45,6 +51,7 @@ class EdgeReranker(Enum):
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'
mmr = 'mmr'
+ cross_encoder = 'cross_encoder'
class NodeReranker(Enum):
@@ -52,11 +59,13 @@ class NodeReranker(Enum):
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'
mmr = 'mmr'
+ cross_encoder = 'cross_encoder'
class CommunityReranker(Enum):
rrf = 'reciprocal_rank_fusion'
mmr = 'mmr'
+ cross_encoder = 'cross_encoder'
class EdgeSearchConfig(BaseModel):
@@ -64,6 +73,7 @@ class EdgeSearchConfig(BaseModel):
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
+ bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
class NodeSearchConfig(BaseModel):
@@ -71,6 +81,7 @@ class NodeSearchConfig(BaseModel):
reranker: NodeReranker = Field(default=NodeReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
+ bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
class CommunitySearchConfig(BaseModel):
@@ -78,6 +89,7 @@ class CommunitySearchConfig(BaseModel):
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
+ bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
class SearchConfig(BaseModel):
diff --git a/graphiti_core/search/search_config_recipes.py b/graphiti_core/search/search_config_recipes.py
index abba4c29..712793db 100644
--- a/graphiti_core/search/search_config_recipes.py
+++ b/graphiti_core/search/search_config_recipes.py
@@ -48,14 +48,41 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.mmr,
+ mmr_lambda=1,
),
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.mmr,
+ mmr_lambda=1,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.mmr,
+ mmr_lambda=1,
+ ),
+)
+
+# Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
+COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
+ edge_config=EdgeSearchConfig(
+ search_methods=[
+ EdgeSearchMethod.bm25,
+ EdgeSearchMethod.cosine_similarity,
+ EdgeSearchMethod.bfs,
+ ],
+ reranker=EdgeReranker.cross_encoder,
+ ),
+ node_config=NodeSearchConfig(
+ search_methods=[
+ NodeSearchMethod.bm25,
+ NodeSearchMethod.cosine_similarity,
+ NodeSearchMethod.bfs,
+ ],
+ reranker=NodeReranker.cross_encoder,
+ ),
+ community_config=CommunitySearchConfig(
+ search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
+ reranker=CommunityReranker.cross_encoder,
),
)
@@ -81,7 +108,6 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.node_distance,
),
- limit=30,
)
# performs a hybrid search over edges with episode mention reranking
diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py
index 57d8b58b..dd963e04 100644
--- a/graphiti_core/search/search_utils.py
+++ b/graphiti_core/search/search_utils.py
@@ -19,7 +19,6 @@ import logging
from collections import defaultdict
from time import time
-import neo4j
import numpy as np
from neo4j import AsyncDriver, Query
@@ -38,6 +37,7 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
+MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 128
@@ -80,23 +80,21 @@ async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes]
- async with driver.session(
- database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
- ) as session:
- result = await session.run(
- """
- MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
- RETURN DISTINCT
- n.uuid As uuid,
- n.group_id AS group_id,
- n.name AS name,
- n.name_embedding AS name_embedding,
- n.created_at AS created_at,
- n.summary AS summary
- """,
- {'uuids': episode_uuids},
- )
- records = [record async for record in result]
+ records, _, _ = await driver.execute_query(
+ """
+ MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
+ RETURN DISTINCT
+ n.uuid As uuid,
+ n.group_id AS group_id,
+ n.name AS name,
+ n.name_embedding AS name_embedding,
+ n.created_at AS created_at,
+ n.summary AS summary
+ """,
+ uuids=episode_uuids,
+ database_=DEFAULT_DATABASE,
+ routing_='r',
+ )
nodes = [get_entity_node_from_record(record) for record in records]
@@ -107,23 +105,21 @@ async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes]
- async with driver.session(
- database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
- ) as session:
- result = await session.run(
- """
- MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
- RETURN DISTINCT
- c.uuid As uuid,
- c.group_id AS group_id,
- c.name AS name,
- c.name_embedding AS name_embedding
- c.created_at AS created_at,
- c.summary AS summary
- """,
- {'uuids': node_uuids},
- )
- records = [record async for record in result]
+ records, _, _ = await driver.execute_query(
+ """
+ MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
+ RETURN DISTINCT
+ c.uuid As uuid,
+ c.group_id AS group_id,
+ c.name AS name,
+ c.name_embedding AS name_embedding
+ c.created_at AS created_at,
+ c.summary AS summary
+ """,
+ uuids=node_uuids,
+ database_=DEFAULT_DATABASE,
+ routing_='r',
+ )
communities = [get_community_node_from_record(record) for record in records]
@@ -149,7 +145,7 @@ async def edge_fulltext_search(
MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity)
WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
- RETURN
+ RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
@@ -165,20 +161,16 @@ async def edge_fulltext_search(
ORDER BY score DESC LIMIT $limit
""")
- async with driver.session(
- database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
- ) as session:
- result = await session.run(
- cypher_query,
- {
- 'query': fuzzy_query,
- 'source_uuid': source_node_uuid,
- 'target_uuid': target_node_uuid,
- 'group_ids': group_ids,
- 'limit': limit,
- },
- )
- records = [record async for record in result]
+ records, _, _ = await driver.execute_query(
+ cypher_query,
+ query=fuzzy_query,
+ source_uuid=source_node_uuid,
+ target_uuid=target_node_uuid,
+ group_ids=group_ids,
+ limit=limit,
+ database_=DEFAULT_DATABASE,
+ routing_='r',
+ )
edges = [get_entity_edge_from_record(record) for record in records]
@@ -201,13 +193,13 @@ async def edge_similarity_search(
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
- WITH DISTINCT n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
+ WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
- n.uuid AS source_node_uuid,
- m.uuid AS target_node_uuid,
+ startNode(r).uuid AS source_node_uuid,
+ endNode(r).uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
@@ -220,21 +212,59 @@ async def edge_similarity_search(
LIMIT $limit
""")
- async with driver.session(
- database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
- ) as session:
- result = await session.run(
- query,
- {
- 'search_vector': search_vector,
- 'source_uuid': source_node_uuid,
- 'target_uuid': target_node_uuid,
- 'group_ids': group_ids,
- 'limit': limit,
- 'min_score': min_score,
- },
- )
- records = [record async for record in result]
+ records, _, _ = await driver.execute_query(
+ query,
+ search_vector=search_vector,
+ source_uuid=source_node_uuid,
+ target_uuid=target_node_uuid,
+ group_ids=group_ids,
+ limit=limit,
+ min_score=min_score,
+ database_=DEFAULT_DATABASE,
+ routing_='r',
+ )
+
+ edges = [get_entity_edge_from_record(record) for record in records]
+
+ return edges
+
+
+async def edge_bfs_search(
+ driver: AsyncDriver,
+ bfs_origin_node_uuids: list[str] | None,
+ bfs_max_depth: int,
+) -> list[EntityEdge]:
+ # vector similarity search over embedded facts
+ if bfs_origin_node_uuids is None:
+ return []
+
+ query = Query("""
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
+ MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
+ UNWIND relationships(path) AS rel
+ MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-()
+ RETURN DISTINCT
+ r.uuid AS uuid,
+ r.group_id AS group_id,
+ startNode(r).uuid AS source_node_uuid,
+ endNode(r).uuid AS target_node_uuid,
+ r.created_at AS created_at,
+ r.name AS name,
+ r.fact AS fact,
+ r.fact_embedding AS fact_embedding,
+ r.episodes AS episodes,
+ r.expired_at AS expired_at,
+ r.valid_at AS valid_at,
+ r.invalid_at AS invalid_at
+ """)
+
+ records, _, _ = await driver.execute_query(
+ query,
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
+ depth=bfs_max_depth,
+ database_=DEFAULT_DATABASE,
+ routing_='r',
+ )
edges = [get_entity_edge_from_record(record) for record in records]
@@ -252,30 +282,26 @@ async def node_fulltext_search(
if fuzzy_query == '':
return []
- async with driver.session(
- database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
- ) as session:
- result = await session.run(
- """
- CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
- YIELD node AS n, score
- RETURN
- n.uuid AS uuid,
- n.group_id AS group_id,
- n.name AS name,
- n.name_embedding AS name_embedding,
- n.created_at AS created_at,
- n.summary AS summary
- ORDER BY score DESC
- LIMIT $limit
- """,
- {
- 'query': fuzzy_query,
- 'group_ids': group_ids,
- 'limit': limit,
- },
- )
- records = [record async for record in result]
+ records, _, _ = await driver.execute_query(
+ """
+ CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
+ YIELD node AS n, score
+ RETURN
+ n.uuid AS uuid,
+ n.group_id AS group_id,
+ n.name AS name,
+ n.name_embedding AS name_embedding,
+ n.created_at AS created_at,
+ n.summary AS summary
+ ORDER BY score DESC
+ LIMIT $limit
+ """,
+ query=fuzzy_query,
+ group_ids=group_ids,
+ limit=limit,
+ database_=DEFAULT_DATABASE,
+ routing_='r',
+ )
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
@@ -289,34 +315,62 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
# vector similarity search over entity names
- async with driver.session(
- database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
- ) as session:
- result = await session.run(
- """
- CYPHER runtime = parallel parallelRuntimeSupport=all
- MATCH (n:Entity)
- WHERE $group_ids IS NULL OR n.group_id IN $group_ids
- WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
- WHERE score > $min_score
- RETURN
- n.uuid As uuid,
- n.group_id AS group_id,
- n.name AS name,
- n.name_embedding AS name_embedding,
- n.created_at AS created_at,
- n.summary AS summary
- ORDER BY score DESC
- LIMIT $limit
- """,
- {
- 'search_vector': search_vector,
- 'group_ids': group_ids,
- 'limit': limit,
- 'min_score': min_score,
- },
- )
- records = [record async for record in result]
+ records, _, _ = await driver.execute_query(
+ """
+ CYPHER runtime = parallel parallelRuntimeSupport=all
+ MATCH (n:Entity)
+ WHERE $group_ids IS NULL OR n.group_id IN $group_ids
+ WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
+ WHERE score > $min_score
+ RETURN
+ n.uuid As uuid,
+ n.group_id AS group_id,
+ n.name AS name,
+ n.name_embedding AS name_embedding,
+ n.created_at AS created_at,
+ n.summary AS summary
+ ORDER BY score DESC
+ LIMIT $limit
+ """,
+ search_vector=search_vector,
+ group_ids=group_ids,
+ limit=limit,
+ min_score=min_score,
+ database_=DEFAULT_DATABASE,
+ routing_='r',
+ )
+ nodes = [get_entity_node_from_record(record) for record in records]
+
+ return nodes
+
+
+async def node_bfs_search(
+ driver: AsyncDriver,
+ bfs_origin_node_uuids: list[str] | None,
+ bfs_max_depth: int,
+) -> list[EntityNode]:
+ # vector similarity search over entity names
+ if bfs_origin_node_uuids is None:
+ return []
+
+ records, _, _ = await driver.execute_query(
+ """
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
+ MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
+ RETURN DISTINCT
+ n.uuid As uuid,
+ n.group_id AS group_id,
+ n.name AS name,
+ n.name_embedding AS name_embedding,
+ n.created_at AS created_at,
+ n.summary AS summary
+ LIMIT $limit
+ """,
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
+ depth=bfs_max_depth,
+ database_=DEFAULT_DATABASE,
+ routing_='r',
+ )
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
@@ -333,30 +387,26 @@ async def community_fulltext_search(
if fuzzy_query == '':
return []
- async with driver.session(
- database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
- ) as session:
- result = await session.run(
- """
- CALL db.index.fulltext.queryNodes("community_name", $query)
- YIELD node AS comm, score
- RETURN
- comm.uuid AS uuid,
- comm.group_id AS group_id,
- comm.name AS name,
- comm.name_embedding AS name_embedding,
- comm.created_at AS created_at,
- comm.summary AS summary
- ORDER BY score DESC
- LIMIT $limit
- """,
- {
- 'query': fuzzy_query,
- 'group_ids': group_ids,
- 'limit': limit,
- },
- )
- records = [record async for record in result]
+ records, _, _ = await driver.execute_query(
+ """
+ CALL db.index.fulltext.queryNodes("community_name", $query)
+ YIELD node AS comm, score
+ RETURN
+ comm.uuid AS uuid,
+ comm.group_id AS group_id,
+ comm.name AS name,
+ comm.name_embedding AS name_embedding,
+ comm.created_at AS created_at,
+ comm.summary AS summary
+ ORDER BY score DESC
+ LIMIT $limit
+ """,
+ query=fuzzy_query,
+ group_ids=group_ids,
+ limit=limit,
+ database_=DEFAULT_DATABASE,
+ routing_='r',
+ )
communities = [get_community_node_from_record(record) for record in records]
return communities
@@ -370,34 +420,30 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]:
# vector similarity search over entity names
- async with driver.session(
- database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS
- ) as session:
- result = await session.run(
- """
- CYPHER runtime = parallel parallelRuntimeSupport=all
- MATCH (comm:Community)
- WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
- WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
- WHERE score > $min_score
- RETURN
- comm.uuid As uuid,
- comm.group_id AS group_id,
- comm.name AS name,
- comm.name_embedding AS name_embedding,
- comm.created_at AS created_at,
- comm.summary AS summary
- ORDER BY score DESC
- LIMIT $limit
- """,
- {
- 'search_vector': search_vector,
- 'group_ids': group_ids,
- 'limit': limit,
- 'min_score': min_score,
- },
- )
- records = [record async for record in result]
+ records, _, _ = await driver.execute_query(
+ """
+ CYPHER runtime = parallel parallelRuntimeSupport=all
+ MATCH (comm:Community)
+ WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
+ WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
+ WHERE score > $min_score
+ RETURN
+ comm.uuid As uuid,
+ comm.group_id AS group_id,
+ comm.name AS name,
+ comm.name_embedding AS name_embedding,
+ comm.created_at AS created_at,
+ comm.summary AS summary
+ ORDER BY score DESC
+ LIMIT $limit
+ """,
+ search_vector=search_vector,
+ group_ids=group_ids,
+ limit=limit,
+ min_score=min_score,
+ database_=DEFAULT_DATABASE,
+ routing_='r',
+ )
communities = [get_community_node_from_record(record) for record in records]
return communities
diff --git a/pyproject.toml b/pyproject.toml
index 45fcea33..d2458b2f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
-version = "0.3.16"
+version = "0.3.17"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk ",
diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py
index aa6a95b4..f7296f39 100644
--- a/tests/test_graphiti_int.py
+++ b/tests/test_graphiti_int.py
@@ -26,7 +26,9 @@ from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.nodes import EntityNode, EpisodicNode
-from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
+from graphiti_core.search.search_config_recipes import (
+ COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
+)
pytestmark = pytest.mark.integration
@@ -60,22 +62,19 @@ def setup_logging():
return logger
-def format_context(facts):
- formatted_string = ''
- formatted_string += 'FACTS:\n'
- for fact in facts:
- formatted_string += f' - {fact}\n'
- formatted_string += '\n'
-
- return formatted_string.strip()
-
-
@pytest.mark.asyncio
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
+ episodes = await graphiti.retrieve_episodes(datetime.now(), group_ids=None)
+ episode_uuids = [episode.uuid for episode in episodes]
- results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None)
+ results = await graphiti._search(
+ "Emily: I can't log in",
+ COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
+ bfs_origin_node_uuids=episode_uuids,
+ group_ids=None,
+ )
pretty_results = {
'edges': [edge.fact for edge in results.edges],
'nodes': [node.name for node in results.nodes],
@@ -83,6 +82,7 @@ async def test_graphiti_init():
}
logger.info(pretty_results)
+
await graphiti.close()