add search filters and similarity search
This commit is contained in:
parent
a5e69dc8a7
commit
6441be9934
3 changed files with 211 additions and 146 deletions
|
|
@ -52,20 +52,16 @@ aoss_indices = [
|
||||||
'dims': 1024,
|
'dims': 1024,
|
||||||
'index': True,
|
'index': True,
|
||||||
'similarity': 'cosine',
|
'similarity': 'cosine',
|
||||||
|
'method': {
|
||||||
|
'engine': 'faiss',
|
||||||
|
'space_type': 'cosinesimil',
|
||||||
|
'name': 'hnsw',
|
||||||
|
'parameters': {'ef_construction': 128, 'm': 16},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'query': {
|
|
||||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
|
|
||||||
'size': DEFAULT_SIZE,
|
|
||||||
'knn': {
|
|
||||||
'field': 'name_embedding',
|
|
||||||
'query_vector': [],
|
|
||||||
'k': DEFAULT_SIZE,
|
|
||||||
'num_candidates': 100,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'index_name': 'communities',
|
'index_name': 'communities',
|
||||||
|
|
@ -78,10 +74,6 @@ aoss_indices = [
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'query': {
|
|
||||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
|
|
||||||
'size': DEFAULT_SIZE,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'index_name': 'episodes',
|
'index_name': 'episodes',
|
||||||
|
|
@ -98,15 +90,6 @@ aoss_indices = [
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'query': {
|
|
||||||
'query': {
|
|
||||||
'multi_match': {
|
|
||||||
'query': '',
|
|
||||||
'fields': ['content', 'source', 'source_description', 'group_id'],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'size': DEFAULT_SIZE,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'index_name': 'entity_edges',
|
'index_name': 'entity_edges',
|
||||||
|
|
@ -126,20 +109,16 @@ aoss_indices = [
|
||||||
'dims': 1024,
|
'dims': 1024,
|
||||||
'index': True,
|
'index': True,
|
||||||
'similarity': 'cosine',
|
'similarity': 'cosine',
|
||||||
|
'method': {
|
||||||
|
'engine': 'faiss',
|
||||||
|
'space_type': 'cosinesimil',
|
||||||
|
'name': 'hnsw',
|
||||||
|
'parameters': {'ef_construction': 128, 'm': 16},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'query': {
|
|
||||||
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
|
|
||||||
'size': DEFAULT_SIZE,
|
|
||||||
'knn': {
|
|
||||||
'field': 'fact_embedding',
|
|
||||||
'query_vector': [], # supply vector at runtime
|
|
||||||
'k': DEFAULT_SIZE,
|
|
||||||
'num_candidates': 100,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -234,3 +234,34 @@ def edge_search_filter_query_constructor(
|
||||||
filter_queries.append(expired_at_filter)
|
filter_queries.append(expired_at_filter)
|
||||||
|
|
||||||
return filter_queries, filter_params
|
return filter_queries, filter_params
|
||||||
|
|
||||||
|
|
||||||
|
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
||||||
|
filters = [{'term': {'group_id': group_ids}}]
|
||||||
|
|
||||||
|
if search_filters.node_labels:
|
||||||
|
filters.append({'terms': {'node_labels': search_filters.node_labels}})
|
||||||
|
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
|
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
||||||
|
filters = [{'term': {'group_id': group_ids}}]
|
||||||
|
|
||||||
|
if search_filters.edge_types:
|
||||||
|
filters.append({'terms': {'edge_types': search_filters.edge_types}})
|
||||||
|
|
||||||
|
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
|
||||||
|
ranges = getattr(search_filters, field)
|
||||||
|
if ranges:
|
||||||
|
# OR of ANDs
|
||||||
|
should_clauses = []
|
||||||
|
for and_group in ranges:
|
||||||
|
and_filters = []
|
||||||
|
for df in and_group: # df is a DateFilter
|
||||||
|
range_query = {'range': {field: {df.op: df.value}}}
|
||||||
|
and_filters.append(range_query)
|
||||||
|
should_clauses.append({'bool': {'filter': and_filters}})
|
||||||
|
filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
|
||||||
|
|
||||||
|
return filters
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,8 @@ from graphiti_core.nodes import (
|
||||||
)
|
)
|
||||||
from graphiti_core.search.search_filters import (
|
from graphiti_core.search.search_filters import (
|
||||||
SearchFilters,
|
SearchFilters,
|
||||||
|
build_aoss_edge_filters,
|
||||||
|
build_aoss_node_filters,
|
||||||
edge_search_filter_query_constructor,
|
edge_search_filter_query_constructor,
|
||||||
node_search_filter_query_constructor,
|
node_search_filter_query_constructor,
|
||||||
)
|
)
|
||||||
|
|
@ -200,7 +202,6 @@ async def edge_fulltext_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
|
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
|
||||||
if res['hits']['total']['value'] > 0:
|
if res['hits']['total']['value'] > 0:
|
||||||
# Calculate Cosine similarity then return the edge ids
|
|
||||||
input_ids = []
|
input_ids = []
|
||||||
for r in res['hits']['hits']:
|
for r in res['hits']['hits']:
|
||||||
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
||||||
|
|
@ -245,38 +246,26 @@ async def edge_fulltext_search(
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
|
filters = build_aoss_edge_filters(group_ids, search_filter)
|
||||||
|
res = driver.aoss_client.search(
|
||||||
|
index='entity_edges',
|
||||||
|
routing=group_ids,
|
||||||
|
_source=['uuid'],
|
||||||
|
query={
|
||||||
|
'bool': {
|
||||||
|
'filter': filters,
|
||||||
|
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
if res['hits']['total']['value'] > 0:
|
if res['hits']['total']['value'] > 0:
|
||||||
# Calculate Cosine similarity then return the edge ids
|
input_uuids = {}
|
||||||
input_uuids = []
|
|
||||||
for r in res['hits']['hits']:
|
for r in res['hits']['hits']:
|
||||||
input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']})
|
input_uuids[r['_source']['uuid']] = r['_score']
|
||||||
|
|
||||||
# Match the edge ids and return the values
|
# Get edges
|
||||||
query = (
|
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||||
"""
|
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
|
||||||
UNWIND $uuids as uuid
|
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
||||||
WHERE e.group_id IN $group_ids
|
|
||||||
AND e.uuid=uuid
|
|
||||||
"""
|
|
||||||
+ filter_query
|
|
||||||
+ """
|
|
||||||
AND e.uuid=uuid
|
|
||||||
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m"""
|
|
||||||
+ get_entity_edge_return_query(driver.provider)
|
|
||||||
+ """ORDER BY score DESC LIMIT $limit
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
query,
|
|
||||||
query=fuzzy_query,
|
|
||||||
uuids=input_uuids,
|
|
||||||
limit=limit,
|
|
||||||
routing_='r',
|
|
||||||
**filter_params,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
|
|
@ -412,6 +401,30 @@ async def edge_similarity_search(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
elif driver.aoss_client:
|
||||||
|
filters = build_aoss_edge_filters(group_ids, search_filter)
|
||||||
|
res = driver.aoss_client.search(
|
||||||
|
index='entity_edges',
|
||||||
|
routing=group_ids,
|
||||||
|
_source=['uuid'],
|
||||||
|
knn={
|
||||||
|
'field': 'fact_embedding',
|
||||||
|
'query_vector': search_vector,
|
||||||
|
'k': limit,
|
||||||
|
'num_candidates': 1000,
|
||||||
|
},
|
||||||
|
query={'bool': {'filter': filters}},
|
||||||
|
)
|
||||||
|
|
||||||
|
if res['hits']['total']['value'] > 0:
|
||||||
|
input_uuids = {}
|
||||||
|
for r in res['hits']['hits']:
|
||||||
|
input_uuids[r['_source']['uuid']] = r['_score']
|
||||||
|
|
||||||
|
# Get edges
|
||||||
|
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||||
|
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
match_query
|
match_query
|
||||||
|
|
@ -598,7 +611,6 @@ async def node_fulltext_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
||||||
if res['hits']['total']['value'] > 0:
|
if res['hits']['total']['value'] > 0:
|
||||||
# Calculate Cosine similarity then return the edge ids
|
|
||||||
input_ids = []
|
input_ids = []
|
||||||
for r in res['hits']['hits']:
|
for r in res['hits']['hits']:
|
||||||
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
||||||
|
|
@ -628,35 +640,36 @@ async def node_fulltext_search(
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
filters = build_aoss_node_filters(group_ids, search_filter)
|
||||||
if res['hits']['total']['value'] > 0:
|
res = driver.aoss_client.search(
|
||||||
# Calculate Cosine similarity then return the edge ids
|
'entities',
|
||||||
input_uuids = []
|
routing=group_ids,
|
||||||
for r in res['hits']['hits']:
|
_source=['uuid'],
|
||||||
input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']})
|
query={
|
||||||
|
'bool': {
|
||||||
# Match the edge ides and return the values
|
'filter': filters,
|
||||||
query = (
|
'must': [
|
||||||
"""
|
{
|
||||||
UNWIND $uuids as i
|
'multi_match': {
|
||||||
MATCH (n:Entity)
|
'query': query,
|
||||||
WHERE n.uuid=i.uuid
|
'field': ['name', 'summary'],
|
||||||
RETURN
|
'operator': 'or',
|
||||||
"""
|
}
|
||||||
+ get_entity_node_return_query(driver.provider)
|
}
|
||||||
+ """
|
],
|
||||||
ORDER BY i.score DESC
|
}
|
||||||
LIMIT $limit
|
},
|
||||||
"""
|
|
||||||
)
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
query,
|
|
||||||
uuids=input_uuids,
|
|
||||||
query=fuzzy_query,
|
|
||||||
limit=limit,
|
limit=limit,
|
||||||
routing_='r',
|
|
||||||
**filter_params,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if res['hits']['total']['value'] > 0:
|
||||||
|
input_uuids = {}
|
||||||
|
for r in res['hits']['hits']:
|
||||||
|
input_uuids[r['_source']['uuid']] = r['_score']
|
||||||
|
|
||||||
|
# Get nodes
|
||||||
|
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||||
|
return entities.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
|
|
@ -767,6 +780,29 @@ async def node_similarity_search(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
elif driver.aoss_client:
|
||||||
|
filters = build_aoss_node_filters(group_ids, search_filter)
|
||||||
|
res = driver.aoss_client.search(
|
||||||
|
index='entities',
|
||||||
|
routing=group_ids,
|
||||||
|
_source=['uuid'],
|
||||||
|
knn={
|
||||||
|
'field': 'fact_embedding',
|
||||||
|
'query_vector': search_vector,
|
||||||
|
'k': limit,
|
||||||
|
'num_candidates': 1000,
|
||||||
|
},
|
||||||
|
query={'bool': {'filter': filters}},
|
||||||
|
)
|
||||||
|
|
||||||
|
if res['hits']['total']['value'] > 0:
|
||||||
|
input_uuids = {}
|
||||||
|
for r in res['hits']['hits']:
|
||||||
|
input_uuids[r['_source']['uuid']] = r['_score']
|
||||||
|
|
||||||
|
# Get edges
|
||||||
|
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||||
|
return entity_edges.sort(key=lambda e: input_uuids.get(e, 0), reverse=True)
|
||||||
else:
|
else:
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
|
|
@ -910,7 +946,6 @@ async def episode_fulltext_search(
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.provider == GraphProvider.NEPTUNE:
|
||||||
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
||||||
if res['hits']['total']['value'] > 0:
|
if res['hits']['total']['value'] > 0:
|
||||||
# Calculate Cosine similarity then return the edge ids
|
|
||||||
input_ids = []
|
input_ids = []
|
||||||
for r in res['hits']['hits']:
|
for r in res['hits']['hits']:
|
||||||
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
||||||
|
|
@ -944,7 +979,27 @@ async def episode_fulltext_search(
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
elif driver.aoss_client:
|
elif driver.aoss_client:
|
||||||
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
res = driver.aoss_client.search(
|
||||||
|
'episodes',
|
||||||
|
routing=group_ids,
|
||||||
|
_source=['uuid'],
|
||||||
|
query={
|
||||||
|
'bool': {
|
||||||
|
'filter': [{'term': {'group_id': group_ids}}],
|
||||||
|
'must': [
|
||||||
|
{
|
||||||
|
'multi_match': {
|
||||||
|
'query': query,
|
||||||
|
'field': ['name', 'content'],
|
||||||
|
'operator': 'or',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
if res['hits']['total']['value'] > 0:
|
if res['hits']['total']['value'] > 0:
|
||||||
# Calculate Cosine similarity then return the edge ids
|
# Calculate Cosine similarity then return the edge ids
|
||||||
input_uuids = []
|
input_uuids = []
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue