update mmr to use bulk load overrides

This commit is contained in:
prestonrasmussen 2025-10-29 09:22:25 -04:00
parent 71f1f66d11
commit c4f5c34598
3 changed files with 59 additions and 59 deletions

View file

@ -77,9 +77,7 @@ class GraphOperationsInterface(BaseModel):
async def node_load_embeddings_bulk( async def node_load_embeddings_bulk(
self, self,
_cls: Any,
driver: Any, driver: Any,
transaction: Any,
nodes: list[Any], nodes: list[Any],
batch_size: int = 100, batch_size: int = 100,
) -> None: ) -> None:
@ -183,9 +181,7 @@ class GraphOperationsInterface(BaseModel):
async def edge_load_embeddings_bulk( async def edge_load_embeddings_bulk(
self, self,
_cls: Any,
driver: Any, driver: Any,
transaction: Any,
edges: list[Any], edges: list[Any],
batch_size: int = 100, batch_size: int = 100,
) -> None: ) -> None:

View file

@ -1879,7 +1879,9 @@ def maximal_marginal_relevance(
async def get_embeddings_for_nodes( async def get_embeddings_for_nodes(
driver: GraphDriver, nodes: list[EntityNode] driver: GraphDriver, nodes: list[EntityNode]
) -> dict[str, list[float]]: ) -> dict[str, list[float]]:
if driver.provider == GraphProvider.NEPTUNE: if driver.graph_operations_interface:
await driver.graph_operations_interface.node_load_embeddings_bulk(driver, nodes)
elif driver.provider == GraphProvider.NEPTUNE:
query = """ query = """
MATCH (n:Entity) MATCH (n:Entity)
WHERE n.uuid IN $node_uuids WHERE n.uuid IN $node_uuids
@ -1949,7 +1951,9 @@ async def get_embeddings_for_communities(
async def get_embeddings_for_edges( async def get_embeddings_for_edges(
driver: GraphDriver, edges: list[EntityEdge] driver: GraphDriver, edges: list[EntityEdge]
) -> dict[str, list[float]]: ) -> dict[str, list[float]]:
if driver.provider == GraphProvider.NEPTUNE: if driver.graph_operations_interface:
await driver.graph_operations_interface.edge_load_embeddings_bulk(driver, edges)
elif driver.provider == GraphProvider.NEPTUNE:
query = """ query = """
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity) MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
WHERE e.uuid IN $edge_uuids WHERE e.uuid IN $edge_uuids

View file

@ -1,7 +1,7 @@
[project] [project]
name = "graphiti-core" name = "graphiti-core"
description = "A temporal graph building library" description = "A temporal graph building library"
version = "0.22.1pre1" version = "0.22.1pre2"
authors = [ authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" },