diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 03a58631..7be689b4 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -46,6 +46,7 @@ class GraphDriverSession(ABC): class GraphDriver(ABC): provider: str + fulltext_syntax: str = '' # Neo4j (default) syntax does not require a prefix for fulltext queries @abstractmethod def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine: diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index d43d75ec..ed7431e9 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -97,6 +97,9 @@ class FalkorDriver(GraphDriver): self.client = FalkorDB(host=host, port=port, username=username, password=password) self._database = database + self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/ + + def _get_graph(self, graph_name: str | None) -> FalkorGraph: # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db" if graph_name is None: diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 2ba8e6da..47d72343 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -30,6 +30,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import ( + get_default_group_id, semaphore_gather, validate_excluded_entity_types, validate_group_id, @@ -352,7 +353,7 @@ class Graphiti: source_description: str, reference_time: datetime, source: EpisodeType = EpisodeType.message, - group_id: str = '', + group_id: str | None = None, uuid: str | None = None, update_communities: bool = False, entity_types: dict[str, BaseModel] | None = None, @@ -420,7 +421,10 @@ class Graphiti: start = time() now = utc_now() + # if group_id is None, use the default group id by the provider + group_id = group_id or get_default_group_id(self.driver.provider) validate_entity_types(entity_types) + validate_excluded_entity_types(excluded_entity_types, entity_types) validate_group_id(group_id) @@ -537,7 +541,7 @@ class Graphiti: async def add_episode_bulk( self, bulk_episodes: list[RawEpisode], - group_id: str = '', + group_id: str | None = None, entity_types: dict[str, BaseModel] | None = None, excluded_entity_types: list[str] | None = None, edge_types: dict[str, BaseModel] | None = None, @@ -583,6 +587,8 @@ class Graphiti: start = time() now = utc_now() + # if group_id is None, use the default group id by the provider + group_id = group_id or get_default_group_id(self.driver.provider) validate_group_id(group_id) # Create default edge type map diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 2d13ff44..855b364d 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -51,6 +51,15 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None else None ) +def get_default_group_id(db_type: str) -> str: + """ + This function differentiates the default group id based on the database type. + For most databases, the default group id is an empty string, while there are database types that require a specific default group id. + """ + if db_type == 'falkordb': + return '_' + else: + return '' def lucene_sanitize(query: str) -> str: # Escape special characters from a query before passing into Lucene diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 272c43c5..f5eab407 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -60,9 +60,9 @@ MAX_SEARCH_DEPTH = 3 MAX_QUERY_LENGTH = 32 -def fulltext_query(query: str, group_ids: list[str] | None = None): +def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''): group_ids_filter_list = ( - [f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else [] + [fulltext_syntax + f"group_id:'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else [] ) group_ids_filter = '' for f in group_ids_filter_list: @@ -157,7 +157,7 @@ async def edge_fulltext_search( limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: # fulltext search over facts - fuzzy_query = fulltext_query(query, group_ids) + fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax) if fuzzy_query == '': return [] @@ -340,7 +340,7 @@ async def node_fulltext_search( limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: # BM25 search to get top nodes - fuzzy_query = fulltext_query(query, group_ids) + fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax) if fuzzy_query == '': return [] filter_query, filter_params = node_search_filter_query_constructor(search_filter) @@ -472,7 +472,7 @@ async def episode_fulltext_search( limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EpisodicNode]: # BM25 search to get top episodes - fuzzy_query = fulltext_query(query, group_ids) + fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax) if fuzzy_query == '': return [] @@ -516,7 +516,7 @@ async def community_fulltext_search( limit=RELEVANT_SCHEMA_LIMIT, ) -> list[CommunityNode]: # BM25 search to get top communities - fuzzy_query = fulltext_query(query, group_ids) + fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax) if fuzzy_query == '': return [] @@ -740,7 +740,7 @@ async def get_relevant_nodes( 'uuid': node.uuid, 'name': node.name, 'name_embedding': node.name_embedding, - 'fulltext_query': fulltext_query(node.name, [node.group_id]), + 'fulltext_query': fulltext_query(node.name, [node.group_id], driver.fulltext_syntax), } for node in nodes ]