[Bug Fix] Fix the Group ID usage with FalkorDB (#733)

* groupid-none

* groupid-def-fulltext

* lint

* Update graphiti_core/helpers.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Gal Shubeli 2025-07-17 19:35:08 +03:00 committed by GitHub
parent d96f362875
commit 35e0692328
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 28 additions and 9 deletions

View file

@ -46,6 +46,7 @@ class GraphDriverSession(ABC):
class GraphDriver(ABC):
provider: str
fulltext_syntax: str = '' # Neo4j (default) syntax does not require a prefix for fulltext queries
@abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:

View file

@ -97,6 +97,9 @@ class FalkorDriver(GraphDriver):
self.client = FalkorDB(host=host, port=port, username=username, password=password)
self._database = database
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
if graph_name is None:

View file

@ -30,6 +30,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import (
get_default_group_id,
semaphore_gather,
validate_excluded_entity_types,
validate_group_id,
@ -352,7 +353,7 @@ class Graphiti:
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
group_id: str = '',
group_id: str | None = None,
uuid: str | None = None,
update_communities: bool = False,
entity_types: dict[str, BaseModel] | None = None,
@ -420,7 +421,10 @@ class Graphiti:
start = time()
now = utc_now()
# if group_id is None, use the default group id by the provider
group_id = group_id or get_default_group_id(self.driver.provider)
validate_entity_types(entity_types)
validate_excluded_entity_types(excluded_entity_types, entity_types)
validate_group_id(group_id)
@ -537,7 +541,7 @@ class Graphiti:
async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
group_id: str = '',
group_id: str | None = None,
entity_types: dict[str, BaseModel] | None = None,
excluded_entity_types: list[str] | None = None,
edge_types: dict[str, BaseModel] | None = None,
@ -583,6 +587,8 @@ class Graphiti:
start = time()
now = utc_now()
# if group_id is None, use the default group id by the provider
group_id = group_id or get_default_group_id(self.driver.provider)
validate_group_id(group_id)
# Create default edge type map

View file

@ -51,6 +51,15 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None
else None
)
def get_default_group_id(db_type: str) -> str:
"""
This function differentiates the default group id based on the database type.
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
"""
if db_type == 'falkordb':
return '_'
else:
return ''
def lucene_sanitize(query: str) -> str:
# Escape special characters from a query before passing into Lucene

View file

@ -60,9 +60,9 @@ MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 32
def fulltext_query(query: str, group_ids: list[str] | None = None):
def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''):
group_ids_filter_list = (
[f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
[fulltext_syntax + f"group_id:'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
)
group_ids_filter = ''
for f in group_ids_filter_list:
@ -157,7 +157,7 @@ async def edge_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = fulltext_query(query, group_ids)
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
if fuzzy_query == '':
return []
@ -340,7 +340,7 @@ async def node_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = fulltext_query(query, group_ids)
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
if fuzzy_query == '':
return []
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
@ -472,7 +472,7 @@ async def episode_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]:
# BM25 search to get top episodes
fuzzy_query = fulltext_query(query, group_ids)
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
if fuzzy_query == '':
return []
@ -516,7 +516,7 @@ async def community_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]:
# BM25 search to get top communities
fuzzy_query = fulltext_query(query, group_ids)
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
if fuzzy_query == '':
return []
@ -740,7 +740,7 @@ async def get_relevant_nodes(
'uuid': node.uuid,
'name': node.name,
'name_embedding': node.name_embedding,
'fulltext_query': fulltext_query(node.name, [node.group_id]),
'fulltext_query': fulltext_query(node.name, [node.group_id], driver.fulltext_syntax),
}
for node in nodes
]