Bulk add nodes and edges (#205)
* test * only use parallel runtime if set to true * add and test bulk add * remove group_ids * format * bump version * update readme
This commit is contained in:
parent
63a1b11142
commit
b8f52670ce
8 changed files with 135 additions and 32 deletions
15
README.md
15
README.md
|
|
@ -173,6 +173,17 @@ The `server` directory contains an API service for interacting with the Graphiti
|
|||
|
||||
Please see the [server README](./server/README.md) for more information.
|
||||
|
||||
## Optional Environment Variables
|
||||
|
||||
In addition to the Neo4j and OpenAi-compatible credentials, Graphiti also has a few optional environment variables.
|
||||
If you are using one of our supported models, such as Anthropic or Voyage models, the necessary environment variables
|
||||
must be set.
|
||||
|
||||
`USE_PARALLEL_RUNTIME` is an optional boolean variable that can be set to true if you wish
|
||||
to enable Neo4j's parallel runtime feature for several of our search queries.
|
||||
Note that this feature is not supported for Neo4j Community edition or for smaller AuraDB instances,
|
||||
as such this feature is off by default.
|
||||
|
||||
## Documentation
|
||||
|
||||
- [Guides and API documentation](https://help.getzep.com/graphiti).
|
||||
|
|
@ -186,11 +197,11 @@ Graphiti is under active development. We aim to maintain API stability while wor
|
|||
- [x] Implementing node and edge CRUD operations
|
||||
- [ ] Improving performance and scalability
|
||||
- [ ] Achieving good performance with different LLM and embedding models
|
||||
- [ ] Creating a dedicated embedder interface
|
||||
- [x] Creating a dedicated embedder interface
|
||||
- [ ] Supporting custom graph schemas:
|
||||
- Allow developers to provide their own defined node and edge classes when ingesting episodes
|
||||
- Enable more flexible knowledge representation tailored to specific use cases
|
||||
- [ ] Enhancing retrieval capabilities with more robust and configurable options
|
||||
- [x] Enhancing retrieval capabilities with more robust and configurable options
|
||||
- [ ] Expanding test coverage to ensure reliability and catch edge cases
|
||||
|
||||
## Contributing
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ from graphiti_core.utils import (
|
|||
)
|
||||
from graphiti_core.utils.bulk_utils import (
|
||||
RawEpisode,
|
||||
add_nodes_and_edges_bulk,
|
||||
dedupe_edges_bulk,
|
||||
dedupe_nodes_bulk,
|
||||
extract_edge_dates_bulk,
|
||||
|
|
@ -451,10 +452,9 @@ class Graphiti:
|
|||
if not self.store_raw_episode_content:
|
||||
episode.content = ''
|
||||
|
||||
await episode.save(self.driver)
|
||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
||||
await add_nodes_and_edges_bulk(
|
||||
self.driver, [episode], episodic_edges, nodes, entity_edges
|
||||
)
|
||||
|
||||
# Update any communities
|
||||
if update_communities:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from neo4j import time as neo4j_time
|
|||
load_dotenv()
|
||||
|
||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
||||
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
||||
|
||||
|
||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,15 @@ EPISODIC_EDGE_SAVE = """
|
|||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN r.uuid AS uuid"""
|
||||
|
||||
EPISODIC_EDGE_SAVE_BULK = """
|
||||
UNWIND $episodic_edges AS edge
|
||||
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
|
||||
MATCH (node:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
|
||||
SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
|
||||
RETURN r.uuid AS uuid
|
||||
"""
|
||||
|
||||
ENTITY_EDGE_SAVE = """
|
||||
MATCH (source:Entity {uuid: $source_uuid})
|
||||
MATCH (target:Entity {uuid: $target_uuid})
|
||||
|
|
@ -14,6 +23,17 @@ ENTITY_EDGE_SAVE = """
|
|||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
||||
RETURN r.uuid AS uuid"""
|
||||
|
||||
ENTITY_EDGE_SAVE_BULK = """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
|
||||
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
||||
RETURN r.uuid AS uuid
|
||||
"""
|
||||
|
||||
COMMUNITY_EDGE_SAVE = """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||
|
|
|
|||
|
|
@ -4,12 +4,29 @@ EPISODIC_NODE_SAVE = """
|
|||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid"""
|
||||
|
||||
EPISODIC_NODE_SAVE_BULK = """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
||||
source: episode.source, content: episode.content,
|
||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
ENTITY_NODE_SAVE = """
|
||||
MERGE (n:Entity {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||
RETURN n.uuid AS uuid"""
|
||||
|
||||
ENTITY_NODE_SAVE_BULK = """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at}
|
||||
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
COMMUNITY_NODE_SAVE = """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
|
|
|
|||
|
|
@ -21,9 +21,15 @@ from time import time
|
|||
|
||||
import numpy as np
|
||||
from neo4j import AsyncDriver, Query
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, lucene_sanitize, normalize_l2
|
||||
from graphiti_core.helpers import (
|
||||
DEFAULT_DATABASE,
|
||||
USE_PARALLEL_RUNTIME,
|
||||
lucene_sanitize,
|
||||
normalize_l2,
|
||||
)
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
EntityNode,
|
||||
|
|
@ -38,7 +44,7 @@ RELEVANT_SCHEMA_LIMIT = 3
|
|||
DEFAULT_MIN_SCORE = 0.6
|
||||
DEFAULT_MMR_LAMBDA = 0.5
|
||||
MAX_SEARCH_DEPTH = 3
|
||||
MAX_QUERY_LENGTH = 128
|
||||
MAX_QUERY_LENGTH = 32
|
||||
|
||||
|
||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||
|
|
@ -187,8 +193,11 @@ async def edge_similarity_search(
|
|||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
query = Query("""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
runtime_query: LiteralString = (
|
||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||
)
|
||||
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
||||
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
||||
|
|
@ -210,10 +219,10 @@ async def edge_similarity_search(
|
|||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""")
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
runtime_query + query,
|
||||
search_vector=search_vector,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
|
|
@ -318,9 +327,13 @@ async def node_similarity_search(
|
|||
min_score: float = DEFAULT_MIN_SCORE,
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
runtime_query: LiteralString = (
|
||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
runtime_query
|
||||
+ """
|
||||
MATCH (n:Entity)
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||
|
|
@ -425,23 +438,27 @@ async def community_similarity_search(
|
|||
min_score=DEFAULT_MIN_SCORE,
|
||||
) -> list[CommunityNode]:
|
||||
# vector similarity search over entity names
|
||||
runtime_query: LiteralString = (
|
||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
||||
MATCH (comm:Community)
|
||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
comm.uuid As uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.name_embedding AS name_embedding,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
runtime_query
|
||||
+ """
|
||||
MATCH (comm:Community)
|
||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
comm.uuid As uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.name_embedding AS name_embedding,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
|
|
|
|||
|
|
@ -21,12 +21,20 @@ from collections import defaultdict
|
|||
from datetime import datetime
|
||||
from math import ceil
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from neo4j import AsyncDriver, AsyncManagedTransaction
|
||||
from numpy import dot, sqrt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
ENTITY_EDGE_SAVE_BULK,
|
||||
EPISODIC_EDGE_SAVE_BULK,
|
||||
)
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
ENTITY_NODE_SAVE_BULK,
|
||||
EPISODIC_NODE_SAVE_BULK,
|
||||
)
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
||||
from graphiti_core.utils import retrieve_episodes
|
||||
|
|
@ -75,6 +83,35 @@ async def retrieve_previous_episodes_bulk(
|
|||
return episode_tuples
|
||||
|
||||
|
||||
async def add_nodes_and_edges_bulk(
|
||||
driver: AsyncDriver,
|
||||
episodic_nodes: list[EpisodicNode],
|
||||
episodic_edges: list[EpisodicEdge],
|
||||
entity_nodes: list[EntityNode],
|
||||
entity_edges: list[EntityEdge],
|
||||
):
|
||||
async with driver.session() as session:
|
||||
await session.execute_write(
|
||||
add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges
|
||||
)
|
||||
|
||||
|
||||
async def add_nodes_and_edges_bulk_tx(
|
||||
tx: AsyncManagedTransaction,
|
||||
episodic_nodes: list[EpisodicNode],
|
||||
episodic_edges: list[EpisodicEdge],
|
||||
entity_nodes: list[EntityNode],
|
||||
entity_edges: list[EntityEdge],
|
||||
):
|
||||
episodes = [dict(episode) for episode in episodic_nodes]
|
||||
for episode in episodes:
|
||||
episode['source'] = str(episode['source'].value)
|
||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=[dict(entity) for entity in entity_nodes])
|
||||
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
|
||||
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
|
||||
|
||||
|
||||
async def extract_nodes_and_edges_bulk(
|
||||
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.3.19"
|
||||
version = "0.3.20"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue