add search and graph operations interfaces (#984)

* add search and graph operations interfaces

* update

* update

* update

* update

* update

* update
This commit is contained in:
Preston Rasmussen 2025-10-07 13:34:37 -04:00 committed by GitHub
parent 73015e980e
commit 604e3199a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 430 additions and 433 deletions

View file

@ -24,6 +24,9 @@ from typing import Any
from dotenv import load_dotenv
from graphiti_core.driver.graph_operations.graph_operations import GraphOperationsInterface
from graphiti_core.driver.search_interface.search_interface import SearchInterface
logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
@ -73,7 +76,8 @@ class GraphDriver(ABC):
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
)
_database: str
aoss_client: Any # type: ignore
search_interface: SearchInterface | None = None
graph_operations_interface: GraphOperationsInterface | None = None
@abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@ -109,9 +113,3 @@ class GraphDriver(ABC):
Only implemented by providers that need custom fulltext query building.
"""
raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
return 0
async def clear_aoss_indices(self):
return 1

View file

@ -0,0 +1,195 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Any
from pydantic import BaseModel
class GraphOperationsInterface(BaseModel):
"""
Interface for updating graph mutation behavior.
"""
# -----------------
# Node: Save/Delete
# -----------------
async def node_save(self, node: Any, driver: Any) -> None:
"""Persist (create or update) a single node."""
raise NotImplementedError
async def node_delete(self, node: Any, driver: Any) -> None:
raise NotImplementedError
async def node_save_bulk(
self,
_cls: Any, # kept for parity; callers won't pass it
driver: Any,
transaction: Any,
nodes: list[Any],
batch_size: int = 100,
) -> None:
"""Persist (create or update) many nodes in batches."""
raise NotImplementedError
async def node_delete_by_group_id(
self,
_cls: Any,
driver: Any,
group_id: str,
batch_size: int = 100,
) -> None:
raise NotImplementedError
async def node_delete_by_uuids(
self,
_cls: Any,
driver: Any,
uuids: list[str],
group_id: str | None = None,
batch_size: int = 100,
) -> None:
raise NotImplementedError
# --------------------------
# Node: Embeddings (load)
# --------------------------
async def node_load_embeddings(self, node: Any, driver: Any) -> None:
"""
Load embedding vectors for a single node into the instance (e.g., set node.embedding or similar).
"""
raise NotImplementedError
async def node_load_embeddings_bulk(
self,
_cls: Any,
driver: Any,
transaction: Any,
nodes: list[Any],
batch_size: int = 100,
) -> None:
"""
Load embedding vectors for many nodes in batches. Mutates the provided node instances.
"""
raise NotImplementedError
# --------------------------
# EpisodicNode: Save/Delete
# --------------------------
async def episodic_node_save(self, node: Any, driver: Any) -> None:
"""Persist (create or update) a single episodic node."""
raise NotImplementedError
async def episodic_node_delete(self, node: Any, driver: Any) -> None:
raise NotImplementedError
async def episodic_node_save_bulk(
self,
_cls: Any,
driver: Any,
transaction: Any,
nodes: list[Any],
batch_size: int = 100,
) -> None:
"""Persist (create or update) many episodic nodes in batches."""
raise NotImplementedError
async def episodic_edge_save_bulk(
self,
_cls: Any,
driver: Any,
transaction: Any,
episodic_edges: list[Any],
batch_size: int = 100,
) -> None:
"""Persist (create or update) many episodic edges in batches."""
raise NotImplementedError
async def episodic_node_delete_by_group_id(
self,
_cls: Any,
driver: Any,
group_id: str,
batch_size: int = 100,
) -> None:
raise NotImplementedError
async def episodic_node_delete_by_uuids(
self,
_cls: Any,
driver: Any,
uuids: list[str],
group_id: str | None = None,
batch_size: int = 100,
) -> None:
raise NotImplementedError
# -----------------
# Edge: Save/Delete
# -----------------
async def edge_save(self, edge: Any, driver: Any) -> None:
"""Persist (create or update) a single edge."""
raise NotImplementedError
async def edge_delete(self, edge: Any, driver: Any) -> None:
raise NotImplementedError
async def edge_save_bulk(
self,
_cls: Any,
driver: Any,
transaction: Any,
edges: list[Any],
batch_size: int = 100,
) -> None:
"""Persist (create or update) many edges in batches."""
raise NotImplementedError
async def edge_delete_by_uuids(
self,
_cls: Any,
driver: Any,
uuids: list[str],
group_id: str | None = None,
) -> None:
raise NotImplementedError
# -----------------
# Edge: Embeddings (load)
# -----------------
async def edge_load_embeddings(self, edge: Any, driver: Any) -> None:
"""
Load embedding vectors for a single edge into the instance (e.g., set edge.embedding or similar).
"""
raise NotImplementedError
async def edge_load_embeddings_bulk(
self,
_cls: Any,
driver: Any,
transaction: Any,
edges: list[Any],
batch_size: int = 100,
) -> None:
"""
Load embedding vectors for many edges in batches. Mutates the provided edge instances.
"""
raise NotImplementedError

View file

@ -0,0 +1,89 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Any
from pydantic import BaseModel
class SearchInterface(BaseModel):
"""
This is an interface for implementing custom search logic
"""
async def edge_fulltext_search(
self,
driver: Any,
query: str,
search_filter: Any,
group_ids: list[str] | None = None,
limit: int = 100,
) -> list[Any]:
raise NotImplementedError
async def edge_similarity_search(
self,
driver: Any,
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
search_filter: Any,
group_ids: list[str] | None = None,
limit: int = 100,
min_score: float = 0.7,
) -> list[Any]:
raise NotImplementedError
async def node_fulltext_search(
self,
driver: Any,
query: str,
search_filter: Any,
group_ids: list[str] | None = None,
limit: int = 100,
) -> list[Any]:
raise NotImplementedError
async def node_similarity_search(
self,
driver: Any,
search_vector: list[float],
search_filter: Any,
group_ids: list[str] | None = None,
limit: int = 100,
min_score: float = 0.7,
) -> list[Any]:
raise NotImplementedError
async def episode_fulltext_search(
self,
driver: Any,
query: str,
search_filter: Any, # kept for parity even if unused in your impl
group_ids: list[str] | None = None,
limit: int = 100,
) -> list[Any]:
raise NotImplementedError
# ---------- SEARCH FILTERS (sync) ----------
def build_node_search_filters(self, search_filters: Any) -> Any:
raise NotImplementedError
def build_edge_search_filters(self, search_filters: Any) -> Any:
raise NotImplementedError
class Config:
arbitrary_types_allowed = True

View file

@ -25,7 +25,7 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
@ -53,6 +53,9 @@ class Edge(BaseModel, ABC):
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.edge_delete(self, driver)
if driver.provider == GraphProvider.KUZU:
await driver.execute_query(
"""
@ -77,17 +80,13 @@ class Edge(BaseModel, ABC):
uuid=self.uuid,
)
if driver.aoss_client:
await driver.aoss_client.delete(
index=ENTITY_EDGE_INDEX_NAME,
id=self.uuid,
params={'routing': self.group_id},
)
logger.debug(f'Deleted Edge: {self.uuid}')
@classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
if driver.provider == GraphProvider.KUZU:
await driver.execute_query(
"""
@ -115,12 +114,6 @@ class Edge(BaseModel, ABC):
uuids=uuids,
)
if driver.aoss_client:
await driver.aoss_client.delete_by_query(
index=ENTITY_EDGE_INDEX_NAME,
body={'query': {'terms': {'uuid': uuids}}},
)
logger.debug(f'Deleted Edges: {uuids}')
def __hash__(self):
@ -258,6 +251,9 @@ class EntityEdge(Edge):
return self.fact_embedding
async def load_fact_embedding(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
@ -268,21 +264,6 @@ class EntityEdge(Edge):
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
"""
elif driver.aoss_client:
resp = await driver.aoss_client.search(
body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1,
},
index=ENTITY_EDGE_INDEX_NAME,
params={'routing': self.group_id},
)
if resp['hits']['hits']:
self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
return
else:
raise EdgeNotFoundError(self.uuid)
if driver.provider == GraphProvider.KUZU:
query = """
@ -320,15 +301,11 @@ class EntityEdge(Edge):
if driver.provider == GraphProvider.KUZU:
edge_data['attributes'] = json.dumps(self.attributes)
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
get_entity_edge_save_query(driver.provider),
**edge_data,
)
else:
edge_data.update(self.attributes or {})
if driver.aoss_client:
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
edge_data=edge_data,

View file

@ -27,10 +27,6 @@ from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import (
COMMUNITY_INDEX_NAME,
ENTITY_EDGE_INDEX_NAME,
ENTITY_INDEX_NAME,
EPISODE_INDEX_NAME,
GraphDriver,
GraphProvider,
)
@ -99,6 +95,9 @@ class Node(BaseModel, ABC):
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_delete(self, driver)
match driver.provider:
case GraphProvider.NEO4J:
records, _, _ = await driver.execute_query(
@ -113,27 +112,6 @@ class Node(BaseModel, ABC):
uuid=self.uuid,
)
edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else []
if driver.aoss_client:
# Delete the node from OpenSearch indices
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
await driver.aoss_client.delete(
index=index,
id=self.uuid,
params={'routing': self.group_id},
)
# Bulk delete the detached edges
if edge_uuids:
actions = []
for eid in edge_uuids:
actions.append(
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
)
await driver.aoss_client.bulk(body=actions)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
@ -181,6 +159,11 @@ class Node(BaseModel, ABC):
@classmethod
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_delete_by_group_id(
cls, driver, group_id, batch_size
)
match driver.provider:
case GraphProvider.NEO4J:
async with driver.session() as session:
@ -196,31 +179,6 @@ class Node(BaseModel, ABC):
batch_size=batch_size,
)
if driver.aoss_client:
await driver.aoss_client.delete_by_query(
index=EPISODE_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}},
params={'routing': group_id},
)
await driver.aoss_client.delete_by_query(
index=ENTITY_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}},
params={'routing': group_id},
)
await driver.aoss_client.delete_by_query(
index=COMMUNITY_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}},
params={'routing': group_id},
)
await driver.aoss_client.delete_by_query(
index=ENTITY_EDGE_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}},
params={'routing': group_id},
)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
@ -258,6 +216,11 @@ class Node(BaseModel, ABC):
@classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_delete_by_uuids(
cls, driver, uuids, group_id=None, batch_size=batch_size
)
match driver.provider:
case GraphProvider.FALKORDB:
for label in ['Entity', 'Episodic', 'Community']:
@ -300,7 +263,7 @@ class Node(BaseModel, ABC):
case _: # Neo4J, Neptune
async with driver.session() as session:
# Collect all edge UUIDs before deleting nodes
result = await session.run(
await session.run(
"""
MATCH (n:Entity|Episodic|Community)
WHERE n.uuid IN $uuids
@ -310,11 +273,6 @@ class Node(BaseModel, ABC):
uuids=uuids,
)
record = await result.single()
edge_uuids: list[str] = (
record['edge_uuids'] if record and record['edge_uuids'] else []
)
# Now delete the nodes in batches
await session.run(
"""
@ -329,20 +287,6 @@ class Node(BaseModel, ABC):
batch_size=batch_size,
)
if driver.aoss_client:
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
await driver.aoss_client.delete_by_query(
index=index,
body={'query': {'terms': {'uuid': uuids}}},
)
if edge_uuids:
actions = [
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
for eid in edge_uuids
]
await driver.aoss_client.bulk(body=actions)
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@ -363,6 +307,9 @@ class EpisodicNode(Node):
)
async def save(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.episodic_node_save(self, driver)
episode_args = {
'uuid': self.uuid,
'name': self.name,
@ -375,12 +322,6 @@ class EpisodicNode(Node):
'source': self.source.value,
}
if driver.aoss_client:
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episodes',
[episode_args],
)
result = await driver.execute_query(
get_episode_node_save_query(driver.provider), **episode_args
)
@ -510,26 +451,14 @@ class EntityNode(Node):
return self.name_embedding
async def load_name_embedding(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_load_embeddings(self, driver)
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
"""
elif driver.aoss_client:
resp = await driver.aoss_client.search(
body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1,
},
index=ENTITY_INDEX_NAME,
params={'routing': self.group_id},
)
if resp['hits']['hits']:
self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
return
else:
raise NodeNotFoundError(self.uuid)
else:
query: LiteralString = """
@ -548,6 +477,9 @@ class EntityNode(Node):
self.name_embedding = records[0]['name_embedding']
async def save(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_save(self, driver)
entity_data: dict[str, Any] = {
'uuid': self.uuid,
'name': self.name,
@ -568,11 +500,8 @@ class EntityNode(Node):
entity_data.update(self.attributes or {})
labels = ':'.join(self.labels + ['Entity'])
if driver.aoss_client:
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
get_entity_node_save_query(driver.provider, labels),
entity_data=entity_data,
)

View file

@ -249,41 +249,3 @@ def edge_search_filter_query_constructor(
filter_queries.append(expired_at_filter)
return filter_queries, filter_params
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters = [{'terms': {'group_id': group_ids}}]
if search_filters.node_labels:
filters.append({'terms': {'node_labels': search_filters.node_labels}})
return filters
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
filters: list[dict] = [{'terms': {'group_id': group_ids}}]
if search_filters.edge_types:
filters.append({'terms': {'edge_types': search_filters.edge_types}})
if search_filters.edge_uuids:
filters.append({'terms': {'uuid': search_filters.edge_uuids}})
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
ranges = getattr(search_filters, field)
if ranges:
# OR of ANDs
should_clauses = []
for and_group in ranges:
and_filters = []
for df in and_group: # df is a DateFilter
range_query = {
'range': {
field: {cypher_to_opensearch_operator(df.comparison_operator): df.date}
}
}
and_filters.append(range_query)
should_clauses.append({'bool': {'filter': and_filters}})
filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
return filters

View file

@ -24,9 +24,6 @@ from numpy._typing import NDArray
from typing_extensions import LiteralString
from graphiti_core.driver.driver import (
ENTITY_EDGE_INDEX_NAME,
ENTITY_INDEX_NAME,
EPISODE_INDEX_NAME,
GraphDriver,
GraphProvider,
)
@ -57,8 +54,6 @@ from graphiti_core.nodes import (
)
from graphiti_core.search.search_filters import (
SearchFilters,
build_aoss_edge_filters,
build_aoss_node_filters,
edge_search_filter_query_constructor,
node_search_filter_query_constructor,
)
@ -179,6 +174,11 @@ async def edge_fulltext_search(
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
if driver.search_interface:
return await driver.search_interface.edge_fulltext_search(
driver, query, search_filter, group_ids, limit
)
# fulltext search over facts
fuzzy_query = fulltext_query(query, group_ids, driver)
@ -217,11 +217,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values
query = (
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
+ filter_query
+ """
AND id(e)=id
@ -253,35 +253,6 @@ async def edge_fulltext_search(
)
else:
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = await driver.aoss_client.search(
index=ENTITY_EDGE_INDEX_NAME,
params={'routing': route},
body={
'size': limit,
'_source': ['uuid'],
'query': {
'bool': {
'filter': filters,
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
}
},
},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges
else:
return []
else:
query = (
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
@ -321,6 +292,18 @@ async def edge_similarity_search(
limit: int = RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityEdge]:
if driver.search_interface:
return await driver.search_interface.edge_similarity_search(
driver,
search_vector,
source_node_uuid,
target_node_uuid,
search_filter,
group_ids,
limit,
min_score,
)
match_query = """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
@ -356,8 +339,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -415,38 +398,6 @@ async def edge_similarity_search(
)
else:
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = await driver.aoss_client.search(
index=ENTITY_EDGE_INDEX_NAME,
params={'routing': route},
body={
'size': limit,
'_source': ['uuid'],
'query': {
'knn': {
'fact_embedding': {
'vector': list(map(float, search_vector)),
'k': limit,
'filter': {'bool': {'filter': filters}},
}
}
},
},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges
return []
else:
query = (
match_query
@ -609,6 +560,11 @@ async def node_fulltext_search(
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
if driver.search_interface:
return await driver.search_interface.node_fulltext_search(
driver, query, search_filter, group_ids, limit
)
# BM25 search to get top nodes
fuzzy_query = fulltext_query(query, group_ids, driver)
if fuzzy_query == '':
@ -640,11 +596,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -661,43 +617,6 @@ async def node_fulltext_search(
)
else:
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter)
res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME,
params={'routing': route},
body={
'_source': ['uuid'],
'size': limit,
'query': {
'bool': {
'filter': filters,
'must': [
{
'multi_match': {
'query': query,
'fields': ['name', 'summary'],
'operator': 'or',
}
}
],
}
},
},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get nodes
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entities
else:
return []
else:
query = (
get_nodes_query(
@ -735,6 +654,11 @@ async def node_similarity_search(
limit=RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
if driver.search_interface:
return await driver.search_interface.node_similarity_search(
driver, search_vector, search_filter, group_ids, limit, min_score
)
filter_queries, filter_params = node_search_filter_query_constructor(
search_filter, driver.provider
)
@ -754,8 +678,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -784,11 +708,11 @@ async def node_similarity_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -806,42 +730,11 @@ async def node_similarity_search(
)
else:
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter)
res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME,
params={'routing': route},
body={
'size': limit,
'_source': ['uuid'],
'query': {
'knn': {
'name_embedding': {
'vector': list(map(float, search_vector)),
'k': limit,
'filter': {'bool': {'filter': filters}},
}
}
},
},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get edges
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_nodes
return []
else:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
WITH n, """
@ -966,6 +859,11 @@ async def episode_fulltext_search(
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]:
if driver.search_interface:
return await driver.search_interface.episode_fulltext_search(
driver, query, _search_filter, group_ids, limit
)
# BM25 search to get top episodes
fuzzy_query = fulltext_query(query, group_ids, driver)
if fuzzy_query == '':
@ -1012,40 +910,6 @@ async def episode_fulltext_search(
)
else:
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
res = await driver.aoss_client.search(
index=EPISODE_INDEX_NAME,
params={'routing': route},
body={
'size': limit,
'_source': ['uuid'],
'bool': {
'filter': {'terms': group_ids},
'must': [
{
'multi_match': {
'query': query,
'field': ['name', 'content'],
'operator': 'or',
}
}
],
},
},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
input_uuids[r['_source']['uuid']] = r['_score']
# Get nodes
episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return episodes
else:
return []
else:
query = (
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
@ -1173,8 +1037,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Community)
"""
MATCH (n:Community)
"""
+ group_filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1233,8 +1097,8 @@ async def community_similarity_search(
query = (
"""
MATCH (c:Community)
"""
MATCH (c:Community)
"""
+ group_filter_query
+ """
WITH c,
@ -1376,9 +1240,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1423,9 +1287,9 @@ async def get_relevant_nodes(
else:
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1514,9 +1378,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge
@ -1586,9 +1450,9 @@ async def get_relevant_edges(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, n, m, """
@ -1624,9 +1488,9 @@ async def get_relevant_edges(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, """
@ -1699,10 +1563,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH e, edge
@ -1772,10 +1636,10 @@ async def get_edge_invalidation_candidates(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
+ filter_query
+ """
WITH edge, e, n, m, """
@ -1811,10 +1675,10 @@ async def get_edge_invalidation_candidates(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH edge, e, """

View file

@ -24,9 +24,6 @@ from pydantic import BaseModel, Field
from typing_extensions import Any
from graphiti_core.driver.driver import (
ENTITY_EDGE_INDEX_NAME,
ENTITY_INDEX_NAME,
EPISODE_INDEX_NAME,
GraphDriver,
GraphDriverSession,
GraphProvider,
@ -177,12 +174,10 @@ async def add_nodes_and_edges_bulk_tx(
'group_id': node.group_id,
'summary': node.summary,
'created_at': node.created_at,
'name_embedding': node.name_embedding,
'labels': list(set(node.labels + ['Entity'])),
}
if not bool(driver.aoss_client):
entity_data['name_embedding'] = node.name_embedding
entity_data['labels'] = list(set(node.labels + ['Entity']))
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
entity_data['attributes'] = json.dumps(attributes)
@ -207,11 +202,9 @@ async def add_nodes_and_edges_bulk_tx(
'expired_at': edge.expired_at,
'valid_at': edge.valid_at,
'invalid_at': edge.invalid_at,
'fact_embedding': edge.fact_embedding,
}
if not bool(driver.aoss_client):
edge_data['fact_embedding'] = edge.fact_embedding
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
edge_data['attributes'] = json.dumps(attributes)
@ -220,7 +213,17 @@ async def add_nodes_and_edges_bulk_tx(
edges.append(edge_data)
if driver.provider == GraphProvider.KUZU:
if driver.graph_operations_interface:
await driver.graph_operations_interface.episodic_node_save_bulk(
None, driver, tx, episodic_nodes
)
await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
await driver.graph_operations_interface.episodic_edge_save_bulk(
None, driver, tx, episodic_edges
)
await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
elif driver.provider == GraphProvider.KUZU:
# FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
episode_query = get_episode_node_save_bulk_query(driver.provider)
for episode in episodes:
@ -237,9 +240,7 @@ async def add_nodes_and_edges_bulk_tx(
else:
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
await tx.run(
get_entity_node_save_bulk_query(
driver.provider, nodes, has_aoss=bool(driver.aoss_client)
),
get_entity_node_save_bulk_query(driver.provider, nodes),
nodes=nodes,
)
await tx.run(
@ -247,23 +248,10 @@ async def add_nodes_and_edges_bulk_tx(
episodic_edges=[edge.model_dump() for edge in episodic_edges],
)
await tx.run(
get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)),
get_entity_edge_save_bulk_query(driver.provider),
entity_edges=edges,
)
if bool(driver.aoss_client):
for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
if node_data.get('uuid') == entity_node.uuid:
node_data['name_embedding'] = entity_node.name_embedding
for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
if edge_data.get('uuid') == entity_edge.uuid:
edge_data['fact_embedding'] = entity_edge.fact_embedding
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
async def extract_nodes_and_edges_bulk(
clients: GraphitiClients,

View file

@ -34,9 +34,6 @@ logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
if driver.aoss_client:
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
return
if delete_existing:
records, _, _ = await driver.execute_query(
"""
@ -56,8 +53,8 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
range_indices: list[LiteralString] = get_range_indices(driver.provider)
# Don't create fulltext indices if OpenSearch is being used
if not driver.aoss_client:
# Don't create fulltext indices if search_interface is being used
if not driver.search_interface:
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
if driver.provider == GraphProvider.KUZU:
@ -95,8 +92,6 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
async def delete_all(tx):
await tx.run('MATCH (n) DETACH DELETE n')
if driver.aoss_client:
await driver.clear_aoss_indices()
async def delete_group_ids(tx):
labels = ['Entity', 'Episodic', 'Community']
@ -153,9 +148,9 @@ async def retrieve_episodes(
query: LiteralString = (
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
+ query_filter
+ """
RETURN

2
uv.lock generated
View file

@ -1,5 +1,5 @@
version = 1
revision = 3
revision = 2
requires-python = ">=3.10, <4"
resolution-markers = [
"python_full_version >= '3.14'",