Search refactor + Community search (#111)

* WIP

* WIP

* WIP

* community search

* WIP

* WIP

* integration tested

* tests

* tests

* mypy

* mypy

* format
This commit is contained in:
Preston Rasmussen 2024-09-16 14:03:05 -04:00 committed by GitHub
parent e4ee8d62fa
commit d7c20c1f59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 780 additions and 329 deletions

View file

@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
messages = parse_podcast_messages()
if not use_bulk:
for i, message in enumerate(messages[3:130]):
for i, message in enumerate(messages[3:20]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',

View file

@ -1,3 +1,20 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
class GraphitiError(Exception):
"""Base exception class for Graphiti Core."""
@ -16,3 +33,11 @@ class NodeNotFoundError(GraphitiError):
def __init__(self, uuid: str):
self.message = f'node {uuid} not found'
super().__init__(self.message)
class SearchRerankerError(GraphitiError):
"""Raised when a node is not found."""
def __init__(self, text: str):
self.message = text
super().__init__(self.message)

View file

@ -24,14 +24,19 @@ from neo4j import AsyncGraphDatabase
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.llm_client.utils import generate_embedding
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
from graphiti_core.search.search import SearchConfig, search
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
from graphiti_core.search.search_config_recipes import (
EDGE_HYBRID_SEARCH_NODE_DISTANCE,
EDGE_HYBRID_SEARCH_RRF,
NODE_HYBRID_SEARCH_NODE_DISTANCE,
NODE_HYBRID_SEARCH_RRF,
)
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_relevant_edges,
get_relevant_nodes,
hybrid_node_search,
)
from graphiti_core.utils import (
build_episodic_edges,
@ -548,7 +553,7 @@ class Graphiti:
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
num_results=10,
num_results=DEFAULT_SEARCH_LIMIT,
):
"""
Perform a hybrid search on the knowledge graph.
@ -564,7 +569,7 @@ class Graphiti:
Facts will be reranked based on proximity to this node
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
num_results : int, optional
limit : int, optional
The maximum number of results to return. Defaults to 10.
Returns
@ -581,21 +586,17 @@ class Graphiti:
The search is performed using the current date and time as the reference
point for temporal relevance.
"""
reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
search_config = SearchConfig(
num_episodes=0,
num_edges=num_results,
num_nodes=0,
group_ids=group_ids,
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
reranker=reranker,
search_config = (
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
)
search_config.limit = num_results
edges = (
await hybrid_search(
await search(
self.driver,
self.llm_client.get_embedder(),
query,
datetime.now(),
group_ids,
search_config,
center_node_uuid,
)
@ -606,19 +607,20 @@ class Graphiti:
async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
group_ids: list[str | None] | None = None,
center_node_uuid: str | None = None,
):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
) -> SearchResults:
return await search(
self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
)
async def get_nodes_by_query(
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
limit: int = DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.
@ -629,7 +631,9 @@ class Graphiti:
Parameters
----------
query : str
The text query to search for in the graph.
The text query to search for in the graph
center_node_uuid: str, optional
Facts will be reranked based on proximity to this node.
group_ids : list[str | None] | None, optional
The graph partitions to return data from.
limit : int | None, optional
@ -655,8 +659,12 @@ class Graphiti:
If not specified, a default limit (defined in the search functions) will be used.
"""
embedder = self.llm_client.get_embedder()
query_embedding = await generate_embedding(embedder, query)
relevant_nodes = await hybrid_node_search(
[query], [query_embedding], self.driver, group_ids, limit
search_config = (
NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
)
return relevant_nodes
search_config.limit = limit
nodes = (
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
).nodes
return nodes

View file

@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from datetime import datetime
from neo4j import time as neo4j_time

View file

@ -1,3 +1,20 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
class RateLimitError(Exception):
"""Exception raised when the rate limit is exceeded."""

View file

@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import typing
from time import time
@ -17,6 +33,6 @@ async def generate_embedding(
embedding = embedding[:EMBEDDING_DIM]
end = time()
logger.debug(f'embedded text of length {len(text)} in {end-start} ms')
logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
return embedding

View file

@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Any, Protocol, TypedDict
from .models import Message, PromptFunction, PromptVersion

View file

@ -15,131 +15,227 @@ limitations under the License.
"""
import logging
from datetime import datetime
from enum import Enum
from time import time
from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from graphiti_core.edges import EntityEdge
from graphiti_core.errors import SearchRerankerError
from graphiti_core.llm_client.config import EMBEDDING_DIM
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
CommunityReranker,
CommunitySearchConfig,
CommunitySearchMethod,
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
SearchConfig,
SearchResults,
)
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
edge_fulltext_search,
edge_similarity_search,
get_mentioned_nodes,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
rrf,
)
from graphiti_core.utils import retrieve_episodes
from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
logger = logging.getLogger(__name__)
class SearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
class Reranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
class SearchConfig(BaseModel):
num_edges: int = Field(default=10)
num_nodes: int = Field(default=10)
num_episodes: int = EPISODE_WINDOW_LEN
group_ids: list[str | None] | None
search_methods: list[SearchMethod]
reranker: Reranker | None
class SearchResults(BaseModel):
episodes: list[EpisodicNode]
nodes: list[EntityNode]
edges: list[EntityEdge]
async def hybrid_search(
async def search(
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
group_ids: list[str | None] | None,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
start = time()
query = query.replace('\n', ' ')
episodes = []
nodes = []
edges = []
search_results = []
if config.num_episodes > 0:
episodes.extend(await retrieve_episodes(driver, timestamp, config.num_episodes))
nodes.extend(await get_mentioned_nodes(driver, episodes))
if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(
driver, query, None, None, config.group_ids, 2 * config.num_edges
edges = (
await edge_search(
driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
)
search_results.append(text_search)
if SearchMethod.cosine_similarity in config.search_methods:
query_text = query.replace('\n', ' ')
search_vector = (
(await embedder.create(input=[query_text], model='text-embedding-3-small'))
.data[0]
.embedding[:EMBEDDING_DIM]
if config.edge_config is not None
else []
)
nodes = (
await node_search(
driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
)
similarity_search = await edge_similarity_search(
driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
if config.node_config is not None
else []
)
communities = (
await community_search(
driver, embedder, query, group_ids, config.community_config, config.limit
)
search_results.append(similarity_search)
if config.community_config is not None
else []
)
if len(search_results) > 1 and config.reranker is None:
logger.exception('Multiple searches enabled without a reranker')
raise Exception('Multiple searches enabled without a reranker')
else:
edge_uuid_map = {}
search_result_uuids = []
for result in search_results:
result_uuids = []
for edge in result:
result_uuids.append(edge.uuid)
edge_uuid_map[edge.uuid] = edge
search_result_uuids.append(result_uuids)
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids: list[str] = []
if config.reranker == Reranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == Reranker.node_distance:
if center_node_uuid is None:
logger.exception('No center node provided for Node Distance reranker')
raise Exception('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(
driver, search_result_uuids, center_node_uuid
)
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)
context = SearchResults(
episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
results = SearchResults(
edges=edges[: config.limit],
nodes=nodes[: config.limit],
communities=communities[: config.limit],
)
end = time()
logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms')
return context
return results
async def edge_search(
driver: AsyncDriver,
embedder,
query: str,
group_ids: list[str | None] | None,
config: EdgeSearchConfig,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]:
search_results: list[list[EntityEdge]] = []
if EdgeSearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit)
search_results.append(text_search)
if EdgeSearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model='text-embedding-3-small'))
.data[0]
.embedding[:EMBEDDING_DIM]
)
similarity_search = await edge_similarity_search(
driver, search_vector, None, None, group_ids, 2 * limit
)
search_results.append(similarity_search)
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple edge searches enabled without a reranker')
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
reranked_uuids: list[str] = []
if config.reranker == EdgeReranker.rrf:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
source_to_edge_uuid_map = {
edge.source_node_uuid: edge.uuid for result in search_results for edge in result
}
source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_edges
async def node_search(
driver: AsyncDriver,
embedder,
query: str,
group_ids: list[str | None] | None,
config: NodeSearchConfig,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
search_results: list[list[EntityNode]] = []
if NodeSearchMethod.bm25 in config.search_methods:
text_search = await node_fulltext_search(driver, query, group_ids, 2 * limit)
search_results.append(text_search)
if NodeSearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model='text-embedding-3-small'))
.data[0]
.embedding[:EMBEDDING_DIM]
)
similarity_search = await node_similarity_search(
driver, search_vector, group_ids, 2 * limit
)
search_results.append(similarity_search)
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple node searches enabled without a reranker')
search_result_uuids = [[node.uuid for node in result] for result in search_results]
node_uuid_map = {node.uuid: node for result in search_results for node in result}
reranked_uuids: list[str] = []
if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_nodes
async def community_search(
driver: AsyncDriver,
embedder,
query: str,
group_ids: list[str | None] | None,
config: CommunitySearchConfig,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]:
search_results: list[list[CommunityNode]] = []
if CommunitySearchMethod.bm25 in config.search_methods:
text_search = await community_fulltext_search(driver, query, group_ids, 2 * limit)
search_results.append(text_search)
if CommunitySearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model='text-embedding-3-small'))
.data[0]
.embedding[:EMBEDDING_DIM]
)
similarity_search = await community_similarity_search(
driver, search_vector, group_ids, 2 * limit
)
search_results.append(similarity_search)
if len(search_results) > 1 and config.reranker is None:
raise SearchRerankerError('Multiple node searches enabled without a reranker')
search_result_uuids = [[community.uuid for community in result] for result in search_results]
community_uuid_map = {
community.uuid: community for result in search_results for community in result
}
reranked_uuids: list[str] = []
if config.reranker == CommunityReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_communities

View file

@ -0,0 +1,81 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from enum import Enum
from pydantic import BaseModel, Field
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import CommunityNode, EntityNode
DEFAULT_SEARCH_LIMIT = 10
class EdgeSearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
class NodeSearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
class CommunitySearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'
class EdgeReranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
class NodeReranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
class CommunityReranker(Enum):
rrf = 'reciprocal_rank_fusion'
class EdgeSearchConfig(BaseModel):
search_methods: list[EdgeSearchMethod]
reranker: EdgeReranker | None
class NodeSearchConfig(BaseModel):
search_methods: list[NodeSearchMethod]
reranker: NodeReranker | None
class CommunitySearchConfig(BaseModel):
search_methods: list[CommunitySearchMethod]
reranker: CommunityReranker | None
class SearchConfig(BaseModel):
edge_config: EdgeSearchConfig | None = Field(default=None)
node_config: NodeSearchConfig | None = Field(default=None)
community_config: CommunitySearchConfig | None = Field(default=None)
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
class SearchResults(BaseModel):
edges: list[EntityEdge]
nodes: list[EntityNode]
communities: list[CommunityNode]

View file

@ -0,0 +1,84 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from graphiti_core.search.search_config import (
CommunityReranker,
CommunitySearchConfig,
CommunitySearchMethod,
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
SearchConfig,
)
# Performs a hybrid search with rrf reranking over edges, nodes, and communities
COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.rrf,
),
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.rrf,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.rrf,
),
)
# performs a hybrid search over edges with rrf reranking
EDGE_HYBRID_SEARCH_RRF = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.rrf,
)
)
# performs a hybrid search over edges with node distance reranking
EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.node_distance,
)
)
# performs a hybrid search over nodes with rrf reranking
NODE_HYBRID_SEARCH_RRF = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.rrf,
)
)
# performs a hybrid search over nodes with node distance reranking
NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.node_distance,
)
)
# performs a hybrid search over communities with rrf reranking
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.rrf,
)
)

View file

@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
import re
@ -7,7 +23,13 @@ from time import time
from neo4j import AsyncDriver, Query
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.nodes import EntityNode, EpisodicNode, get_entity_node_from_record
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodicNode,
get_community_node_from_record,
get_entity_node_from_record,
)
logger = logging.getLogger(__name__)
@ -35,181 +57,6 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
return nodes
async def edge_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over embedded facts
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
if source_node_uuid is None and target_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
elif source_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
elif target_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit,
)
edges = [get_entity_edge_from_record(record) for record in records]
return edges
async def entity_similarity_search(
search_vector: list[float],
driver: AsyncDriver,
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score
MATCH (n WHERE n.group_id IN $group_ids)
RETURN
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
)
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
async def entity_fulltext_search(
query: str,
driver: AsyncDriver,
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
group_ids = group_ids if group_ids is not None else [None]
# BM25 search to get top nodes
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
YIELD node AS n, score
MATCH (n WHERE n.group_id in $group_ids)
RETURN
n.uuid AS uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
)
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
async def edge_fulltext_search(
driver: AsyncDriver,
query: str,
@ -322,6 +169,247 @@ async def edge_fulltext_search(
return edges
async def edge_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
group_ids: list[str | None] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over embedded facts
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
if source_node_uuid is None and target_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
elif source_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
elif target_node_uuid is None:
query = Query("""
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
""")
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit,
)
edges = [get_entity_edge_from_record(record) for record in records]
return edges
async def node_fulltext_search(
driver: AsyncDriver,
query: str,
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
group_ids = group_ids if group_ids is not None else [None]
# BM25 search to get top nodes
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
YIELD node AS n, score
MATCH (n WHERE n.group_id in $group_ids)
RETURN
n.uuid AS uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
)
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
async def node_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score
MATCH (n WHERE n.group_id IN $group_ids)
RETURN
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
)
nodes = [get_entity_node_from_record(record) for record in records]
return nodes
async def community_fulltext_search(
driver: AsyncDriver,
query: str,
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]:
group_ids = group_ids if group_ids is not None else [None]
# BM25 search to get top communities
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("community_name", $query)
YIELD node AS comm, score
MATCH (comm WHERE comm.group_id in $group_ids)
RETURN
comm.uuid AS uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
)
communities = [get_community_node_from_record(record) for record in records]
return communities
async def community_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
group_ids: list[str | None] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
YIELD node AS comm, score
MATCH (comm WHERE comm.group_id IN $group_ids)
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
)
communities = [get_community_node_from_record(record) for record in records]
return communities
async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
@ -371,8 +459,8 @@ async def hybrid_node_search(
results: list[list[EntityNode]] = list(
await asyncio.gather(
*[entity_fulltext_search(q, driver, group_ids, 2 * limit) for q in queries],
*[entity_similarity_search(e, driver, group_ids, 2 * limit) for e in embeddings],
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
)
)
@ -490,24 +578,23 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
async def node_distance_reranker(
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(results)
sorted_uuids = rrf(node_uuids)
scores: dict[str, float] = {}
# Find the shortest path to center node
query = Query("""
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid})
RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
RETURN length(p) AS score
""")
path_results = await asyncio.gather(
*[
driver.execute_query(
query,
edge_uuid=uuid,
node_uuid=uuid,
center_uuid=center_node_uuid,
)
for uuid in sorted_uuids
@ -518,15 +605,8 @@ async def node_distance_reranker(
records = result[0]
record = records[0] if len(records) > 0 else None
distance: float = record['score'] if record is not None else float('inf')
if record is not None and (
record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid
):
distance = 0
if uuid in scores:
scores[uuid] = min(distance, scores[uuid])
else:
scores[uuid] = distance
distance = 0 if uuid == center_node_uuid else distance
scores[uuid] = distance
# rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])

View file

@ -26,6 +26,7 @@ from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
pytestmark = pytest.mark.integration
@ -81,6 +82,17 @@ async def test_graphiti_init():
edges = await graphiti.search('issues with higher ed', group_ids=['1'])
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
results = await graphiti._search(
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=['1']
)
pretty_results = {
'edges': [edge.fact for edge in results.edges],
'nodes': [node.name for node in results.nodes],
'communities': [community.name for community in results.communities],
}
logger.info(pretty_results)
graphiti.close()

View file

@ -11,11 +11,11 @@ async def test_hybrid_node_search_deduplication():
# Mock the database driver
mock_driver = AsyncMock()
# Mock the entity_fulltext_search and entity_similarity_search functions
# Mock the node_fulltext_search and entity_similarity_search functions
with patch(
'graphiti_core.search.search_utils.entity_fulltext_search'
'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.entity_similarity_search'
'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search:
# Set up mock return values
mock_fulltext_search.side_effect = [
@ -47,9 +47,9 @@ async def test_hybrid_node_search_empty_results():
mock_driver = AsyncMock()
with patch(
'graphiti_core.search.search_utils.entity_fulltext_search'
'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.entity_similarity_search'
'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search:
mock_fulltext_search.return_value = []
mock_similarity_search.return_value = []
@ -66,9 +66,9 @@ async def test_hybrid_node_search_only_fulltext():
mock_driver = AsyncMock()
with patch(
'graphiti_core.search.search_utils.entity_fulltext_search'
'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.entity_similarity_search'
'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search:
mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')
@ -90,9 +90,9 @@ async def test_hybrid_node_search_with_limit():
mock_driver = AsyncMock()
with patch(
'graphiti_core.search.search_utils.entity_fulltext_search'
'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.entity_similarity_search'
'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search:
mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
@ -120,8 +120,8 @@ async def test_hybrid_node_search_with_limit():
assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1
# Verify that the limit was passed to the search functions
mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 2)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 2)
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 2)
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 2)
@pytest.mark.asyncio
@ -129,9 +129,9 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
mock_driver = AsyncMock()
with patch(
'graphiti_core.search.search_utils.entity_fulltext_search'
'graphiti_core.search.search_utils.node_fulltext_search'
) as mock_fulltext_search, patch(
'graphiti_core.search.search_utils.entity_similarity_search'
'graphiti_core.search.search_utils.node_similarity_search'
) as mock_similarity_search:
mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
@ -155,5 +155,5 @@ async def test_hybrid_node_search_with_limit_and_duplicates():
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1
mock_fulltext_search.assert_called_with('Test', mock_driver, ['1'], 4)
mock_similarity_search.assert_called_with([0.1, 0.2, 0.3], mock_driver, ['1'], 4)
mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 4)
mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 4)