Bulk add nodes and edges (#205)

* test

* only use parallel runtime if set to true

* add and test bulk add

* remove group_ids

* format

* bump version

* update readme
This commit is contained in:
Preston Rasmussen 2024-10-31 12:31:37 -04:00 committed by GitHub
parent 63a1b11142
commit b8f52670ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 135 additions and 32 deletions

View file

@ -173,6 +173,17 @@ The `server` directory contains an API service for interacting with the Graphiti
Please see the [server README](./server/README.md) for more information.
## Optional Environment Variables
In addition to the Neo4j and OpenAi-compatible credentials, Graphiti also has a few optional environment variables.
If you are using one of our supported models, such as Anthropic or Voyage models, the necessary environment variables
must be set.
`USE_PARALLEL_RUNTIME` is an optional boolean variable that can be set to true if you wish
to enable Neo4j's parallel runtime feature for several of our search queries.
Note that this feature is not supported for Neo4j Community edition or for smaller AuraDB instances,
as such this feature is off by default.
## Documentation
- [Guides and API documentation](https://help.getzep.com/graphiti).
@ -186,11 +197,11 @@ Graphiti is under active development. We aim to maintain API stability while wor
- [x] Implementing node and edge CRUD operations
- [ ] Improving performance and scalability
- [ ] Achieving good performance with different LLM and embedding models
- [ ] Creating a dedicated embedder interface
- [x] Creating a dedicated embedder interface
- [ ] Supporting custom graph schemas:
- Allow developers to provide their own defined node and edge classes when ingesting episodes
- Enable more flexible knowledge representation tailored to specific use cases
- [ ] Enhancing retrieval capabilities with more robust and configurable options
- [x] Enhancing retrieval capabilities with more robust and configurable options
- [ ] Expanding test coverage to ensure reliability and catch edge cases
## Contributing

View file

@ -51,6 +51,7 @@ from graphiti_core.utils import (
)
from graphiti_core.utils.bulk_utils import (
RawEpisode,
add_nodes_and_edges_bulk,
dedupe_edges_bulk,
dedupe_nodes_bulk,
extract_edge_dates_bulk,
@ -451,10 +452,9 @@ class Graphiti:
if not self.store_raw_episode_content:
episode.content = ''
await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
await add_nodes_and_edges_bulk(
self.driver, [episode], episodic_edges, nodes, entity_edges
)
# Update any communities
if update_communities:

View file

@ -24,6 +24,7 @@ from neo4j import time as neo4j_time
load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:

View file

@ -5,6 +5,15 @@ EPISODIC_EDGE_SAVE = """
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid"""
EPISODIC_EDGE_SAVE_BULK = """
UNWIND $episodic_edges AS edge
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
MATCH (node:Entity {uuid: edge.target_node_uuid})
MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
RETURN r.uuid AS uuid
"""
ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
@ -14,6 +23,17 @@ ENTITY_EDGE_SAVE = """
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid"""
ENTITY_EDGE_SAVE_BULK = """
UNWIND $entity_edges AS edge
MATCH (source:Entity {uuid: edge.source_node_uuid})
MATCH (target:Entity {uuid: edge.target_node_uuid})
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
RETURN r.uuid AS uuid
"""
COMMUNITY_EDGE_SAVE = """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})

View file

@ -4,12 +4,29 @@ EPISODIC_NODE_SAVE = """
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid"""
EPISODIC_NODE_SAVE_BULK = """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
ENTITY_NODE_SAVE = """
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""
ENTITY_NODE_SAVE_BULK = """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at}
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid
"""
COMMUNITY_NODE_SAVE = """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}

View file

@ -21,9 +21,15 @@ from time import time
import numpy as np
from neo4j import AsyncDriver, Query
from typing_extensions import LiteralString
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.helpers import DEFAULT_DATABASE, lucene_sanitize, normalize_l2
from graphiti_core.helpers import (
DEFAULT_DATABASE,
USE_PARALLEL_RUNTIME,
lucene_sanitize,
normalize_l2,
)
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
@ -38,7 +44,7 @@ RELEVANT_SCHEMA_LIMIT = 3
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 128
MAX_QUERY_LENGTH = 32
def fulltext_query(query: str, group_ids: list[str] | None = None):
@ -187,8 +193,11 @@ async def edge_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
query = Query("""
CYPHER runtime = parallel parallelRuntimeSupport=all
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)
query: LiteralString = """
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
@ -210,10 +219,10 @@ async def edge_similarity_search(
r.invalid_at AS invalid_at
ORDER BY score DESC
LIMIT $limit
""")
"""
records, _, _ = await driver.execute_query(
query,
runtime_query + query,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
@ -318,9 +327,13 @@ async def node_similarity_search(
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
# vector similarity search over entity names
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)
records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
runtime_query
+ """
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
@ -425,23 +438,27 @@ async def community_similarity_search(
min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]:
# vector similarity search over entity names
runtime_query: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)
records, _, _ = await driver.execute_query(
"""
CYPHER runtime = parallel parallelRuntimeSupport=all
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
runtime_query
+ """
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.name_embedding AS name_embedding,
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,
limit=limit,

View file

@ -21,12 +21,20 @@ from collections import defaultdict
from datetime import datetime
from math import ceil
from neo4j import AsyncDriver
from neo4j import AsyncDriver, AsyncManagedTransaction
from numpy import dot, sqrt
from pydantic import BaseModel
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient
from graphiti_core.models.edges.edge_db_queries import (
ENTITY_EDGE_SAVE_BULK,
EPISODIC_EDGE_SAVE_BULK,
)
from graphiti_core.models.nodes.node_db_queries import (
ENTITY_NODE_SAVE_BULK,
EPISODIC_NODE_SAVE_BULK,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
from graphiti_core.utils import retrieve_episodes
@ -75,6 +83,35 @@ async def retrieve_previous_episodes_bulk(
return episode_tuples
async def add_nodes_and_edges_bulk(
driver: AsyncDriver,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
):
async with driver.session() as session:
await session.execute_write(
add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges
)
async def add_nodes_and_edges_bulk_tx(
tx: AsyncManagedTransaction,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
):
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=[dict(entity) for entity in entity_nodes])
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.3.19"
version = "0.3.20"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",