diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 4efe230a..663b7859 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -77,3 +77,6 @@ class GraphDriver(ABC): cloned._database = database return cloned + + def sanitize(self, query: str) -> str: + raise NotImplementedError() diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index acf2c66f..46e14c7e 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -164,6 +164,48 @@ class FalkorDriver(GraphDriver): cloned = FalkorDriver(falkor_db=self.client, database=database) return cloned + + def sanitize(self, query: str) -> str: + """ + Replace FalkorDB special characters with whitespace. + Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~ + """ + # FalkorDB separator characters that break text into tokens + separator_map = str.maketrans( + { + ',': ' ', + '.': ' ', + '<': ' ', + '>': ' ', + '{': ' ', + '}': ' ', + '[': ' ', + ']': ' ', + '"': ' ', + "'": ' ', + ':': ' ', + ';': ' ', + '!': ' ', + '@': ' ', + '#': ' ', + '$': ' ', + '%': ' ', + '^': ' ', + '&': ' ', + '*': ' ', + '(': ' ', + ')': ' ', + '-': ' ', + '+': ' ', + '=': ' ', + '~': ' ', + } + ) + sanitized = query.translate(separator_map) + # Clean up multiple spaces + sanitized = ' '.join(sanitized.split()) + return sanitized + def convert_datetimes_to_strings(obj): @@ -177,3 +219,4 @@ def convert_datetimes_to_strings(obj): return obj.isoformat() else: return obj + \ No newline at end of file diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index bd82e8d9..95c7f9ce 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -60,3 +60,39 @@ class Neo4jDriver(GraphDriver): return self.client.execute_query( 'CALL db.indexes() YIELD name DROP INDEX name', ) + + def sanitize(self, query: str) -> str: + # Escape special characters from a query before passing into Lucene + # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ / + escape_map = str.maketrans( + { + '+': r'\+', + '-': r'\-', + '&': r'\&', + '|': r'\|', + '!': r'\!', + '(': r'\(', + ')': r'\)', + '{': r'\{', + '}': r'\}', + '[': r'\[', + ']': r'\]', + '^': r'\^', + '"': r'\"', + '~': r'\~', + '*': r'\*', + '?': r'\?', + ':': r'\:', + '\\': r'\\', + '/': r'\/', + 'O': r'\O', + 'R': r'\R', + 'N': r'\N', + 'T': r'\T', + 'A': r'\A', + 'D': r'\D', + } + ) + + sanitized = query.translate(escape_map) + return sanitized diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 7792d90c..64ec9eaa 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -62,47 +62,6 @@ def get_default_group_id(db_type: str) -> str: else: return '' - -def lucene_sanitize(query: str) -> str: - # Escape special characters from a query before passing into Lucene - # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ / - escape_map = str.maketrans( - { - '+': r'\+', - '-': r'\-', - '&': r'\&', - '|': r'\|', - '!': r'\!', - '(': r'\(', - ')': r'\)', - '{': r'\{', - '}': r'\}', - '[': r'\[', - ']': r'\]', - '^': r'\^', - '"': r'\"', - "'": r"\'", - '~': r'\~', - '*': r'\*', - '?': r'\?', - ':': r'\:', - '\\': r'\\', - '/': r'\/', - '@': r'\@', - '%': r'\%', - 'O': r'\O', - 'R': r'\R', - 'N': r'\N', - 'T': r'\T', - 'A': r'\A', - 'D': r'\D', - } - ) - - sanitized = query.translate(escape_map) - return sanitized - - def normalize_l2(embedding: list[float]) -> NDArray: embedding_array = np.array(embedding) norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index cca7e3b8..1ac67396 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -32,7 +32,6 @@ from graphiti_core.graph_queries import ( ) from graphiti_core.helpers import ( RUNTIME_QUERY, - lucene_sanitize, normalize_l2, semaphore_gather, ) @@ -60,10 +59,10 @@ MAX_SEARCH_DEPTH = 3 MAX_QUERY_LENGTH = 128 -def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''): +def fulltext_query(query: str, driver: GraphDriver, group_ids: list[str] | None = None): # Lucene expects string values (e.g. group_id) to be wrapped in single quotes group_ids_filter_list = ( - [fulltext_syntax + f"group_id:'{g}'" for g in group_ids] if group_ids is not None else [] + [driver.fulltext_syntax + f"group_id:'{g}'" for g in group_ids] if group_ids is not None else [] ) group_ids_filter = '' for f in group_ids_filter_list: @@ -71,7 +70,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_synt group_ids_filter += ' AND ' if group_ids_filter else '' - lucene_query = lucene_sanitize(query) + lucene_query = driver.sanitize(query) # If the lucene query is too long return no query if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH: return '' @@ -158,7 +157,7 @@ async def edge_fulltext_search( limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: # fulltext search over facts - fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax) + fuzzy_query = fulltext_query(query, driver, group_ids) if fuzzy_query == '': return [] @@ -344,7 +343,7 @@ async def node_fulltext_search( limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: # BM25 search to get top nodes - fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax) + fuzzy_query = fulltext_query(query, driver, group_ids) if fuzzy_query == '': return [] filter_query, filter_params = node_search_filter_query_constructor(search_filter) @@ -479,7 +478,7 @@ async def episode_fulltext_search( limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EpisodicNode]: # BM25 search to get top episodes - fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax) + fuzzy_query = fulltext_query(query, driver, group_ids) if fuzzy_query == '': return [] @@ -524,7 +523,7 @@ async def community_fulltext_search( limit=RELEVANT_SCHEMA_LIMIT, ) -> list[CommunityNode]: # BM25 search to get top communities - fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax) + fuzzy_query = fulltext_query(query, driver, group_ids) if fuzzy_query == '': return [] @@ -749,7 +748,7 @@ async def get_relevant_nodes( 'uuid': node.uuid, 'name': node.name, 'name_embedding': node.name_embedding, - 'fulltext_query': fulltext_query(node.name, [node.group_id], driver.fulltext_syntax), + 'fulltext_query': fulltext_query(node.name, driver, [node.group_id]), } for node in nodes ]