add search and graph operations interfaces (#984)
* add search and graph operations interfaces * update * update * update * update * update * update
This commit is contained in:
parent
73015e980e
commit
604e3199a3
12 changed files with 430 additions and 433 deletions
|
|
@ -24,6 +24,9 @@ from typing import Any
|
|||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core.driver.graph_operations.graph_operations import GraphOperationsInterface
|
||||
from graphiti_core.driver.search_interface.search_interface import SearchInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SIZE = 10
|
||||
|
|
@ -73,7 +76,8 @@ class GraphDriver(ABC):
|
|||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||
)
|
||||
_database: str
|
||||
aoss_client: Any # type: ignore
|
||||
search_interface: SearchInterface | None = None
|
||||
graph_operations_interface: GraphOperationsInterface | None = None
|
||||
|
||||
@abstractmethod
|
||||
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
||||
|
|
@ -109,9 +113,3 @@ class GraphDriver(ABC):
|
|||
Only implemented by providers that need custom fulltext query building.
|
||||
"""
|
||||
raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
|
||||
|
||||
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
||||
return 0
|
||||
|
||||
async def clear_aoss_indices(self):
|
||||
return 1
|
||||
|
|
|
|||
0
graphiti_core/driver/graph_operations/__init__.py
Normal file
0
graphiti_core/driver/graph_operations/__init__.py
Normal file
195
graphiti_core/driver/graph_operations/graph_operations.py
Normal file
195
graphiti_core/driver/graph_operations/graph_operations.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GraphOperationsInterface(BaseModel):
|
||||
"""
|
||||
Interface for updating graph mutation behavior.
|
||||
"""
|
||||
|
||||
# -----------------
|
||||
# Node: Save/Delete
|
||||
# -----------------
|
||||
|
||||
async def node_save(self, node: Any, driver: Any) -> None:
|
||||
"""Persist (create or update) a single node."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def node_delete(self, node: Any, driver: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def node_save_bulk(
|
||||
self,
|
||||
_cls: Any, # kept for parity; callers won't pass it
|
||||
driver: Any,
|
||||
transaction: Any,
|
||||
nodes: list[Any],
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Persist (create or update) many nodes in batches."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def node_delete_by_group_id(
|
||||
self,
|
||||
_cls: Any,
|
||||
driver: Any,
|
||||
group_id: str,
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def node_delete_by_uuids(
|
||||
self,
|
||||
_cls: Any,
|
||||
driver: Any,
|
||||
uuids: list[str],
|
||||
group_id: str | None = None,
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
# --------------------------
|
||||
# Node: Embeddings (load)
|
||||
# --------------------------
|
||||
|
||||
async def node_load_embeddings(self, node: Any, driver: Any) -> None:
|
||||
"""
|
||||
Load embedding vectors for a single node into the instance (e.g., set node.embedding or similar).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def node_load_embeddings_bulk(
|
||||
self,
|
||||
_cls: Any,
|
||||
driver: Any,
|
||||
transaction: Any,
|
||||
nodes: list[Any],
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
"""
|
||||
Load embedding vectors for many nodes in batches. Mutates the provided node instances.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# --------------------------
|
||||
# EpisodicNode: Save/Delete
|
||||
# --------------------------
|
||||
|
||||
async def episodic_node_save(self, node: Any, driver: Any) -> None:
|
||||
"""Persist (create or update) a single episodic node."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def episodic_node_delete(self, node: Any, driver: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def episodic_node_save_bulk(
|
||||
self,
|
||||
_cls: Any,
|
||||
driver: Any,
|
||||
transaction: Any,
|
||||
nodes: list[Any],
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Persist (create or update) many episodic nodes in batches."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def episodic_edge_save_bulk(
|
||||
self,
|
||||
_cls: Any,
|
||||
driver: Any,
|
||||
transaction: Any,
|
||||
episodic_edges: list[Any],
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Persist (create or update) many episodic edges in batches."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def episodic_node_delete_by_group_id(
|
||||
self,
|
||||
_cls: Any,
|
||||
driver: Any,
|
||||
group_id: str,
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def episodic_node_delete_by_uuids(
|
||||
self,
|
||||
_cls: Any,
|
||||
driver: Any,
|
||||
uuids: list[str],
|
||||
group_id: str | None = None,
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
# -----------------
|
||||
# Edge: Save/Delete
|
||||
# -----------------
|
||||
|
||||
async def edge_save(self, edge: Any, driver: Any) -> None:
|
||||
"""Persist (create or update) a single edge."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def edge_delete(self, edge: Any, driver: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def edge_save_bulk(
|
||||
self,
|
||||
_cls: Any,
|
||||
driver: Any,
|
||||
transaction: Any,
|
||||
edges: list[Any],
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Persist (create or update) many edges in batches."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def edge_delete_by_uuids(
|
||||
self,
|
||||
_cls: Any,
|
||||
driver: Any,
|
||||
uuids: list[str],
|
||||
group_id: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
# -----------------
|
||||
# Edge: Embeddings (load)
|
||||
# -----------------
|
||||
|
||||
async def edge_load_embeddings(self, edge: Any, driver: Any) -> None:
|
||||
"""
|
||||
Load embedding vectors for a single edge into the instance (e.g., set edge.embedding or similar).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def edge_load_embeddings_bulk(
|
||||
self,
|
||||
_cls: Any,
|
||||
driver: Any,
|
||||
transaction: Any,
|
||||
edges: list[Any],
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
"""
|
||||
Load embedding vectors for many edges in batches. Mutates the provided edge instances.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
0
graphiti_core/driver/search_interface/__init__.py
Normal file
0
graphiti_core/driver/search_interface/__init__.py
Normal file
89
graphiti_core/driver/search_interface/search_interface.py
Normal file
89
graphiti_core/driver/search_interface/search_interface.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SearchInterface(BaseModel):
|
||||
"""
|
||||
This is an interface for implementing custom search logic
|
||||
"""
|
||||
|
||||
async def edge_fulltext_search(
|
||||
self,
|
||||
driver: Any,
|
||||
query: str,
|
||||
search_filter: Any,
|
||||
group_ids: list[str] | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def edge_similarity_search(
|
||||
self,
|
||||
driver: Any,
|
||||
search_vector: list[float],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
search_filter: Any,
|
||||
group_ids: list[str] | None = None,
|
||||
limit: int = 100,
|
||||
min_score: float = 0.7,
|
||||
) -> list[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def node_fulltext_search(
|
||||
self,
|
||||
driver: Any,
|
||||
query: str,
|
||||
search_filter: Any,
|
||||
group_ids: list[str] | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def node_similarity_search(
|
||||
self,
|
||||
driver: Any,
|
||||
search_vector: list[float],
|
||||
search_filter: Any,
|
||||
group_ids: list[str] | None = None,
|
||||
limit: int = 100,
|
||||
min_score: float = 0.7,
|
||||
) -> list[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def episode_fulltext_search(
|
||||
self,
|
||||
driver: Any,
|
||||
query: str,
|
||||
search_filter: Any, # kept for parity even if unused in your impl
|
||||
group_ids: list[str] | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
# ---------- SEARCH FILTERS (sync) ----------
|
||||
def build_node_search_filters(self, search_filters: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def build_edge_search_filters(self, search_filters: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
|
@ -25,7 +25,7 @@ from uuid import uuid4
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
|
|
@ -53,6 +53,9 @@ class Edge(BaseModel, ABC):
|
|||
async def save(self, driver: GraphDriver): ...
|
||||
|
||||
async def delete(self, driver: GraphDriver):
|
||||
if driver.graph_operations_interface:
|
||||
return await driver.graph_operations_interface.edge_delete(self, driver)
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
await driver.execute_query(
|
||||
"""
|
||||
|
|
@ -77,17 +80,13 @@ class Edge(BaseModel, ABC):
|
|||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
await driver.aoss_client.delete(
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
id=self.uuid,
|
||||
params={'routing': self.group_id},
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Edge: {self.uuid}')
|
||||
|
||||
@classmethod
|
||||
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
if driver.graph_operations_interface:
|
||||
return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
await driver.execute_query(
|
||||
"""
|
||||
|
|
@ -115,12 +114,6 @@ class Edge(BaseModel, ABC):
|
|||
uuids=uuids,
|
||||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
body={'query': {'terms': {'uuid': uuids}}},
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Edges: {uuids}')
|
||||
|
||||
def __hash__(self):
|
||||
|
|
@ -258,6 +251,9 @@ class EntityEdge(Edge):
|
|||
return self.fact_embedding
|
||||
|
||||
async def load_fact_embedding(self, driver: GraphDriver):
|
||||
if driver.graph_operations_interface:
|
||||
return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
|
||||
|
||||
query = """
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN e.fact_embedding AS fact_embedding
|
||||
|
|
@ -268,21 +264,6 @@ class EntityEdge(Edge):
|
|||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
||||
"""
|
||||
elif driver.aoss_client:
|
||||
resp = await driver.aoss_client.search(
|
||||
body={
|
||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||
'size': 1,
|
||||
},
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
params={'routing': self.group_id},
|
||||
)
|
||||
|
||||
if resp['hits']['hits']:
|
||||
self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
|
||||
return
|
||||
else:
|
||||
raise EdgeNotFoundError(self.uuid)
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
query = """
|
||||
|
|
@ -320,15 +301,11 @@ class EntityEdge(Edge):
|
|||
if driver.provider == GraphProvider.KUZU:
|
||||
edge_data['attributes'] = json.dumps(self.attributes)
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
**edge_data,
|
||||
)
|
||||
else:
|
||||
edge_data.update(self.attributes or {})
|
||||
|
||||
if driver.aoss_client:
|
||||
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
edge_data=edge_data,
|
||||
|
|
|
|||
|
|
@ -27,10 +27,6 @@ from pydantic import BaseModel, Field
|
|||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import (
|
||||
COMMUNITY_INDEX_NAME,
|
||||
ENTITY_EDGE_INDEX_NAME,
|
||||
ENTITY_INDEX_NAME,
|
||||
EPISODE_INDEX_NAME,
|
||||
GraphDriver,
|
||||
GraphProvider,
|
||||
)
|
||||
|
|
@ -99,6 +95,9 @@ class Node(BaseModel, ABC):
|
|||
async def save(self, driver: GraphDriver): ...
|
||||
|
||||
async def delete(self, driver: GraphDriver):
|
||||
if driver.graph_operations_interface:
|
||||
return await driver.graph_operations_interface.node_delete(self, driver)
|
||||
|
||||
match driver.provider:
|
||||
case GraphProvider.NEO4J:
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -113,27 +112,6 @@ class Node(BaseModel, ABC):
|
|||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else []
|
||||
|
||||
if driver.aoss_client:
|
||||
# Delete the node from OpenSearch indices
|
||||
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
||||
await driver.aoss_client.delete(
|
||||
index=index,
|
||||
id=self.uuid,
|
||||
params={'routing': self.group_id},
|
||||
)
|
||||
|
||||
# Bulk delete the detached edges
|
||||
if edge_uuids:
|
||||
actions = []
|
||||
for eid in edge_uuids:
|
||||
actions.append(
|
||||
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
||||
)
|
||||
|
||||
await driver.aoss_client.bulk(body=actions)
|
||||
|
||||
case GraphProvider.KUZU:
|
||||
for label in ['Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
|
|
@ -181,6 +159,11 @@ class Node(BaseModel, ABC):
|
|||
|
||||
@classmethod
|
||||
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
|
||||
if driver.graph_operations_interface:
|
||||
return await driver.graph_operations_interface.node_delete_by_group_id(
|
||||
cls, driver, group_id, batch_size
|
||||
)
|
||||
|
||||
match driver.provider:
|
||||
case GraphProvider.NEO4J:
|
||||
async with driver.session() as session:
|
||||
|
|
@ -196,31 +179,6 @@ class Node(BaseModel, ABC):
|
|||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=EPISODE_INDEX_NAME,
|
||||
body={'query': {'term': {'group_id': group_id}}},
|
||||
params={'routing': group_id},
|
||||
)
|
||||
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=ENTITY_INDEX_NAME,
|
||||
body={'query': {'term': {'group_id': group_id}}},
|
||||
params={'routing': group_id},
|
||||
)
|
||||
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=COMMUNITY_INDEX_NAME,
|
||||
body={'query': {'term': {'group_id': group_id}}},
|
||||
params={'routing': group_id},
|
||||
)
|
||||
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
body={'query': {'term': {'group_id': group_id}}},
|
||||
params={'routing': group_id},
|
||||
)
|
||||
|
||||
case GraphProvider.KUZU:
|
||||
for label in ['Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
|
|
@ -258,6 +216,11 @@ class Node(BaseModel, ABC):
|
|||
|
||||
@classmethod
|
||||
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
||||
if driver.graph_operations_interface:
|
||||
return await driver.graph_operations_interface.node_delete_by_uuids(
|
||||
cls, driver, uuids, group_id=None, batch_size=batch_size
|
||||
)
|
||||
|
||||
match driver.provider:
|
||||
case GraphProvider.FALKORDB:
|
||||
for label in ['Entity', 'Episodic', 'Community']:
|
||||
|
|
@ -300,7 +263,7 @@ class Node(BaseModel, ABC):
|
|||
case _: # Neo4J, Neptune
|
||||
async with driver.session() as session:
|
||||
# Collect all edge UUIDs before deleting nodes
|
||||
result = await session.run(
|
||||
await session.run(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community)
|
||||
WHERE n.uuid IN $uuids
|
||||
|
|
@ -310,11 +273,6 @@ class Node(BaseModel, ABC):
|
|||
uuids=uuids,
|
||||
)
|
||||
|
||||
record = await result.single()
|
||||
edge_uuids: list[str] = (
|
||||
record['edge_uuids'] if record and record['edge_uuids'] else []
|
||||
)
|
||||
|
||||
# Now delete the nodes in batches
|
||||
await session.run(
|
||||
"""
|
||||
|
|
@ -329,20 +287,6 @@ class Node(BaseModel, ABC):
|
|||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
||||
await driver.aoss_client.delete_by_query(
|
||||
index=index,
|
||||
body={'query': {'terms': {'uuid': uuids}}},
|
||||
)
|
||||
|
||||
if edge_uuids:
|
||||
actions = [
|
||||
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
||||
for eid in edge_uuids
|
||||
]
|
||||
await driver.aoss_client.bulk(body=actions)
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||
|
||||
|
|
@ -363,6 +307,9 @@ class EpisodicNode(Node):
|
|||
)
|
||||
|
||||
async def save(self, driver: GraphDriver):
|
||||
if driver.graph_operations_interface:
|
||||
return await driver.graph_operations_interface.episodic_node_save(self, driver)
|
||||
|
||||
episode_args = {
|
||||
'uuid': self.uuid,
|
||||
'name': self.name,
|
||||
|
|
@ -375,12 +322,6 @@ class EpisodicNode(Node):
|
|||
'source': self.source.value,
|
||||
}
|
||||
|
||||
if driver.aoss_client:
|
||||
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
||||
'episodes',
|
||||
[episode_args],
|
||||
)
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_episode_node_save_query(driver.provider), **episode_args
|
||||
)
|
||||
|
|
@ -510,26 +451,14 @@ class EntityNode(Node):
|
|||
return self.name_embedding
|
||||
|
||||
async def load_name_embedding(self, driver: GraphDriver):
|
||||
if driver.graph_operations_interface:
|
||||
return await driver.graph_operations_interface.node_load_embeddings(self, driver)
|
||||
|
||||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
||||
"""
|
||||
elif driver.aoss_client:
|
||||
resp = await driver.aoss_client.search(
|
||||
body={
|
||||
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
||||
'size': 1,
|
||||
},
|
||||
index=ENTITY_INDEX_NAME,
|
||||
params={'routing': self.group_id},
|
||||
)
|
||||
|
||||
if resp['hits']['hits']:
|
||||
self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
|
||||
return
|
||||
else:
|
||||
raise NodeNotFoundError(self.uuid)
|
||||
|
||||
else:
|
||||
query: LiteralString = """
|
||||
|
|
@ -548,6 +477,9 @@ class EntityNode(Node):
|
|||
self.name_embedding = records[0]['name_embedding']
|
||||
|
||||
async def save(self, driver: GraphDriver):
|
||||
if driver.graph_operations_interface:
|
||||
return await driver.graph_operations_interface.node_save(self, driver)
|
||||
|
||||
entity_data: dict[str, Any] = {
|
||||
'uuid': self.uuid,
|
||||
'name': self.name,
|
||||
|
|
@ -568,11 +500,8 @@ class EntityNode(Node):
|
|||
entity_data.update(self.attributes or {})
|
||||
labels = ':'.join(self.labels + ['Entity'])
|
||||
|
||||
if driver.aoss_client:
|
||||
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
||||
|
||||
result = await driver.execute_query(
|
||||
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
||||
get_entity_node_save_query(driver.provider, labels),
|
||||
entity_data=entity_data,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -249,41 +249,3 @@ def edge_search_filter_query_constructor(
|
|||
filter_queries.append(expired_at_filter)
|
||||
|
||||
return filter_queries, filter_params
|
||||
|
||||
|
||||
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
||||
filters = [{'terms': {'group_id': group_ids}}]
|
||||
|
||||
if search_filters.node_labels:
|
||||
filters.append({'terms': {'node_labels': search_filters.node_labels}})
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
||||
filters: list[dict] = [{'terms': {'group_id': group_ids}}]
|
||||
|
||||
if search_filters.edge_types:
|
||||
filters.append({'terms': {'edge_types': search_filters.edge_types}})
|
||||
|
||||
if search_filters.edge_uuids:
|
||||
filters.append({'terms': {'uuid': search_filters.edge_uuids}})
|
||||
|
||||
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
|
||||
ranges = getattr(search_filters, field)
|
||||
if ranges:
|
||||
# OR of ANDs
|
||||
should_clauses = []
|
||||
for and_group in ranges:
|
||||
and_filters = []
|
||||
for df in and_group: # df is a DateFilter
|
||||
range_query = {
|
||||
'range': {
|
||||
field: {cypher_to_opensearch_operator(df.comparison_operator): df.date}
|
||||
}
|
||||
}
|
||||
and_filters.append(range_query)
|
||||
should_clauses.append({'bool': {'filter': and_filters}})
|
||||
filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
|
||||
|
||||
return filters
|
||||
|
|
|
|||
|
|
@ -24,9 +24,6 @@ from numpy._typing import NDArray
|
|||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import (
|
||||
ENTITY_EDGE_INDEX_NAME,
|
||||
ENTITY_INDEX_NAME,
|
||||
EPISODE_INDEX_NAME,
|
||||
GraphDriver,
|
||||
GraphProvider,
|
||||
)
|
||||
|
|
@ -57,8 +54,6 @@ from graphiti_core.nodes import (
|
|||
)
|
||||
from graphiti_core.search.search_filters import (
|
||||
SearchFilters,
|
||||
build_aoss_edge_filters,
|
||||
build_aoss_node_filters,
|
||||
edge_search_filter_query_constructor,
|
||||
node_search_filter_query_constructor,
|
||||
)
|
||||
|
|
@ -179,6 +174,11 @@ async def edge_fulltext_search(
|
|||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
if driver.search_interface:
|
||||
return await driver.search_interface.edge_fulltext_search(
|
||||
driver, query, search_filter, group_ids, limit
|
||||
)
|
||||
|
||||
# fulltext search over facts
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver)
|
||||
|
||||
|
|
@ -217,11 +217,11 @@ async def edge_fulltext_search(
|
|||
# Match the edge ids and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
AND id(e)=id
|
||||
|
|
@ -253,35 +253,6 @@ async def edge_fulltext_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
||||
res = await driver.aoss_client.search(
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
params={'routing': route},
|
||||
body={
|
||||
'size': limit,
|
||||
'_source': ['uuid'],
|
||||
'query': {
|
||||
'bool': {
|
||||
'filter': filters,
|
||||
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Get edges
|
||||
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entity_edges
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
query = (
|
||||
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
|
||||
|
|
@ -321,6 +292,18 @@ async def edge_similarity_search(
|
|||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
) -> list[EntityEdge]:
|
||||
if driver.search_interface:
|
||||
return await driver.search_interface.edge_similarity_search(
|
||||
driver,
|
||||
search_vector,
|
||||
source_node_uuid,
|
||||
target_node_uuid,
|
||||
search_filter,
|
||||
group_ids,
|
||||
limit,
|
||||
min_score,
|
||||
)
|
||||
|
||||
match_query = """
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
|
|
@ -356,8 +339,8 @@ async def edge_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||
|
|
@ -415,38 +398,6 @@ async def edge_similarity_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
||||
res = await driver.aoss_client.search(
|
||||
index=ENTITY_EDGE_INDEX_NAME,
|
||||
params={'routing': route},
|
||||
body={
|
||||
'size': limit,
|
||||
'_source': ['uuid'],
|
||||
'query': {
|
||||
'knn': {
|
||||
'fact_embedding': {
|
||||
'vector': list(map(float, search_vector)),
|
||||
'k': limit,
|
||||
'filter': {'bool': {'filter': filters}},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Get edges
|
||||
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entity_edges
|
||||
return []
|
||||
|
||||
else:
|
||||
query = (
|
||||
match_query
|
||||
|
|
@ -609,6 +560,11 @@ async def node_fulltext_search(
|
|||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
if driver.search_interface:
|
||||
return await driver.search_interface.node_fulltext_search(
|
||||
driver, query, search_filter, group_ids, limit
|
||||
)
|
||||
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver)
|
||||
if fuzzy_query == '':
|
||||
|
|
@ -640,11 +596,11 @@ async def node_fulltext_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -661,43 +617,6 @@ async def node_fulltext_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
||||
res = await driver.aoss_client.search(
|
||||
index=ENTITY_INDEX_NAME,
|
||||
params={'routing': route},
|
||||
body={
|
||||
'_source': ['uuid'],
|
||||
'size': limit,
|
||||
'query': {
|
||||
'bool': {
|
||||
'filter': filters,
|
||||
'must': [
|
||||
{
|
||||
'multi_match': {
|
||||
'query': query,
|
||||
'fields': ['name', 'summary'],
|
||||
'operator': 'or',
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Get nodes
|
||||
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entities
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
query = (
|
||||
get_nodes_query(
|
||||
|
|
@ -735,6 +654,11 @@ async def node_similarity_search(
|
|||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
) -> list[EntityNode]:
|
||||
if driver.search_interface:
|
||||
return await driver.search_interface.node_similarity_search(
|
||||
driver, search_vector, search_filter, group_ids, limit, min_score
|
||||
)
|
||||
|
||||
filter_queries, filter_params = node_search_filter_query_constructor(
|
||||
search_filter, driver.provider
|
||||
)
|
||||
|
|
@ -754,8 +678,8 @@ async def node_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -784,11 +708,11 @@ async def node_similarity_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -806,42 +730,11 @@ async def node_similarity_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
||||
res = await driver.aoss_client.search(
|
||||
index=ENTITY_INDEX_NAME,
|
||||
params={'routing': route},
|
||||
body={
|
||||
'size': limit,
|
||||
'_source': ['uuid'],
|
||||
'query': {
|
||||
'knn': {
|
||||
'name_embedding': {
|
||||
'vector': list(map(float, search_vector)),
|
||||
'k': limit,
|
||||
'filter': {'bool': {'filter': filters}},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Get edges
|
||||
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return entity_nodes
|
||||
return []
|
||||
else:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH n, """
|
||||
|
|
@ -966,6 +859,11 @@ async def episode_fulltext_search(
|
|||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EpisodicNode]:
|
||||
if driver.search_interface:
|
||||
return await driver.search_interface.episode_fulltext_search(
|
||||
driver, query, _search_filter, group_ids, limit
|
||||
)
|
||||
|
||||
# BM25 search to get top episodes
|
||||
fuzzy_query = fulltext_query(query, group_ids, driver)
|
||||
if fuzzy_query == '':
|
||||
|
|
@ -1012,40 +910,6 @@ async def episode_fulltext_search(
|
|||
)
|
||||
else:
|
||||
return []
|
||||
elif driver.aoss_client:
|
||||
route = group_ids[0] if group_ids else None
|
||||
res = await driver.aoss_client.search(
|
||||
index=EPISODE_INDEX_NAME,
|
||||
params={'routing': route},
|
||||
body={
|
||||
'size': limit,
|
||||
'_source': ['uuid'],
|
||||
'bool': {
|
||||
'filter': {'terms': group_ids},
|
||||
'must': [
|
||||
{
|
||||
'multi_match': {
|
||||
'query': query,
|
||||
'field': ['name', 'content'],
|
||||
'operator': 'or',
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if res['hits']['total']['value'] > 0:
|
||||
input_uuids = {}
|
||||
for r in res['hits']['hits']:
|
||||
input_uuids[r['_source']['uuid']] = r['_score']
|
||||
|
||||
# Get nodes
|
||||
episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
|
||||
episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
||||
return episodes
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
query = (
|
||||
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
|
||||
|
|
@ -1173,8 +1037,8 @@ async def community_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -1233,8 +1097,8 @@ async def community_similarity_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
WITH c,
|
||||
|
|
@ -1376,9 +1240,9 @@ async def get_relevant_nodes(
|
|||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1423,9 +1287,9 @@ async def get_relevant_nodes(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1514,9 +1378,9 @@ async def get_relevant_edges(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1586,9 +1450,9 @@ async def get_relevant_edges(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, n, m, """
|
||||
|
|
@ -1624,9 +1488,9 @@ async def get_relevant_edges(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, """
|
||||
|
|
@ -1699,10 +1563,10 @@ async def get_edge_invalidation_candidates(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1772,10 +1636,10 @@ async def get_edge_invalidation_candidates(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, n, m, """
|
||||
|
|
@ -1811,10 +1675,10 @@ async def get_edge_invalidation_candidates(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, """
|
||||
|
|
|
|||
|
|
@ -24,9 +24,6 @@ from pydantic import BaseModel, Field
|
|||
from typing_extensions import Any
|
||||
|
||||
from graphiti_core.driver.driver import (
|
||||
ENTITY_EDGE_INDEX_NAME,
|
||||
ENTITY_INDEX_NAME,
|
||||
EPISODE_INDEX_NAME,
|
||||
GraphDriver,
|
||||
GraphDriverSession,
|
||||
GraphProvider,
|
||||
|
|
@ -177,12 +174,10 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
'group_id': node.group_id,
|
||||
'summary': node.summary,
|
||||
'created_at': node.created_at,
|
||||
'name_embedding': node.name_embedding,
|
||||
'labels': list(set(node.labels + ['Entity'])),
|
||||
}
|
||||
|
||||
if not bool(driver.aoss_client):
|
||||
entity_data['name_embedding'] = node.name_embedding
|
||||
|
||||
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
||||
entity_data['attributes'] = json.dumps(attributes)
|
||||
|
|
@ -207,11 +202,9 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
'expired_at': edge.expired_at,
|
||||
'valid_at': edge.valid_at,
|
||||
'invalid_at': edge.invalid_at,
|
||||
'fact_embedding': edge.fact_embedding,
|
||||
}
|
||||
|
||||
if not bool(driver.aoss_client):
|
||||
edge_data['fact_embedding'] = edge.fact_embedding
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
||||
edge_data['attributes'] = json.dumps(attributes)
|
||||
|
|
@ -220,7 +213,17 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
|
||||
edges.append(edge_data)
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
if driver.graph_operations_interface:
|
||||
await driver.graph_operations_interface.episodic_node_save_bulk(
|
||||
None, driver, tx, episodic_nodes
|
||||
)
|
||||
await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
|
||||
await driver.graph_operations_interface.episodic_edge_save_bulk(
|
||||
None, driver, tx, episodic_edges
|
||||
)
|
||||
await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
|
||||
|
||||
elif driver.provider == GraphProvider.KUZU:
|
||||
# FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
|
||||
episode_query = get_episode_node_save_bulk_query(driver.provider)
|
||||
for episode in episodes:
|
||||
|
|
@ -237,9 +240,7 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
else:
|
||||
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
||||
await tx.run(
|
||||
get_entity_node_save_bulk_query(
|
||||
driver.provider, nodes, has_aoss=bool(driver.aoss_client)
|
||||
),
|
||||
get_entity_node_save_bulk_query(driver.provider, nodes),
|
||||
nodes=nodes,
|
||||
)
|
||||
await tx.run(
|
||||
|
|
@ -247,23 +248,10 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
||||
)
|
||||
await tx.run(
|
||||
get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
||||
get_entity_edge_save_bulk_query(driver.provider),
|
||||
entity_edges=edges,
|
||||
)
|
||||
|
||||
if bool(driver.aoss_client):
|
||||
for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
|
||||
if node_data.get('uuid') == entity_node.uuid:
|
||||
node_data['name_embedding'] = entity_node.name_embedding
|
||||
|
||||
for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
|
||||
if edge_data.get('uuid') == entity_edge.uuid:
|
||||
edge_data['fact_embedding'] = entity_edge.fact_embedding
|
||||
|
||||
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
||||
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
||||
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
clients: GraphitiClients,
|
||||
|
|
|
|||
|
|
@ -34,9 +34,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
||||
if driver.aoss_client:
|
||||
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
|
||||
return
|
||||
if delete_existing:
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
|
|
@ -56,8 +53,8 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
|||
|
||||
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
||||
|
||||
# Don't create fulltext indices if OpenSearch is being used
|
||||
if not driver.aoss_client:
|
||||
# Don't create fulltext indices if search_interface is being used
|
||||
if not driver.search_interface:
|
||||
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
|
|
@ -95,8 +92,6 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
|||
|
||||
async def delete_all(tx):
|
||||
await tx.run('MATCH (n) DETACH DELETE n')
|
||||
if driver.aoss_client:
|
||||
await driver.clear_aoss_indices()
|
||||
|
||||
async def delete_group_ids(tx):
|
||||
labels = ['Entity', 'Episodic', 'Community']
|
||||
|
|
@ -153,9 +148,9 @@ async def retrieve_episodes(
|
|||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ query_filter
|
||||
+ """
|
||||
RETURN
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.10, <4"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14'",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue