[Bug Fix] Fix the Group ID usage with FalkorDB (#733)
* groupid-none * groupid-def-fulltext * lint * Update graphiti_core/helpers.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
d96f362875
commit
35e0692328
5 changed files with 28 additions and 9 deletions
|
|
@ -46,6 +46,7 @@ class GraphDriverSession(ABC):
|
|||
|
||||
class GraphDriver(ABC):
|
||||
provider: str
|
||||
fulltext_syntax: str = '' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||
|
||||
@abstractmethod
|
||||
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||
|
|
|
|||
|
|
@ -97,6 +97,9 @@ class FalkorDriver(GraphDriver):
|
|||
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
||||
self._database = database
|
||||
|
||||
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
|
||||
|
||||
|
||||
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
||||
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
|
||||
if graph_name is None:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import (
|
||||
get_default_group_id,
|
||||
semaphore_gather,
|
||||
validate_excluded_entity_types,
|
||||
validate_group_id,
|
||||
|
|
@ -352,7 +353,7 @@ class Graphiti:
|
|||
source_description: str,
|
||||
reference_time: datetime,
|
||||
source: EpisodeType = EpisodeType.message,
|
||||
group_id: str = '',
|
||||
group_id: str | None = None,
|
||||
uuid: str | None = None,
|
||||
update_communities: bool = False,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
|
|
@ -420,7 +421,10 @@ class Graphiti:
|
|||
start = time()
|
||||
now = utc_now()
|
||||
|
||||
# if group_id is None, use the default group id by the provider
|
||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
||||
validate_entity_types(entity_types)
|
||||
|
||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||
validate_group_id(group_id)
|
||||
|
||||
|
|
@ -537,7 +541,7 @@ class Graphiti:
|
|||
async def add_episode_bulk(
|
||||
self,
|
||||
bulk_episodes: list[RawEpisode],
|
||||
group_id: str = '',
|
||||
group_id: str | None = None,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
excluded_entity_types: list[str] | None = None,
|
||||
edge_types: dict[str, BaseModel] | None = None,
|
||||
|
|
@ -583,6 +587,8 @@ class Graphiti:
|
|||
start = time()
|
||||
now = utc_now()
|
||||
|
||||
# if group_id is None, use the default group id by the provider
|
||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
||||
validate_group_id(group_id)
|
||||
|
||||
# Create default edge type map
|
||||
|
|
|
|||
|
|
@ -51,6 +51,15 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None
|
|||
else None
|
||||
)
|
||||
|
||||
def get_default_group_id(db_type: str) -> str:
|
||||
"""
|
||||
This function differentiates the default group id based on the database type.
|
||||
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
||||
"""
|
||||
if db_type == 'falkordb':
|
||||
return '_'
|
||||
else:
|
||||
return ''
|
||||
|
||||
def lucene_sanitize(query: str) -> str:
|
||||
# Escape special characters from a query before passing into Lucene
|
||||
|
|
|
|||
|
|
@ -60,9 +60,9 @@ MAX_SEARCH_DEPTH = 3
|
|||
MAX_QUERY_LENGTH = 32
|
||||
|
||||
|
||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||
def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''):
|
||||
group_ids_filter_list = (
|
||||
[f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
|
||||
[fulltext_syntax + f"group_id:'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
|
||||
)
|
||||
group_ids_filter = ''
|
||||
for f in group_ids_filter_list:
|
||||
|
|
@ -157,7 +157,7 @@ async def edge_fulltext_search(
|
|||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
# fulltext search over facts
|
||||
fuzzy_query = fulltext_query(query, group_ids)
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
||||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
|
|
@ -340,7 +340,7 @@ async def node_fulltext_search(
|
|||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = fulltext_query(query, group_ids)
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
||||
if fuzzy_query == '':
|
||||
return []
|
||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||
|
|
@ -472,7 +472,7 @@ async def episode_fulltext_search(
|
|||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EpisodicNode]:
|
||||
# BM25 search to get top episodes
|
||||
fuzzy_query = fulltext_query(query, group_ids)
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
||||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
|
|
@ -516,7 +516,7 @@ async def community_fulltext_search(
|
|||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
# BM25 search to get top communities
|
||||
fuzzy_query = fulltext_query(query, group_ids)
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
||||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
|
|
@ -740,7 +740,7 @@ async def get_relevant_nodes(
|
|||
'uuid': node.uuid,
|
||||
'name': node.name,
|
||||
'name_embedding': node.name_embedding,
|
||||
'fulltext_query': fulltext_query(node.name, [node.group_id]),
|
||||
'fulltext_query': fulltext_query(node.name, [node.group_id], driver.fulltext_syntax),
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue