update-driver
This commit is contained in:
parent
88f7fbc4a5
commit
9ce7f1b8f1
5 changed files with 90 additions and 50 deletions
|
|
@ -77,3 +77,6 @@ class GraphDriver(ABC):
|
|||
cloned._database = database
|
||||
|
||||
return cloned
|
||||
|
||||
def sanitize(self, query: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -164,6 +164,48 @@ class FalkorDriver(GraphDriver):
|
|||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||
|
||||
return cloned
|
||||
|
||||
def sanitize(self, query: str) -> str:
|
||||
"""
|
||||
Replace FalkorDB special characters with whitespace.
|
||||
Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
|
||||
"""
|
||||
# FalkorDB separator characters that break text into tokens
|
||||
separator_map = str.maketrans(
|
||||
{
|
||||
',': ' ',
|
||||
'.': ' ',
|
||||
'<': ' ',
|
||||
'>': ' ',
|
||||
'{': ' ',
|
||||
'}': ' ',
|
||||
'[': ' ',
|
||||
']': ' ',
|
||||
'"': ' ',
|
||||
"'": ' ',
|
||||
':': ' ',
|
||||
';': ' ',
|
||||
'!': ' ',
|
||||
'@': ' ',
|
||||
'#': ' ',
|
||||
'$': ' ',
|
||||
'%': ' ',
|
||||
'^': ' ',
|
||||
'&': ' ',
|
||||
'*': ' ',
|
||||
'(': ' ',
|
||||
')': ' ',
|
||||
'-': ' ',
|
||||
'+': ' ',
|
||||
'=': ' ',
|
||||
'~': ' ',
|
||||
}
|
||||
)
|
||||
sanitized = query.translate(separator_map)
|
||||
# Clean up multiple spaces
|
||||
sanitized = ' '.join(sanitized.split())
|
||||
return sanitized
|
||||
|
||||
|
||||
|
||||
def convert_datetimes_to_strings(obj):
|
||||
|
|
@ -177,3 +219,4 @@ def convert_datetimes_to_strings(obj):
|
|||
return obj.isoformat()
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
|
@ -60,3 +60,39 @@ class Neo4jDriver(GraphDriver):
|
|||
return self.client.execute_query(
|
||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||
)
|
||||
|
||||
def sanitize(self, query: str) -> str:
|
||||
# Escape special characters from a query before passing into Lucene
|
||||
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
|
||||
escape_map = str.maketrans(
|
||||
{
|
||||
'+': r'\+',
|
||||
'-': r'\-',
|
||||
'&': r'\&',
|
||||
'|': r'\|',
|
||||
'!': r'\!',
|
||||
'(': r'\(',
|
||||
')': r'\)',
|
||||
'{': r'\{',
|
||||
'}': r'\}',
|
||||
'[': r'\[',
|
||||
']': r'\]',
|
||||
'^': r'\^',
|
||||
'"': r'\"',
|
||||
'~': r'\~',
|
||||
'*': r'\*',
|
||||
'?': r'\?',
|
||||
':': r'\:',
|
||||
'\\': r'\\',
|
||||
'/': r'\/',
|
||||
'O': r'\O',
|
||||
'R': r'\R',
|
||||
'N': r'\N',
|
||||
'T': r'\T',
|
||||
'A': r'\A',
|
||||
'D': r'\D',
|
||||
}
|
||||
)
|
||||
|
||||
sanitized = query.translate(escape_map)
|
||||
return sanitized
|
||||
|
|
|
|||
|
|
@ -62,47 +62,6 @@ def get_default_group_id(db_type: str) -> str:
|
|||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def lucene_sanitize(query: str) -> str:
|
||||
# Escape special characters from a query before passing into Lucene
|
||||
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
|
||||
escape_map = str.maketrans(
|
||||
{
|
||||
'+': r'\+',
|
||||
'-': r'\-',
|
||||
'&': r'\&',
|
||||
'|': r'\|',
|
||||
'!': r'\!',
|
||||
'(': r'\(',
|
||||
')': r'\)',
|
||||
'{': r'\{',
|
||||
'}': r'\}',
|
||||
'[': r'\[',
|
||||
']': r'\]',
|
||||
'^': r'\^',
|
||||
'"': r'\"',
|
||||
"'": r"\'",
|
||||
'~': r'\~',
|
||||
'*': r'\*',
|
||||
'?': r'\?',
|
||||
':': r'\:',
|
||||
'\\': r'\\',
|
||||
'/': r'\/',
|
||||
'@': r'\@',
|
||||
'%': r'\%',
|
||||
'O': r'\O',
|
||||
'R': r'\R',
|
||||
'N': r'\N',
|
||||
'T': r'\T',
|
||||
'A': r'\A',
|
||||
'D': r'\D',
|
||||
}
|
||||
)
|
||||
|
||||
sanitized = query.translate(escape_map)
|
||||
return sanitized
|
||||
|
||||
|
||||
def normalize_l2(embedding: list[float]) -> NDArray:
|
||||
embedding_array = np.array(embedding)
|
||||
norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True)
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ from graphiti_core.graph_queries import (
|
|||
)
|
||||
from graphiti_core.helpers import (
|
||||
RUNTIME_QUERY,
|
||||
lucene_sanitize,
|
||||
normalize_l2,
|
||||
semaphore_gather,
|
||||
)
|
||||
|
|
@ -60,10 +59,10 @@ MAX_SEARCH_DEPTH = 3
|
|||
MAX_QUERY_LENGTH = 128
|
||||
|
||||
|
||||
def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''):
|
||||
def fulltext_query(query: str, driver: GraphDriver, group_ids: list[str] | None = None):
|
||||
# Lucene expects string values (e.g. group_id) to be wrapped in single quotes
|
||||
group_ids_filter_list = (
|
||||
[fulltext_syntax + f"group_id:'{g}'" for g in group_ids] if group_ids is not None else []
|
||||
[driver.fulltext_syntax + f"group_id:'{g}'" for g in group_ids] if group_ids is not None else []
|
||||
)
|
||||
group_ids_filter = ''
|
||||
for f in group_ids_filter_list:
|
||||
|
|
@ -71,7 +70,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_synt
|
|||
|
||||
group_ids_filter += ' AND ' if group_ids_filter else ''
|
||||
|
||||
lucene_query = lucene_sanitize(query)
|
||||
lucene_query = driver.sanitize(query)
|
||||
# If the lucene query is too long return no query
|
||||
if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH:
|
||||
return ''
|
||||
|
|
@ -158,7 +157,7 @@ async def edge_fulltext_search(
|
|||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
# fulltext search over facts
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
||||
fuzzy_query = fulltext_query(query, driver, group_ids)
|
||||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
|
|
@ -344,7 +343,7 @@ async def node_fulltext_search(
|
|||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
||||
fuzzy_query = fulltext_query(query, driver, group_ids)
|
||||
if fuzzy_query == '':
|
||||
return []
|
||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||
|
|
@ -479,7 +478,7 @@ async def episode_fulltext_search(
|
|||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EpisodicNode]:
|
||||
# BM25 search to get top episodes
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
||||
fuzzy_query = fulltext_query(query, driver, group_ids)
|
||||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
|
|
@ -524,7 +523,7 @@ async def community_fulltext_search(
|
|||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
# BM25 search to get top communities
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
||||
fuzzy_query = fulltext_query(query, driver, group_ids)
|
||||
if fuzzy_query == '':
|
||||
return []
|
||||
|
||||
|
|
@ -749,7 +748,7 @@ async def get_relevant_nodes(
|
|||
'uuid': node.uuid,
|
||||
'name': node.name,
|
||||
'name_embedding': node.name_embedding,
|
||||
'fulltext_query': fulltext_query(node.name, [node.group_id], driver.fulltext_syntax),
|
||||
'fulltext_query': fulltext_query(node.name, driver, [node.group_id]),
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue