fix-fulltext-syntax-error (#914)
* fix-fulltext-syntax-error * update-abs-method
This commit is contained in:
parent
da71d118db
commit
d725fcdf8e
6 changed files with 123 additions and 9 deletions
|
|
@ -311,3 +311,10 @@ class GraphDriver(ABC):
|
|||
return success if failed == 0 else success
|
||||
|
||||
return 0
|
||||
|
||||
def build_fulltext_query(self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128) -> str:
|
||||
"""
|
||||
Specific fulltext query builder for database providers.
|
||||
Only implemented by providers that need custom fulltext query building.
|
||||
"""
|
||||
raise NotImplementedError(f"build_fulltext_query not implemented for {self.provider}")
|
||||
|
|
|
|||
|
|
@ -36,6 +36,11 @@ from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STOPWORDS = [
|
||||
'a', 'is', 'the', 'an', 'and', 'are', 'as', 'at', 'be', 'but', 'by', 'for',
|
||||
'if', 'in', 'into', 'it', 'no', 'not', 'of', 'on', 'or', 'such', 'that', 'their',
|
||||
'then', 'there', 'these', 'they', 'this', 'to', 'was', 'will', 'with'
|
||||
]
|
||||
|
||||
class FalkorDriverSession(GraphDriverSession):
|
||||
provider = GraphProvider.FALKORDB
|
||||
|
|
@ -167,3 +172,77 @@ class FalkorDriver(GraphDriver):
|
|||
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
||||
|
||||
return cloned
|
||||
|
||||
|
||||
def sanitize(self, query: str) -> str:
|
||||
"""
|
||||
Replace FalkorDB special characters with whitespace.
|
||||
Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
|
||||
"""
|
||||
# FalkorDB separator characters that break text into tokens
|
||||
separator_map = str.maketrans(
|
||||
{
|
||||
',': ' ',
|
||||
'.': ' ',
|
||||
'<': ' ',
|
||||
'>': ' ',
|
||||
'{': ' ',
|
||||
'}': ' ',
|
||||
'[': ' ',
|
||||
']': ' ',
|
||||
'"': ' ',
|
||||
"'": ' ',
|
||||
':': ' ',
|
||||
';': ' ',
|
||||
'!': ' ',
|
||||
'@': ' ',
|
||||
'#': ' ',
|
||||
'$': ' ',
|
||||
'%': ' ',
|
||||
'^': ' ',
|
||||
'&': ' ',
|
||||
'*': ' ',
|
||||
'(': ' ',
|
||||
')': ' ',
|
||||
'-': ' ',
|
||||
'+': ' ',
|
||||
'=': ' ',
|
||||
'~': ' ',
|
||||
'?': ' ',
|
||||
}
|
||||
)
|
||||
sanitized = query.translate(separator_map)
|
||||
# Clean up multiple spaces
|
||||
sanitized = ' '.join(sanitized.split())
|
||||
return sanitized
|
||||
|
||||
def build_fulltext_query(self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128) -> str:
|
||||
"""
|
||||
Build a fulltext query string for FalkorDB using RedisSearch syntax.
|
||||
FalkorDB uses RedisSearch-like syntax where:
|
||||
- Field queries use @ prefix: @field:value
|
||||
- Multiple values for same field: (@field:value1|value2)
|
||||
- Text search doesn't need @ prefix for content fields
|
||||
- AND is implicit with space: (@group_id:value) (text)
|
||||
- OR uses pipe within parentheses: (@group_id:value1|value2)
|
||||
"""
|
||||
if group_ids is None or len(group_ids) == 0:
|
||||
group_filter = ''
|
||||
else:
|
||||
group_values = '|'.join(group_ids)
|
||||
group_filter = f"(@group_id:{group_values})"
|
||||
|
||||
sanitized_query = self.sanitize(query)
|
||||
|
||||
# Remove stopwords from the sanitized query
|
||||
query_words = sanitized_query.split()
|
||||
filtered_words = [word for word in query_words if word.lower() not in STOPWORDS]
|
||||
sanitized_query = ' | '.join(filtered_words)
|
||||
|
||||
# If the query is too long return no query
|
||||
if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length:
|
||||
return ''
|
||||
|
||||
full_query = group_filter + ' (' + sanitized_query + ')'
|
||||
|
||||
return full_query
|
||||
|
|
@ -71,12 +71,38 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|||
|
||||
def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return [
|
||||
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
|
||||
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
|
||||
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
|
||||
from typing import cast
|
||||
|
||||
from graphiti_core.driver.falkordb_driver import STOPWORDS
|
||||
|
||||
# Convert to string representation for embedding in queries
|
||||
stopwords_str = str(STOPWORDS)
|
||||
|
||||
# Use type: ignore to satisfy LiteralString requirement while maintaining single source of truth
|
||||
return cast(list[LiteralString], [
|
||||
f"""CALL db.idx.fulltext.createNodeIndex(
|
||||
{{
|
||||
label: 'Episodic',
|
||||
stopwords: {stopwords_str}
|
||||
}},
|
||||
'content', 'source', 'source_description', 'group_id'
|
||||
)""",
|
||||
f"""CALL db.idx.fulltext.createNodeIndex(
|
||||
{{
|
||||
label: 'Entity',
|
||||
stopwords: {stopwords_str}
|
||||
}},
|
||||
'name', 'summary', 'group_id'
|
||||
)""",
|
||||
f"""CALL db.idx.fulltext.createNodeIndex(
|
||||
{{
|
||||
label: 'Community',
|
||||
stopwords: {stopwords_str}
|
||||
}},
|
||||
'name', 'group_id'
|
||||
)""",
|
||||
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
||||
]
|
||||
])
|
||||
|
||||
if provider == GraphProvider.KUZU:
|
||||
return [
|
||||
|
|
|
|||
|
|
@ -456,12 +456,12 @@ class Graphiti:
|
|||
start = time()
|
||||
now = utc_now()
|
||||
|
||||
# if group_id is None, use the default group id by the provider
|
||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
||||
validate_entity_types(entity_types)
|
||||
|
||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||
validate_group_id(group_id)
|
||||
# if group_id is None, use the default group id by the provider
|
||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
||||
|
||||
previous_episodes = (
|
||||
await self.retrieve_episodes(
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def get_default_group_id(provider: GraphProvider) -> str:
|
|||
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
||||
"""
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return '_'
|
||||
return '\\_'
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
|
@ -116,7 +116,7 @@ async def semaphore_gather(
|
|||
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
|
||||
|
||||
|
||||
def validate_group_id(group_id: str) -> bool:
|
||||
def validate_group_id(group_id: str | None) -> bool:
|
||||
"""
|
||||
Validate that a group_id contains only ASCII alphanumeric characters, dashes, and underscores.
|
||||
|
||||
|
|
|
|||
|
|
@ -92,6 +92,8 @@ def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver)
|
|||
if len(query.split(' ')) > MAX_QUERY_LENGTH:
|
||||
return ''
|
||||
return query
|
||||
elif driver.provider == GraphProvider.FALKORDB:
|
||||
return driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
|
||||
group_ids_filter_list = (
|
||||
[driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
|
||||
if group_ids is not None
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue