update-driver
This commit is contained in:
parent
88f7fbc4a5
commit
9ce7f1b8f1
5 changed files with 90 additions and 50 deletions
|
|
@ -77,3 +77,6 @@ class GraphDriver(ABC):
|
||||||
cloned._database = database
|
cloned._database = database
|
||||||
|
|
||||||
return cloned
|
return cloned
|
||||||
|
|
||||||
|
def sanitize(self, query: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
||||||
|
|
@ -164,6 +164,48 @@ class FalkorDriver(GraphDriver):
|
||||||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||||
|
|
||||||
return cloned
|
return cloned
|
||||||
|
|
||||||
|
def sanitize(self, query: str) -> str:
|
||||||
|
"""
|
||||||
|
Replace FalkorDB special characters with whitespace.
|
||||||
|
Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
|
||||||
|
"""
|
||||||
|
# FalkorDB separator characters that break text into tokens
|
||||||
|
separator_map = str.maketrans(
|
||||||
|
{
|
||||||
|
',': ' ',
|
||||||
|
'.': ' ',
|
||||||
|
'<': ' ',
|
||||||
|
'>': ' ',
|
||||||
|
'{': ' ',
|
||||||
|
'}': ' ',
|
||||||
|
'[': ' ',
|
||||||
|
']': ' ',
|
||||||
|
'"': ' ',
|
||||||
|
"'": ' ',
|
||||||
|
':': ' ',
|
||||||
|
';': ' ',
|
||||||
|
'!': ' ',
|
||||||
|
'@': ' ',
|
||||||
|
'#': ' ',
|
||||||
|
'$': ' ',
|
||||||
|
'%': ' ',
|
||||||
|
'^': ' ',
|
||||||
|
'&': ' ',
|
||||||
|
'*': ' ',
|
||||||
|
'(': ' ',
|
||||||
|
')': ' ',
|
||||||
|
'-': ' ',
|
||||||
|
'+': ' ',
|
||||||
|
'=': ' ',
|
||||||
|
'~': ' ',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
sanitized = query.translate(separator_map)
|
||||||
|
# Clean up multiple spaces
|
||||||
|
sanitized = ' '.join(sanitized.split())
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def convert_datetimes_to_strings(obj):
|
def convert_datetimes_to_strings(obj):
|
||||||
|
|
@ -177,3 +219,4 @@ def convert_datetimes_to_strings(obj):
|
||||||
return obj.isoformat()
|
return obj.isoformat()
|
||||||
else:
|
else:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
@ -60,3 +60,39 @@ class Neo4jDriver(GraphDriver):
|
||||||
return self.client.execute_query(
|
return self.client.execute_query(
|
||||||
'CALL db.indexes() YIELD name DROP INDEX name',
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def sanitize(self, query: str) -> str:
|
||||||
|
# Escape special characters from a query before passing into Lucene
|
||||||
|
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
|
||||||
|
escape_map = str.maketrans(
|
||||||
|
{
|
||||||
|
'+': r'\+',
|
||||||
|
'-': r'\-',
|
||||||
|
'&': r'\&',
|
||||||
|
'|': r'\|',
|
||||||
|
'!': r'\!',
|
||||||
|
'(': r'\(',
|
||||||
|
')': r'\)',
|
||||||
|
'{': r'\{',
|
||||||
|
'}': r'\}',
|
||||||
|
'[': r'\[',
|
||||||
|
']': r'\]',
|
||||||
|
'^': r'\^',
|
||||||
|
'"': r'\"',
|
||||||
|
'~': r'\~',
|
||||||
|
'*': r'\*',
|
||||||
|
'?': r'\?',
|
||||||
|
':': r'\:',
|
||||||
|
'\\': r'\\',
|
||||||
|
'/': r'\/',
|
||||||
|
'O': r'\O',
|
||||||
|
'R': r'\R',
|
||||||
|
'N': r'\N',
|
||||||
|
'T': r'\T',
|
||||||
|
'A': r'\A',
|
||||||
|
'D': r'\D',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
sanitized = query.translate(escape_map)
|
||||||
|
return sanitized
|
||||||
|
|
|
||||||
|
|
@ -62,47 +62,6 @@ def get_default_group_id(db_type: str) -> str:
|
||||||
else:
|
else:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
def lucene_sanitize(query: str) -> str:
|
|
||||||
# Escape special characters from a query before passing into Lucene
|
|
||||||
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
|
|
||||||
escape_map = str.maketrans(
|
|
||||||
{
|
|
||||||
'+': r'\+',
|
|
||||||
'-': r'\-',
|
|
||||||
'&': r'\&',
|
|
||||||
'|': r'\|',
|
|
||||||
'!': r'\!',
|
|
||||||
'(': r'\(',
|
|
||||||
')': r'\)',
|
|
||||||
'{': r'\{',
|
|
||||||
'}': r'\}',
|
|
||||||
'[': r'\[',
|
|
||||||
']': r'\]',
|
|
||||||
'^': r'\^',
|
|
||||||
'"': r'\"',
|
|
||||||
"'": r"\'",
|
|
||||||
'~': r'\~',
|
|
||||||
'*': r'\*',
|
|
||||||
'?': r'\?',
|
|
||||||
':': r'\:',
|
|
||||||
'\\': r'\\',
|
|
||||||
'/': r'\/',
|
|
||||||
'@': r'\@',
|
|
||||||
'%': r'\%',
|
|
||||||
'O': r'\O',
|
|
||||||
'R': r'\R',
|
|
||||||
'N': r'\N',
|
|
||||||
'T': r'\T',
|
|
||||||
'A': r'\A',
|
|
||||||
'D': r'\D',
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
sanitized = query.translate(escape_map)
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_l2(embedding: list[float]) -> NDArray:
|
def normalize_l2(embedding: list[float]) -> NDArray:
|
||||||
embedding_array = np.array(embedding)
|
embedding_array = np.array(embedding)
|
||||||
norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True)
|
norm = np.linalg.norm(embedding_array, 2, axis=0, keepdims=True)
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,6 @@ from graphiti_core.graph_queries import (
|
||||||
)
|
)
|
||||||
from graphiti_core.helpers import (
|
from graphiti_core.helpers import (
|
||||||
RUNTIME_QUERY,
|
RUNTIME_QUERY,
|
||||||
lucene_sanitize,
|
|
||||||
normalize_l2,
|
normalize_l2,
|
||||||
semaphore_gather,
|
semaphore_gather,
|
||||||
)
|
)
|
||||||
|
|
@ -60,10 +59,10 @@ MAX_SEARCH_DEPTH = 3
|
||||||
MAX_QUERY_LENGTH = 128
|
MAX_QUERY_LENGTH = 128
|
||||||
|
|
||||||
|
|
||||||
def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''):
|
def fulltext_query(query: str, driver: GraphDriver, group_ids: list[str] | None = None):
|
||||||
# Lucene expects string values (e.g. group_id) to be wrapped in single quotes
|
# Lucene expects string values (e.g. group_id) to be wrapped in single quotes
|
||||||
group_ids_filter_list = (
|
group_ids_filter_list = (
|
||||||
[fulltext_syntax + f"group_id:'{g}'" for g in group_ids] if group_ids is not None else []
|
[driver.fulltext_syntax + f"group_id:'{g}'" for g in group_ids] if group_ids is not None else []
|
||||||
)
|
)
|
||||||
group_ids_filter = ''
|
group_ids_filter = ''
|
||||||
for f in group_ids_filter_list:
|
for f in group_ids_filter_list:
|
||||||
|
|
@ -71,7 +70,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_synt
|
||||||
|
|
||||||
group_ids_filter += ' AND ' if group_ids_filter else ''
|
group_ids_filter += ' AND ' if group_ids_filter else ''
|
||||||
|
|
||||||
lucene_query = lucene_sanitize(query)
|
lucene_query = driver.sanitize(query)
|
||||||
# If the lucene query is too long return no query
|
# If the lucene query is too long return no query
|
||||||
if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH:
|
if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH:
|
||||||
return ''
|
return ''
|
||||||
|
|
@ -158,7 +157,7 @@ async def edge_fulltext_search(
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# fulltext search over facts
|
# fulltext search over facts
|
||||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
fuzzy_query = fulltext_query(query, driver, group_ids)
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -344,7 +343,7 @@ async def node_fulltext_search(
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# BM25 search to get top nodes
|
# BM25 search to get top nodes
|
||||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
fuzzy_query = fulltext_query(query, driver, group_ids)
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||||
|
|
@ -479,7 +478,7 @@ async def episode_fulltext_search(
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
# BM25 search to get top episodes
|
# BM25 search to get top episodes
|
||||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
fuzzy_query = fulltext_query(query, driver, group_ids)
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -524,7 +523,7 @@ async def community_fulltext_search(
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
# BM25 search to get top communities
|
# BM25 search to get top communities
|
||||||
fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
|
fuzzy_query = fulltext_query(query, driver, group_ids)
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -749,7 +748,7 @@ async def get_relevant_nodes(
|
||||||
'uuid': node.uuid,
|
'uuid': node.uuid,
|
||||||
'name': node.name,
|
'name': node.name,
|
||||||
'name_embedding': node.name_embedding,
|
'name_embedding': node.name_embedding,
|
||||||
'fulltext_query': fulltext_query(node.name, [node.group_id], driver.fulltext_syntax),
|
'fulltext_query': fulltext_query(node.name, driver, [node.group_id]),
|
||||||
}
|
}
|
||||||
for node in nodes
|
for node in nodes
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue