update-driver

This commit is contained in:
Gal Shubeli 2025-07-30 18:18:44 +03:00
parent 88f7fbc4a5
commit 9ce7f1b8f1
5 changed files with 90 additions and 50 deletions

View file

@ -77,3 +77,6 @@ class GraphDriver(ABC):
cloned._database = database
return cloned
def sanitize(self, query: str) -> str:
raise NotImplementedError()

View file

@ -164,6 +164,48 @@ class FalkorDriver(GraphDriver):
cloned = FalkorDriver(falkor_db=self.client, database=database)
return cloned
def sanitize(self, query: str) -> str:
"""
Replace FalkorDB special characters with whitespace.
Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
"""
# FalkorDB separator characters that break text into tokens
separator_map = str.maketrans(
{
',': ' ',
'.': ' ',
'<': ' ',
'>': ' ',
'{': ' ',
'}': ' ',
'[': ' ',
']': ' ',
'"': ' ',
"'": ' ',
':': ' ',
';': ' ',
'!': ' ',
'@': ' ',
'#': ' ',
'$': ' ',
'%': ' ',
'^': ' ',
'&': ' ',
'*': ' ',
'(': ' ',
')': ' ',
'-': ' ',
'+': ' ',
'=': ' ',
'~': ' ',
}
)
sanitized = query.translate(separator_map)
# Clean up multiple spaces
sanitized = ' '.join(sanitized.split())
return sanitized
def convert_datetimes_to_strings(obj):
@ -177,3 +219,4 @@ def convert_datetimes_to_strings(obj):
return obj.isoformat()
else:
return obj

View file

@ -60,3 +60,39 @@ class Neo4jDriver(GraphDriver):
return self.client.execute_query(
'CALL db.indexes() YIELD name DROP INDEX name',
)
def sanitize(self, query: str) -> str:
# Escape special characters from a query before passing into Lucene
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
escape_map = str.maketrans(
{
'+': r'\+',
'-': r'\-',
'&': r'\&',
'|': r'\|',
'!': r'\!',
'(': r'\(',
')': r'\)',
'{': r'\{',
'}': r'\}',
'[': r'\[',
']': r'\]',
'^': r'\^',
'"': r'\"',
'~': r'\~',
'*': r'\*',
'?': r'\?',
':': r'\:',
'\\': r'\\',
'/': r'\/',
'O': r'\O',
'R': r'\R',
'N': r'\N',
'T': r'\T',
'A': r'\A',
'D': r'\D',
}
)
sanitized = query.translate(escape_map)
return sanitized

View file

@ -62,47 +62,6 @@ def get_default_group_id(db_type: str) -> str:
else:
return ''
def lucene_sanitize(query: str) -> str:
# Escape special characters from a query before passing into Lucene
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
escape_map = str.maketrans(
{
'+': r'\+',
'-': r'\-',
'&': r'\&',
'|': r'\|',
'!': r'\!',
'(': r'\(',
')': r'\)',
'{': r'\{',
'}': r'\}',
'[': r'\[',
']': r'\]',
'^': r'\^',
'"': r'\"',
"'": r"\'",
'~': r'\~',
'*': r'\*',
'?': r'\?',
':': r'\:',
'\\': r'\\',
'/': r'\/',
'@': r'\@',
'%': r'\%',
'O': r'\O',
'R': r'\R',
'N': r'\N',
'T': r'\T',
'A': r'\A',
'D': r'\D',
}
)
sanitized = query.translate(escape_map)
return sanitized
def normalize_l2(embedding: list[float]) -> NDArray:
embedding_array = np.array(embedding)
norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True)

View file

@ -32,7 +32,6 @@ from graphiti_core.graph_queries import (
)
from graphiti_core.helpers import (
RUNTIME_QUERY,
lucene_sanitize,
normalize_l2,
semaphore_gather,
)
@ -60,10 +59,10 @@ MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 128
def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''):
def fulltext_query(query: str, driver: GraphDriver, group_ids: list[str] | None = None):
# Lucene expects string values (e.g. group_id) to be wrapped in single quotes
group_ids_filter_list = (
[fulltext_syntax + f"group_id:'{g}'" for g in group_ids] if group_ids is not None else []
[driver.fulltext_syntax + f"group_id:'{g}'" for g in group_ids] if group_ids is not None else []
)
group_ids_filter = ''
for f in group_ids_filter_list:
@ -71,7 +70,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_synt
group_ids_filter += ' AND ' if group_ids_filter else ''
lucene_query = lucene_sanitize(query)
lucene_query = driver.sanitize(query)
# If the lucene query is too long return no query
if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH:
return ''
@ -158,7 +157,7 @@ async def edge_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
fuzzy_query = fulltext_query(query, driver, group_ids)
if fuzzy_query == '':
return []
@ -344,7 +343,7 @@ async def node_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
fuzzy_query = fulltext_query(query, driver, group_ids)
if fuzzy_query == '':
return []
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
@ -479,7 +478,7 @@ async def episode_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]:
# BM25 search to get top episodes
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
fuzzy_query = fulltext_query(query, driver, group_ids)
if fuzzy_query == '':
return []
@ -524,7 +523,7 @@ async def community_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]:
# BM25 search to get top communities
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
fuzzy_query = fulltext_query(query, driver, group_ids)
if fuzzy_query == '':
return []
@ -749,7 +748,7 @@ async def get_relevant_nodes(
'uuid': node.uuid,
'name': node.name,
'name_embedding': node.name_embedding,
'fulltext_query': fulltext_query(node.name, [node.group_id], driver.fulltext_syntax),
'fulltext_query': fulltext_query(node.name, driver, [node.group_id]),
}
for node in nodes
]