add search filters and similarity search

This commit is contained in:
prestonrasmussen 2025-09-07 22:34:15 -04:00
parent a5e69dc8a7
commit 6441be9934
3 changed files with 211 additions and 146 deletions

View file

@ -52,20 +52,16 @@ aoss_indices = [
'dims': 1024, 'dims': 1024,
'index': True, 'index': True,
'similarity': 'cosine', 'similarity': 'cosine',
'method': {
'engine': 'faiss',
'space_type': 'cosinesimil',
'name': 'hnsw',
'parameters': {'ef_construction': 128, 'm': 16},
},
}, },
} }
} }
}, },
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
'size': DEFAULT_SIZE,
'knn': {
'field': 'name_embedding',
'query_vector': [],
'k': DEFAULT_SIZE,
'num_candidates': 100,
},
},
}, },
{ {
'index_name': 'communities', 'index_name': 'communities',
@ -78,10 +74,6 @@ aoss_indices = [
} }
} }
}, },
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
'size': DEFAULT_SIZE,
},
}, },
{ {
'index_name': 'episodes', 'index_name': 'episodes',
@ -98,15 +90,6 @@ aoss_indices = [
} }
} }
}, },
'query': {
'query': {
'multi_match': {
'query': '',
'fields': ['content', 'source', 'source_description', 'group_id'],
}
},
'size': DEFAULT_SIZE,
},
}, },
{ {
'index_name': 'entity_edges', 'index_name': 'entity_edges',
@ -126,20 +109,16 @@ aoss_indices = [
'dims': 1024, 'dims': 1024,
'index': True, 'index': True,
'similarity': 'cosine', 'similarity': 'cosine',
'method': {
'engine': 'faiss',
'space_type': 'cosinesimil',
'name': 'hnsw',
'parameters': {'ef_construction': 128, 'm': 16},
},
}, },
} }
} }
}, },
'query': {
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
'size': DEFAULT_SIZE,
'knn': {
'field': 'fact_embedding',
'query_vector': [], # supply vector at runtime
'k': DEFAULT_SIZE,
'num_candidates': 100,
},
},
}, },
] ]

View file

@ -234,3 +234,34 @@ def edge_search_filter_query_constructor(
filter_queries.append(expired_at_filter) filter_queries.append(expired_at_filter)
return filter_queries, filter_params return filter_queries, filter_params
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters = [{'term': {'group_id': group_ids}}]
if search_filters.node_labels:
filters.append({'terms': {'node_labels': search_filters.node_labels}})
return filters
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters = [{'term': {'group_id': group_ids}}]
if search_filters.edge_types:
filters.append({'terms': {'edge_types': search_filters.edge_types}})
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
ranges = getattr(search_filters, field)
if ranges:
# OR of ANDs
should_clauses = []
for and_group in ranges:
and_filters = []
for df in and_group: # df is a DateFilter
range_query = {'range': {field: {df.op: df.value}}}
and_filters.append(range_query)
should_clauses.append({'bool': {'filter': and_filters}})
filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
return filters

View file

@ -51,6 +51,8 @@ from graphiti_core.nodes import (
) )
from graphiti_core.search.search_filters import ( from graphiti_core.search.search_filters import (
SearchFilters, SearchFilters,
build_aoss_edge_filters,
build_aoss_node_filters,
edge_search_filter_query_constructor, edge_search_filter_query_constructor,
node_search_filter_query_constructor, node_search_filter_query_constructor,
) )
@ -200,7 +202,6 @@ async def edge_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = [] input_ids = []
for r in res['hits']['hits']: for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -245,38 +246,26 @@ async def edge_fulltext_search(
else: else:
return [] return []
elif driver.aoss_client: elif driver.aoss_client:
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue filters = build_aoss_edge_filters(group_ids, search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=group_ids,
_source=['uuid'],
query={
'bool': {
'filter': filters,
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
}
},
)
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids input_uuids = {}
input_uuids = []
for r in res['hits']['hits']: for r in res['hits']['hits']:
input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']}) input_uuids[r['_source']['uuid']] = r['_score']
# Match the edge ids and return the values # Get edges
query = ( entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
""" return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
UNWIND $uuids as uuid
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND e.uuid=uuid
"""
+ filter_query
+ """
AND e.uuid=uuid
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m"""
+ get_entity_edge_return_query(driver.provider)
+ """ORDER BY score DESC LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
query=fuzzy_query,
uuids=input_uuids,
limit=limit,
routing_='r',
**filter_params,
)
else: else:
return [] return []
else: else:
@ -412,6 +401,30 @@ async def edge_similarity_search(
) )
else: else:
return [] return []
elif driver.aoss_client:
filters = build_aoss_edge_filters(group_ids, search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=group_ids,
_source=['uuid'],
knn={
'field': 'fact_embedding',
'query_vector': search_vector,
'k': limit,
'num_candidates': 1000,
},
query={'bool': {'filter': filters}},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
else: else:
query = ( query = (
match_query match_query
@ -598,7 +611,6 @@ async def node_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = [] input_ids = []
for r in res['hits']['hits']: for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -628,35 +640,36 @@ async def node_fulltext_search(
else: else:
return [] return []
elif driver.aoss_client: elif driver.aoss_client:
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue filters = build_aoss_node_filters(group_ids, search_filter)
if res['hits']['total']['value'] > 0: res = driver.aoss_client.search(
# Calculate Cosine similarity then return the edge ids 'entities',
input_uuids = [] routing=group_ids,
for r in res['hits']['hits']: _source=['uuid'],
input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']}) query={
'bool': {
# Match the edge ides and return the values 'filter': filters,
query = ( 'must': [
""" {
UNWIND $uuids as i 'multi_match': {
MATCH (n:Entity) 'query': query,
WHERE n.uuid=i.uuid 'field': ['name', 'summary'],
RETURN 'operator': 'or',
""" }
+ get_entity_node_return_query(driver.provider) }
+ """ ],
ORDER BY i.score DESC }
LIMIT $limit },
"""
)
records, _, _ = await driver.execute_query(
query,
uuids=input_uuids,
query=fuzzy_query,
limit=limit, limit=limit,
routing_='r',
**filter_params,
) )
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get nodes
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
return entities.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
else: else:
return [] return []
else: else:
@ -767,6 +780,29 @@ async def node_similarity_search(
) )
else: else:
return [] return []
elif driver.aoss_client:
filters = build_aoss_node_filters(group_ids, search_filter)
res = driver.aoss_client.search(
index='entities',
routing=group_ids,
_source=['uuid'],
knn={
'field': 'fact_embedding',
'query_vector': search_vector,
'k': limit,
'num_candidates': 1000,
},
query={'bool': {'filter': filters}},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
else: else:
query = ( query = (
""" """
@ -910,7 +946,6 @@ async def episode_fulltext_search(
if driver.provider == GraphProvider.NEPTUNE: if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = [] input_ids = []
for r in res['hits']['hits']: for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']}) input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@ -944,7 +979,27 @@ async def episode_fulltext_search(
else: else:
return [] return []
elif driver.aoss_client: elif driver.aoss_client:
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue res = driver.aoss_client.search(
'episodes',
routing=group_ids,
_source=['uuid'],
query={
'bool': {
'filter': [{'term': {'group_id': group_ids}}],
'must': [
{
'multi_match': {
'query': query,
'field': ['name', 'content'],
'operator': 'or',
}
}
],
}
},
limit=limit,
)
if res['hits']['total']['value'] > 0: if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids # Calculate Cosine similarity then return the edge ids
input_uuids = [] input_uuids = []