Bulk add nodes and edges (#205)
* test * only use parallel runtime if set to true * add and test bulk add * remove group_ids * format * bump version * update readme
This commit is contained in:
parent
63a1b11142
commit
b8f52670ce
8 changed files with 135 additions and 32 deletions
15
README.md
15
README.md
|
|
@ -173,6 +173,17 @@ The `server` directory contains an API service for interacting with the Graphiti
|
||||||
|
|
||||||
Please see the [server README](./server/README.md) for more information.
|
Please see the [server README](./server/README.md) for more information.
|
||||||
|
|
||||||
|
## Optional Environment Variables
|
||||||
|
|
||||||
|
In addition to the Neo4j and OpenAi-compatible credentials, Graphiti also has a few optional environment variables.
|
||||||
|
If you are using one of our supported models, such as Anthropic or Voyage models, the necessary environment variables
|
||||||
|
must be set.
|
||||||
|
|
||||||
|
`USE_PARALLEL_RUNTIME` is an optional boolean variable that can be set to true if you wish
|
||||||
|
to enable Neo4j's parallel runtime feature for several of our search queries.
|
||||||
|
Note that this feature is not supported for Neo4j Community edition or for smaller AuraDB instances,
|
||||||
|
as such this feature is off by default.
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
- [Guides and API documentation](https://help.getzep.com/graphiti).
|
- [Guides and API documentation](https://help.getzep.com/graphiti).
|
||||||
|
|
@ -186,11 +197,11 @@ Graphiti is under active development. We aim to maintain API stability while wor
|
||||||
- [x] Implementing node and edge CRUD operations
|
- [x] Implementing node and edge CRUD operations
|
||||||
- [ ] Improving performance and scalability
|
- [ ] Improving performance and scalability
|
||||||
- [ ] Achieving good performance with different LLM and embedding models
|
- [ ] Achieving good performance with different LLM and embedding models
|
||||||
- [ ] Creating a dedicated embedder interface
|
- [x] Creating a dedicated embedder interface
|
||||||
- [ ] Supporting custom graph schemas:
|
- [ ] Supporting custom graph schemas:
|
||||||
- Allow developers to provide their own defined node and edge classes when ingesting episodes
|
- Allow developers to provide their own defined node and edge classes when ingesting episodes
|
||||||
- Enable more flexible knowledge representation tailored to specific use cases
|
- Enable more flexible knowledge representation tailored to specific use cases
|
||||||
- [ ] Enhancing retrieval capabilities with more robust and configurable options
|
- [x] Enhancing retrieval capabilities with more robust and configurable options
|
||||||
- [ ] Expanding test coverage to ensure reliability and catch edge cases
|
- [ ] Expanding test coverage to ensure reliability and catch edge cases
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ from graphiti_core.utils import (
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.bulk_utils import (
|
from graphiti_core.utils.bulk_utils import (
|
||||||
RawEpisode,
|
RawEpisode,
|
||||||
|
add_nodes_and_edges_bulk,
|
||||||
dedupe_edges_bulk,
|
dedupe_edges_bulk,
|
||||||
dedupe_nodes_bulk,
|
dedupe_nodes_bulk,
|
||||||
extract_edge_dates_bulk,
|
extract_edge_dates_bulk,
|
||||||
|
|
@ -451,10 +452,9 @@ class Graphiti:
|
||||||
if not self.store_raw_episode_content:
|
if not self.store_raw_episode_content:
|
||||||
episode.content = ''
|
episode.content = ''
|
||||||
|
|
||||||
await episode.save(self.driver)
|
await add_nodes_and_edges_bulk(
|
||||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
self.driver, [episode], episodic_edges, nodes, entity_edges
|
||||||
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
)
|
||||||
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
|
||||||
|
|
||||||
# Update any communities
|
# Update any communities
|
||||||
if update_communities:
|
if update_communities:
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from neo4j import time as neo4j_time
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
||||||
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
||||||
|
|
||||||
|
|
||||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,15 @@ EPISODIC_EDGE_SAVE = """
|
||||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||||
RETURN r.uuid AS uuid"""
|
RETURN r.uuid AS uuid"""
|
||||||
|
|
||||||
|
EPISODIC_EDGE_SAVE_BULK = """
|
||||||
|
UNWIND $episodic_edges AS edge
|
||||||
|
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
|
||||||
|
MATCH (node:Entity {uuid: edge.target_node_uuid})
|
||||||
|
MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
|
||||||
|
SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
|
||||||
|
RETURN r.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
ENTITY_EDGE_SAVE = """
|
ENTITY_EDGE_SAVE = """
|
||||||
MATCH (source:Entity {uuid: $source_uuid})
|
MATCH (source:Entity {uuid: $source_uuid})
|
||||||
MATCH (target:Entity {uuid: $target_uuid})
|
MATCH (target:Entity {uuid: $target_uuid})
|
||||||
|
|
@ -14,6 +23,17 @@ ENTITY_EDGE_SAVE = """
|
||||||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
||||||
RETURN r.uuid AS uuid"""
|
RETURN r.uuid AS uuid"""
|
||||||
|
|
||||||
|
ENTITY_EDGE_SAVE_BULK = """
|
||||||
|
UNWIND $entity_edges AS edge
|
||||||
|
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||||
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||||
|
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||||
|
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||||
|
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at}
|
||||||
|
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
||||||
|
RETURN r.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
COMMUNITY_EDGE_SAVE = """
|
COMMUNITY_EDGE_SAVE = """
|
||||||
MATCH (community:Community {uuid: $community_uuid})
|
MATCH (community:Community {uuid: $community_uuid})
|
||||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,29 @@ EPISODIC_NODE_SAVE = """
|
||||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||||
RETURN n.uuid AS uuid"""
|
RETURN n.uuid AS uuid"""
|
||||||
|
|
||||||
|
EPISODIC_NODE_SAVE_BULK = """
|
||||||
|
UNWIND $episodes AS episode
|
||||||
|
MERGE (n:Episodic {uuid: episode.uuid})
|
||||||
|
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
||||||
|
source: episode.source, content: episode.content,
|
||||||
|
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
ENTITY_NODE_SAVE = """
|
ENTITY_NODE_SAVE = """
|
||||||
MERGE (n:Entity {uuid: $uuid})
|
MERGE (n:Entity {uuid: $uuid})
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||||
RETURN n.uuid AS uuid"""
|
RETURN n.uuid AS uuid"""
|
||||||
|
|
||||||
|
ENTITY_NODE_SAVE_BULK = """
|
||||||
|
UNWIND $nodes AS node
|
||||||
|
MERGE (n:Entity {uuid: node.uuid})
|
||||||
|
SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at}
|
||||||
|
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
COMMUNITY_NODE_SAVE = """
|
COMMUNITY_NODE_SAVE = """
|
||||||
MERGE (n:Community {uuid: $uuid})
|
MERGE (n:Community {uuid: $uuid})
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,15 @@ from time import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from neo4j import AsyncDriver, Query
|
from neo4j import AsyncDriver, Query
|
||||||
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, lucene_sanitize, normalize_l2
|
from graphiti_core.helpers import (
|
||||||
|
DEFAULT_DATABASE,
|
||||||
|
USE_PARALLEL_RUNTIME,
|
||||||
|
lucene_sanitize,
|
||||||
|
normalize_l2,
|
||||||
|
)
|
||||||
from graphiti_core.nodes import (
|
from graphiti_core.nodes import (
|
||||||
CommunityNode,
|
CommunityNode,
|
||||||
EntityNode,
|
EntityNode,
|
||||||
|
|
@ -38,7 +44,7 @@ RELEVANT_SCHEMA_LIMIT = 3
|
||||||
DEFAULT_MIN_SCORE = 0.6
|
DEFAULT_MIN_SCORE = 0.6
|
||||||
DEFAULT_MMR_LAMBDA = 0.5
|
DEFAULT_MMR_LAMBDA = 0.5
|
||||||
MAX_SEARCH_DEPTH = 3
|
MAX_SEARCH_DEPTH = 3
|
||||||
MAX_QUERY_LENGTH = 128
|
MAX_QUERY_LENGTH = 32
|
||||||
|
|
||||||
|
|
||||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||||
|
|
@ -187,8 +193,11 @@ async def edge_similarity_search(
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
# vector similarity search over embedded facts
|
# vector similarity search over embedded facts
|
||||||
query = Query("""
|
runtime_query: LiteralString = (
|
||||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||||
|
)
|
||||||
|
|
||||||
|
query: LiteralString = """
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||||
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
||||||
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
||||||
|
|
@ -210,10 +219,10 @@ async def edge_similarity_search(
|
||||||
r.invalid_at AS invalid_at
|
r.invalid_at AS invalid_at
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""")
|
"""
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
runtime_query + query,
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
source_uuid=source_node_uuid,
|
source_uuid=source_node_uuid,
|
||||||
target_uuid=target_node_uuid,
|
target_uuid=target_node_uuid,
|
||||||
|
|
@ -318,9 +327,13 @@ async def node_similarity_search(
|
||||||
min_score: float = DEFAULT_MIN_SCORE,
|
min_score: float = DEFAULT_MIN_SCORE,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
|
runtime_query: LiteralString = (
|
||||||
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||||
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
runtime_query
|
||||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
+ """
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||||
|
|
@ -425,23 +438,27 @@ async def community_similarity_search(
|
||||||
min_score=DEFAULT_MIN_SCORE,
|
min_score=DEFAULT_MIN_SCORE,
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
# vector similarity search over entity names
|
# vector similarity search over entity names
|
||||||
|
runtime_query: LiteralString = (
|
||||||
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||||
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
runtime_query
|
||||||
CYPHER runtime = parallel parallelRuntimeSupport=all
|
+ """
|
||||||
MATCH (comm:Community)
|
MATCH (comm:Community)
|
||||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||||
WHERE score > $min_score
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
comm.uuid As uuid,
|
comm.uuid As uuid,
|
||||||
comm.group_id AS group_id,
|
comm.group_id AS group_id,
|
||||||
comm.name AS name,
|
comm.name AS name,
|
||||||
comm.name_embedding AS name_embedding,
|
comm.name_embedding AS name_embedding,
|
||||||
comm.created_at AS created_at,
|
comm.created_at AS created_at,
|
||||||
comm.summary AS summary
|
comm.summary AS summary
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
|
||||||
|
|
@ -21,12 +21,20 @@ from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver, AsyncManagedTransaction
|
||||||
from numpy import dot, sqrt
|
from numpy import dot, sqrt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
from graphiti_core.models.edges.edge_db_queries import (
|
||||||
|
ENTITY_EDGE_SAVE_BULK,
|
||||||
|
EPISODIC_EDGE_SAVE_BULK,
|
||||||
|
)
|
||||||
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
|
ENTITY_NODE_SAVE_BULK,
|
||||||
|
EPISODIC_NODE_SAVE_BULK,
|
||||||
|
)
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
||||||
from graphiti_core.utils import retrieve_episodes
|
from graphiti_core.utils import retrieve_episodes
|
||||||
|
|
@ -75,6 +83,35 @@ async def retrieve_previous_episodes_bulk(
|
||||||
return episode_tuples
|
return episode_tuples
|
||||||
|
|
||||||
|
|
||||||
|
async def add_nodes_and_edges_bulk(
|
||||||
|
driver: AsyncDriver,
|
||||||
|
episodic_nodes: list[EpisodicNode],
|
||||||
|
episodic_edges: list[EpisodicEdge],
|
||||||
|
entity_nodes: list[EntityNode],
|
||||||
|
entity_edges: list[EntityEdge],
|
||||||
|
):
|
||||||
|
async with driver.session() as session:
|
||||||
|
await session.execute_write(
|
||||||
|
add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def add_nodes_and_edges_bulk_tx(
|
||||||
|
tx: AsyncManagedTransaction,
|
||||||
|
episodic_nodes: list[EpisodicNode],
|
||||||
|
episodic_edges: list[EpisodicEdge],
|
||||||
|
entity_nodes: list[EntityNode],
|
||||||
|
entity_edges: list[EntityEdge],
|
||||||
|
):
|
||||||
|
episodes = [dict(episode) for episode in episodic_nodes]
|
||||||
|
for episode in episodes:
|
||||||
|
episode['source'] = str(episode['source'].value)
|
||||||
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||||
|
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=[dict(entity) for entity in entity_nodes])
|
||||||
|
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
|
||||||
|
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
|
||||||
|
|
||||||
|
|
||||||
async def extract_nodes_and_edges_bulk(
|
async def extract_nodes_and_edges_bulk(
|
||||||
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
|
||||||
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.3.19"
|
version = "0.3.20"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
authors = [
|
authors = [
|
||||||
"Paul Paliychuk <paul@getzep.com>",
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue