update mmr to use bulk load overrides
This commit is contained in:
parent
71f1f66d11
commit
c4f5c34598
3 changed files with 59 additions and 59 deletions
|
|
@ -77,9 +77,7 @@ class GraphOperationsInterface(BaseModel):
|
||||||
|
|
||||||
async def node_load_embeddings_bulk(
|
async def node_load_embeddings_bulk(
|
||||||
self,
|
self,
|
||||||
_cls: Any,
|
|
||||||
driver: Any,
|
driver: Any,
|
||||||
transaction: Any,
|
|
||||||
nodes: list[Any],
|
nodes: list[Any],
|
||||||
batch_size: int = 100,
|
batch_size: int = 100,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -183,9 +181,7 @@ class GraphOperationsInterface(BaseModel):
|
||||||
|
|
||||||
async def edge_load_embeddings_bulk(
|
async def edge_load_embeddings_bulk(
|
||||||
self,
|
self,
|
||||||
_cls: Any,
|
|
||||||
driver: Any,
|
driver: Any,
|
||||||
transaction: Any,
|
|
||||||
edges: list[Any],
|
edges: list[Any],
|
||||||
batch_size: int = 100,
|
batch_size: int = 100,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1879,7 +1879,9 @@ def maximal_marginal_relevance(
|
||||||
async def get_embeddings_for_nodes(
|
async def get_embeddings_for_nodes(
|
||||||
driver: GraphDriver, nodes: list[EntityNode]
|
driver: GraphDriver, nodes: list[EntityNode]
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.graph_operations_interface:
|
||||||
|
await driver.graph_operations_interface.node_load_embeddings_bulk(driver, nodes)
|
||||||
|
elif driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = """
|
query = """
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE n.uuid IN $node_uuids
|
WHERE n.uuid IN $node_uuids
|
||||||
|
|
@ -1949,7 +1951,9 @@ async def get_embeddings_for_communities(
|
||||||
async def get_embeddings_for_edges(
|
async def get_embeddings_for_edges(
|
||||||
driver: GraphDriver, edges: list[EntityEdge]
|
driver: GraphDriver, edges: list[EntityEdge]
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
if driver.provider == GraphProvider.NEPTUNE:
|
if driver.graph_operations_interface:
|
||||||
|
await driver.graph_operations_interface.edge_load_embeddings_bulk(driver, edges)
|
||||||
|
elif driver.provider == GraphProvider.NEPTUNE:
|
||||||
query = """
|
query = """
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
||||||
WHERE e.uuid IN $edge_uuids
|
WHERE e.uuid IN $edge_uuids
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.22.1pre1"
|
version = "0.22.1pre2"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue