update-fulltext-syntax
This commit is contained in:
parent
f4cf5e07e1
commit
f2824fbabd
7 changed files with 107 additions and 33 deletions
|
|
@ -86,3 +86,18 @@ class GraphDriver(ABC):
|
||||||
|
|
||||||
def sanitize(self, query: str) -> str:
|
def sanitize(self, query: str) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_fulltext_query(self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128) -> str:
|
||||||
|
"""
|
||||||
|
Build a fulltext query string appropriate for this graph provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query text
|
||||||
|
group_ids: Optional list of group IDs to filter by
|
||||||
|
max_query_length: Maximum allowed query length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted fulltext query string for this provider
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,11 @@ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphPr
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
STOPWORDS = [
|
||||||
|
'a', 'is', 'the', 'an', 'and', 'are', 'as', 'at', 'be', 'but', 'by', 'for',
|
||||||
|
'if', 'in', 'into', 'it', 'no', 'not', 'of', 'on', 'or', 'such', 'that', 'their',
|
||||||
|
'then', 'there', 'these', 'they', 'this', 'to', 'was', 'will', 'with'
|
||||||
|
]
|
||||||
|
|
||||||
class FalkorDriverSession(GraphDriverSession):
|
class FalkorDriverSession(GraphDriverSession):
|
||||||
def __init__(self, graph: FalkorGraph):
|
def __init__(self, graph: FalkorGraph):
|
||||||
|
|
@ -199,6 +204,7 @@ class FalkorDriver(GraphDriver):
|
||||||
'+': ' ',
|
'+': ' ',
|
||||||
'=': ' ',
|
'=': ' ',
|
||||||
'~': ' ',
|
'~': ' ',
|
||||||
|
'?': ' ',
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
sanitized = query.translate(separator_map)
|
sanitized = query.translate(separator_map)
|
||||||
|
|
@ -206,6 +212,37 @@ class FalkorDriver(GraphDriver):
|
||||||
sanitized = ' '.join(sanitized.split())
|
sanitized = ' '.join(sanitized.split())
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
def build_fulltext_query(self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128) -> str:
|
||||||
|
"""
|
||||||
|
Build a fulltext query string for FalkorDB using RedisSearch syntax.
|
||||||
|
FalkorDB uses RedisSearch-like syntax where:
|
||||||
|
- Field queries use @ prefix: @field:value
|
||||||
|
- Multiple values for same field: (@field:value1|value2)
|
||||||
|
- Text search doesn't need @ prefix for content fields
|
||||||
|
- AND is implicit with space: (@group_id:value) (text)
|
||||||
|
- OR uses pipe within parentheses: (@group_id:value1|value2)
|
||||||
|
"""
|
||||||
|
if group_ids is None or len(group_ids) == 0:
|
||||||
|
group_filter = ''
|
||||||
|
else:
|
||||||
|
group_values = '|'.join(group_ids)
|
||||||
|
group_filter = f"(@group_id:{group_values})"
|
||||||
|
|
||||||
|
sanitized_query = self.sanitize(query)
|
||||||
|
|
||||||
|
# Remove stopwords from the sanitized query
|
||||||
|
query_words = sanitized_query.split()
|
||||||
|
filtered_words = [word for word in query_words if word.lower() not in STOPWORDS]
|
||||||
|
sanitized_query = ' | '.join(filtered_words)
|
||||||
|
|
||||||
|
# If the query is too long return no query
|
||||||
|
if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
full_query = group_filter + ' (' + sanitized_query + ')'
|
||||||
|
|
||||||
|
return full_query
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def convert_datetimes_to_strings(obj):
|
def convert_datetimes_to_strings(obj):
|
||||||
|
|
|
||||||
|
|
@ -100,3 +100,27 @@ class Neo4jDriver(GraphDriver):
|
||||||
|
|
||||||
sanitized = query.translate(escape_map)
|
sanitized = query.translate(escape_map)
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
def build_fulltext_query(self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128) -> str:
|
||||||
|
"""
|
||||||
|
Build a fulltext query string for Neo4j.
|
||||||
|
Neo4j uses Lucene syntax where string values need to be wrapped in single quotes.
|
||||||
|
"""
|
||||||
|
# Lucene expects string values (e.g. group_id) to be wrapped in single quotes
|
||||||
|
group_ids_filter_list = (
|
||||||
|
[self.fulltext_syntax + f'group_id:"{g}"' for g in group_ids] if group_ids is not None else []
|
||||||
|
)
|
||||||
|
group_ids_filter = ''
|
||||||
|
for f in group_ids_filter_list:
|
||||||
|
group_ids_filter += f if not group_ids_filter else f' OR {f}'
|
||||||
|
|
||||||
|
group_ids_filter += ' AND ' if group_ids_filter else ''
|
||||||
|
|
||||||
|
lucene_query = self.sanitize(query)
|
||||||
|
# If the lucene query is too long return no query
|
||||||
|
if len(lucene_query.split(' ')) + len(group_ids or '') >= max_query_length:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
full_query = group_ids_filter + '(' + lucene_query + ')'
|
||||||
|
|
||||||
|
return full_query
|
||||||
|
|
|
||||||
|
|
@ -58,12 +58,31 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
def get_fulltext_indices(provider: GraphProvider) -> list[str]:
|
||||||
if provider == GraphProvider.FALKORDB:
|
if provider == GraphProvider.FALKORDB:
|
||||||
|
from graphiti_core.driver.falkordb_driver import STOPWORDS
|
||||||
return [
|
return [
|
||||||
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
|
f"""CALL db.idx.fulltext.createNodeIndex(
|
||||||
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
|
{{
|
||||||
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
|
label: 'Episodic',
|
||||||
|
stopwords: {STOPWORDS}
|
||||||
|
}},
|
||||||
|
'content', 'source', 'source_description', 'group_id'
|
||||||
|
)""",
|
||||||
|
f"""CALL db.idx.fulltext.createNodeIndex(
|
||||||
|
{{
|
||||||
|
label: 'Entity',
|
||||||
|
stopwords: {STOPWORDS}
|
||||||
|
}},
|
||||||
|
'name', 'summary', 'group_id'
|
||||||
|
)""",
|
||||||
|
f"""CALL db.idx.fulltext.createNodeIndex(
|
||||||
|
{{
|
||||||
|
label: 'Community',
|
||||||
|
stopwords: {STOPWORDS}
|
||||||
|
}},
|
||||||
|
'name', 'group_id'
|
||||||
|
)""",
|
||||||
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -424,12 +424,12 @@ class Graphiti:
|
||||||
start = time()
|
start = time()
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
|
|
||||||
# if group_id is None, use the default group id by the provider
|
|
||||||
group_id = group_id or get_default_group_id(self.driver.provider)
|
|
||||||
validate_entity_types(entity_types)
|
validate_entity_types(entity_types)
|
||||||
|
|
||||||
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
||||||
validate_group_id(group_id)
|
validate_group_id(group_id)
|
||||||
|
# if group_id is None, use the default group id by the provider
|
||||||
|
group_id = group_id or get_default_group_id(self.driver.provider)
|
||||||
|
|
||||||
previous_episodes = (
|
previous_episodes = (
|
||||||
await self.retrieve_episodes(
|
await self.retrieve_episodes(
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ def get_default_group_id(provider: GraphProvider) -> str:
|
||||||
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
||||||
"""
|
"""
|
||||||
if provider == GraphProvider.FALKORDB:
|
if provider == GraphProvider.FALKORDB:
|
||||||
return '_'
|
return '\\_'
|
||||||
else:
|
else:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,27 +61,6 @@ MAX_SEARCH_DEPTH = 3
|
||||||
MAX_QUERY_LENGTH = 128
|
MAX_QUERY_LENGTH = 128
|
||||||
|
|
||||||
|
|
||||||
def fulltext_query(query: str, driver: GraphDriver, group_ids: list[str] | None = None):
|
|
||||||
# Lucene expects string values (e.g. group_id) to be wrapped in single quotes
|
|
||||||
group_ids_filter_list = (
|
|
||||||
[driver.fulltext_syntax + f"group_id:'{g}'" for g in group_ids] if group_ids is not None else []
|
|
||||||
)
|
|
||||||
group_ids_filter = ''
|
|
||||||
for f in group_ids_filter_list:
|
|
||||||
group_ids_filter += f if not group_ids_filter else f' OR {f}'
|
|
||||||
|
|
||||||
group_ids_filter += ' AND ' if group_ids_filter else ''
|
|
||||||
|
|
||||||
lucene_query = driver.sanitize(query)
|
|
||||||
# If the lucene query is too long return no query
|
|
||||||
if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH:
|
|
||||||
return ''
|
|
||||||
|
|
||||||
full_query = group_ids_filter + '(' + lucene_query + ')'
|
|
||||||
|
|
||||||
return full_query
|
|
||||||
|
|
||||||
|
|
||||||
async def get_episodes_by_mentions(
|
async def get_episodes_by_mentions(
|
||||||
driver: GraphDriver,
|
driver: GraphDriver,
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
|
|
@ -147,7 +126,7 @@ async def edge_fulltext_search(
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# fulltext search over facts
|
# fulltext search over facts
|
||||||
fuzzy_query = fulltext_query(query, driver, group_ids)
|
fuzzy_query = driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -304,7 +283,7 @@ async def node_fulltext_search(
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# BM25 search to get top nodes
|
# BM25 search to get top nodes
|
||||||
fuzzy_query = fulltext_query(query, driver, group_ids)
|
fuzzy_query = driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||||
|
|
@ -445,7 +424,7 @@ async def episode_fulltext_search(
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
# BM25 search to get top episodes
|
# BM25 search to get top episodes
|
||||||
fuzzy_query = fulltext_query(query, driver, group_ids)
|
fuzzy_query = driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -484,7 +463,7 @@ async def community_fulltext_search(
|
||||||
limit=RELEVANT_SCHEMA_LIMIT,
|
limit=RELEVANT_SCHEMA_LIMIT,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
# BM25 search to get top communities
|
# BM25 search to get top communities
|
||||||
fuzzy_query = fulltext_query(query, driver, group_ids)
|
fuzzy_query = driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
|
||||||
if fuzzy_query == '':
|
if fuzzy_query == '':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -704,7 +683,7 @@ async def get_relevant_nodes(
|
||||||
'uuid': node.uuid,
|
'uuid': node.uuid,
|
||||||
'name': node.name,
|
'name': node.name,
|
||||||
'name_embedding': node.name_embedding,
|
'name_embedding': node.name_embedding,
|
||||||
'fulltext_query': fulltext_query(node.name, driver, [node.group_id]),
|
'fulltext_query': driver.build_fulltext_query(node.name, [node.group_id], MAX_QUERY_LENGTH),
|
||||||
}
|
}
|
||||||
for node in nodes
|
for node in nodes
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue