add search filters and similarity search

This commit is contained in:
prestonrasmussen 2025-09-07 22:34:15 -04:00
parent a5e69dc8a7
commit 6441be9934
3 changed files with 211 additions and 146 deletions

View file

@ -52,20 +52,16 @@ aoss_indices = [
'dims': 1024,
'index': True,
'similarity': 'cosine',
'method': {
'engine': 'faiss',
'space_type': 'cosinesimil',
'name': 'hnsw',
'parameters': {'ef_construction': 128, 'm': 16},
},
},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
'size': DEFAULT_SIZE,
'knn': {
'field': 'name_embedding',
'query_vector': [],
'k': DEFAULT_SIZE,
'num_candidates': 100,
},
},
},
{
'index_name': 'communities',
@ -78,10 +74,6 @@ aoss_indices = [
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'episodes',
@ -98,15 +90,6 @@ aoss_indices = [
}
}
},
'query': {
'query': {
'multi_match': {
'query': '',
'fields': ['content', 'source', 'source_description', 'group_id'],
}
},
'size': DEFAULT_SIZE,
},
},
{
'index_name': 'entity_edges',
@ -126,20 +109,16 @@ aoss_indices = [
'dims': 1024,
'index': True,
'similarity': 'cosine',
'method': {
'engine': 'faiss',
'space_type': 'cosinesimil',
'name': 'hnsw',
'parameters': {'ef_construction': 128, 'm': 16},
},
},
}
}
},
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
'size': DEFAULT_SIZE,
'knn': {
'field': 'fact_embedding',
'query_vector': [], # supply vector at runtime
'k': DEFAULT_SIZE,
'num_candidates': 100,
},
},
},
]

View file

@ -234,3 +234,34 @@ def edge_search_filter_query_constructor(
filter_queries.append(expired_at_filter)
return filter_queries, filter_params
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters = [{'term': {'group_id': group_ids}}]
if search_filters.node_labels:
filters.append({'terms': {'node_labels': search_filters.node_labels}})
return filters
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters = [{'term': {'group_id': group_ids}}]
if search_filters.edge_types:
filters.append({'terms': {'edge_types': search_filters.edge_types}})
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
ranges = getattr(search_filters, field)
if ranges:
# OR of ANDs
should_clauses = []
for and_group in ranges:
and_filters = []
for df in and_group: # df is a DateFilter
range_query = {'range': {field: {df.op: df.value}}}
and_filters.append(range_query)
should_clauses.append({'bool': {'filter': and_filters}})
filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
return filters

View file

@ -51,6 +51,8 @@ from graphiti_core.nodes import (
)
from graphiti_core.search.search_filters import (
SearchFilters,
build_aoss_edge_filters,
build_aoss_node_filters,
edge_search_filter_query_constructor,
node_search_filter_query_constructor,
)
@ -200,7 +202,6 @@ async def edge_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -208,11 +209,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values
query = (
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
+ filter_query
+ """
AND id(e)=id
@ -245,38 +246,26 @@ async def edge_fulltext_search(
else:
return []
elif driver.aoss_client:
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
filters = build_aoss_edge_filters(group_ids, search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=group_ids,
_source=['uuid'],
query={
'bool': {
'filter': filters,
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
}
},
)
if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_uuids = []
input_uuids = {}
for r in res['hits']['hits']:
input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']})
input_uuids[r['_source']['uuid']] = r['_score']
# Match the edge ids and return the values
query = (
"""
UNWIND $uuids as uuid
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND e.uuid=uuid
"""
+ filter_query
+ """
AND e.uuid=uuid
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m"""
+ get_entity_edge_return_query(driver.provider)
+ """ORDER BY score DESC LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
query=fuzzy_query,
uuids=input_uuids,
limit=limit,
routing_='r',
**filter_params,
)
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
else:
return []
else:
@ -353,8 +342,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -412,6 +401,30 @@ async def edge_similarity_search(
)
else:
return []
elif driver.aoss_client:
filters = build_aoss_edge_filters(group_ids, search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=group_ids,
_source=['uuid'],
knn={
'field': 'fact_embedding',
'query_vector': search_vector,
'k': limit,
'num_candidates': 1000,
},
query={'bool': {'filter': filters}},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
else:
query = (
match_query
@ -598,7 +611,6 @@ async def node_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -606,11 +618,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -628,35 +640,36 @@ async def node_fulltext_search(
else:
return []
elif driver.aoss_client:
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_uuids = []
for r in res['hits']['hits']:
input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']})
filters = build_aoss_node_filters(group_ids, search_filter)
res = driver.aoss_client.search(
'entities',
routing=group_ids,
_source=['uuid'],
query={
'bool': {
'filter': filters,
'must': [
{
'multi_match': {
'query': query,
'field': ['name', 'summary'],
'operator': 'or',
}
}
],
}
},
limit=limit,
)
# Match the edge ides and return the values
query = (
"""
UNWIND $uuids as i
MATCH (n:Entity)
WHERE n.uuid=i.uuid
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
uuids=input_uuids,
query=fuzzy_query,
limit=limit,
routing_='r',
**filter_params,
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get nodes
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
return entities.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
else:
return []
else:
@ -715,8 +728,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -745,11 +758,11 @@ async def node_similarity_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -767,11 +780,34 @@ async def node_similarity_search(
)
else:
return []
elif driver.aoss_client:
filters = build_aoss_node_filters(group_ids, search_filter)
res = driver.aoss_client.search(
index='entities',
routing=group_ids,
_source=['uuid'],
knn={
'field': 'fact_embedding',
'query_vector': search_vector,
'k': limit,
'num_candidates': 1000,
},
query={'bool': {'filter': filters}},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
else:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
WITH n, """
@ -910,7 +946,6 @@ async def episode_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -944,7 +979,27 @@ async def episode_fulltext_search(
else:
return []
elif driver.aoss_client:
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
res = driver.aoss_client.search(
'episodes',
routing=group_ids,
_source=['uuid'],
query={
'bool': {
'filter': [{'term': {'group_id': group_ids}}],
'must': [
{
'multi_match': {
'query': query,
'field': ['name', 'content'],
'operator': 'or',
}
}
],
}
},
limit=limit,
)
if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_uuids = []
@ -1106,8 +1161,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Community)
"""
MATCH (n:Community)
"""
+ group_filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1166,8 +1221,8 @@ async def community_similarity_search(
query = (
"""
MATCH (c:Community)
"""
MATCH (c:Community)
"""
+ group_filter_query
+ """
WITH c,
@ -1309,9 +1364,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1356,9 +1411,9 @@ async def get_relevant_nodes(
else:
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1447,9 +1502,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge
@ -1519,9 +1574,9 @@ async def get_relevant_edges(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, n, m, """
@ -1557,9 +1612,9 @@ async def get_relevant_edges(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, """
@ -1632,10 +1687,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH e, edge
@ -1705,10 +1760,10 @@ async def get_edge_invalidation_candidates(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
+ filter_query
+ """
WITH edge, e, n, m, """
@ -1744,10 +1799,10 @@ async def get_edge_invalidation_candidates(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH edge, e, """