Make default DB explicit (#195)

* add default database

* update

* init tests

* update test

* bump version

* removed unused imports
This commit is contained in:
Preston Rasmussen 2024-10-21 12:33:32 -04:00 committed by GitHub
parent 8b72250f0b
commit b217d1e51f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 142 additions and 58 deletions

View file

@ -26,7 +26,12 @@ from pydantic import BaseModel, Field
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_SAVE,
ENTITY_EDGE_SAVE,
EPISODIC_EDGE_SAVE,
)
from graphiti_core.nodes import Node
logger = logging.getLogger(__name__)
@ -49,6 +54,7 @@ class Edge(BaseModel, ABC):
DELETE e
""",
uuid=self.uuid,
_database=DEFAULT_DATABASE,
)
logger.debug(f'Deleted Edge: {self.uuid}')
@ -70,17 +76,13 @@ class Edge(BaseModel, ABC):
class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid""",
EPISODIC_EDGE_SAVE,
episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to neo4j: {self.uuid}')
@ -100,6 +102,7 @@ class EpisodicEdge(Edge):
e.created_at AS created_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)
edges = [get_episodic_edge_from_record(record) for record in records]
@ -122,6 +125,7 @@ class EpisodicEdge(Edge):
e.created_at AS created_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)
edges = [get_episodic_edge_from_record(record) for record in records]
@ -144,6 +148,7 @@ class EpisodicEdge(Edge):
e.created_at AS created_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)
edges = [get_episodic_edge_from_record(record) for record in records]
@ -184,14 +189,7 @@ class EntityEdge(Edge):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid""",
ENTITY_EDGE_SAVE,
source_uuid=self.source_node_uuid,
target_uuid=self.target_node_uuid,
uuid=self.uuid,
@ -204,6 +202,7 @@ class EntityEdge(Edge):
expired_at=self.expired_at,
valid_at=self.valid_at,
invalid_at=self.invalid_at,
_database=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to neo4j: {self.uuid}')
@ -230,6 +229,7 @@ class EntityEdge(Edge):
e.invalid_at AS invalid_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -259,6 +259,7 @@ class EntityEdge(Edge):
e.invalid_at AS invalid_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -288,6 +289,7 @@ class EntityEdge(Edge):
e.invalid_at AS invalid_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -300,17 +302,13 @@ class EntityEdge(Edge):
class CommunityEdge(Edge):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid""",
COMMUNITY_EDGE_SAVE,
community_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
)
logger.debug(f'Saved edge to neo4j: {self.uuid}')
@ -330,6 +328,7 @@ class CommunityEdge(Edge):
e.created_at AS created_at
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)
edges = [get_community_edge_from_record(record) for record in records]
@ -350,6 +349,7 @@ class CommunityEdge(Edge):
e.created_at AS created_at
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)
edges = [get_community_edge_from_record(record) for record in records]
@ -370,6 +370,7 @@ class CommunityEdge(Edge):
e.created_at AS created_at
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)
edges = [get_community_edge_from_record(record) for record in records]

View file

@ -14,11 +14,14 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from datetime import datetime
import numpy as np
from neo4j import time as neo4j_time
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
return neo_date.to_native() if neo_date else None

View file

View file

View file

@ -0,0 +1,22 @@
EPISODIC_EDGE_SAVE = """
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid"""
ENTITY_EDGE_SAVE = """
MATCH (source:Entity {uuid: $source_uuid})
MATCH (target:Entity {uuid: $target_uuid})
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
RETURN r.uuid AS uuid"""
COMMUNITY_EDGE_SAVE = """
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid"""

View file

View file

@ -0,0 +1,17 @@
EPISODIC_NODE_SAVE = """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid"""
ENTITY_NODE_SAVE = """
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""
COMMUNITY_NODE_SAVE = """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""

View file

@ -27,6 +27,12 @@ from pydantic import BaseModel, Field
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_SAVE,
ENTITY_NODE_SAVE,
EPISODIC_NODE_SAVE,
)
logger = logging.getLogger(__name__)
@ -84,6 +90,7 @@ class Node(BaseModel, ABC):
DETACH DELETE n
""",
uuid=self.uuid,
_database=DEFAULT_DATABASE,
)
logger.debug(f'Deleted Node: {self.uuid}')
@ -119,11 +126,7 @@ class EpisodicNode(Node):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid""",
EPISODIC_NODE_SAVE,
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
@ -133,6 +136,7 @@ class EpisodicNode(Node):
created_at=self.created_at,
valid_at=self.valid_at,
source=self.source.value,
_database=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to neo4j: {self.uuid}')
@ -154,6 +158,7 @@ class EpisodicNode(Node):
e.source AS source
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)
episodes = [get_episodic_node_from_record(record) for record in records]
@ -179,6 +184,7 @@ class EpisodicNode(Node):
e.source AS source
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)
episodes = [get_episodic_node_from_record(record) for record in records]
@ -201,6 +207,7 @@ class EpisodicNode(Node):
e.source AS source
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)
episodes = [get_episodic_node_from_record(record) for record in records]
@ -223,17 +230,14 @@ class EntityNode(Node):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid""",
ENTITY_NODE_SAVE,
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to neo4j: {self.uuid}')
@ -254,6 +258,7 @@ class EntityNode(Node):
n.summary AS summary
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -277,6 +282,7 @@ class EntityNode(Node):
n.summary AS summary
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -297,6 +303,7 @@ class EntityNode(Node):
n.summary AS summary
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -310,17 +317,14 @@ class CommunityNode(Node):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid""",
COMMUNITY_NODE_SAVE,
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
_database=DEFAULT_DATABASE,
)
logger.debug(f'Saved Node to neo4j: {self.uuid}')
@ -350,6 +354,7 @@ class CommunityNode(Node):
n.summary AS summary
""",
uuid=uuid,
_database=DEFAULT_DATABASE,
)
nodes = [get_community_node_from_record(record) for record in records]
@ -373,6 +378,7 @@ class CommunityNode(Node):
n.summary AS summary
""",
uuids=uuids,
_database=DEFAULT_DATABASE,
)
communities = [get_community_node_from_record(record) for record in records]
@ -393,6 +399,7 @@ class CommunityNode(Node):
n.summary AS summary
""",
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)
communities = [get_community_node_from_record(record) for record in records]

View file

@ -23,7 +23,7 @@ import numpy as np
from neo4j import AsyncDriver, Query
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.helpers import lucene_sanitize, normalize_l2
from graphiti_core.helpers import DEFAULT_DATABASE, lucene_sanitize, normalize_l2
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
MAX_QUERY_LENGTH = 512
MAX_QUERY_LENGTH = 128
def fulltext_query(query: str, group_ids: list[str] | None = None):
@ -91,6 +91,7 @@ async def get_mentioned_nodes(
n.summary AS summary
""",
uuids=episode_uuids,
_database=DEFAULT_DATABASE,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -114,6 +115,7 @@ async def get_communities_by_nodes(
c.summary AS summary
""",
uuids=node_uuids,
_database=DEFAULT_DATABASE,
)
communities = [get_community_node_from_record(record) for record in records]
@ -161,6 +163,7 @@ async def edge_fulltext_search(
target_uuid=target_node_uuid,
group_ids=group_ids,
limit=limit,
_database=DEFAULT_DATABASE,
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -211,6 +214,7 @@ async def edge_similarity_search(
group_ids=group_ids,
limit=limit,
min_score=min_score,
_database=DEFAULT_DATABASE,
)
edges = [get_entity_edge_from_record(record) for record in records]
@ -246,6 +250,7 @@ async def node_fulltext_search(
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
_database=DEFAULT_DATABASE,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -281,6 +286,7 @@ async def node_similarity_search(
group_ids=group_ids,
limit=limit,
min_score=min_score,
_database=DEFAULT_DATABASE,
)
nodes = [get_entity_node_from_record(record) for record in records]
@ -315,6 +321,7 @@ async def community_fulltext_search(
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
_database=DEFAULT_DATABASE,
)
communities = [get_community_node_from_record(record) for record in records]
@ -350,6 +357,7 @@ async def community_similarity_search(
group_ids=group_ids,
limit=limit,
min_score=min_score,
_database=DEFAULT_DATABASE,
)
communities = [get_community_node_from_record(record) for record in records]
@ -541,6 +549,7 @@ async def node_distance_reranker(
query,
node_uuid=uuid,
center_uuid=center_node_uuid,
_database=DEFAULT_DATABASE,
)
for uuid in filtered_uuids
]
@ -577,6 +586,7 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
driver.execute_query(
query,
node_uuid=uuid,
_database=DEFAULT_DATABASE,
)
for uuid in sorted_uuids
]

View file

@ -8,8 +8,13 @@ from pydantic import BaseModel
from graphiti_core.edges import CommunityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
get_community_node_from_record,
)
from graphiti_core.prompts import prompt_library
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
@ -29,11 +34,14 @@ async def get_community_clusters(
community_clusters: list[list[EntityNode]] = []
if group_ids is None:
group_id_values, _, _ = await driver.execute_query("""
group_id_values, _, _ = await driver.execute_query(
"""
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
RETURN
collect(DISTINCT n.group_id) AS group_ids
""")
""",
_database=DEFAULT_DATABASE,
)
group_ids = group_id_values[0]['group_ids']
@ -51,6 +59,7 @@ async def get_community_clusters(
""",
uuid=node.uuid,
group_id=group_id,
_database=DEFAULT_DATABASE,
)
projection[node.uuid] = [
@ -209,10 +218,13 @@ async def build_communities(
async def remove_communities(driver: AsyncDriver):
await driver.execute_query("""
await driver.execute_query(
"""
MATCH (c:Community)
DETACH DELETE c
""")
""",
_database=DEFAULT_DATABASE,
)
async def determine_entity_community(
@ -231,6 +243,7 @@ async def determine_entity_community(
c.summary AS summary
""",
entity_uuid=entity.uuid,
_database=DEFAULT_DATABASE,
)
if len(records) > 0:
@ -249,6 +262,7 @@ async def determine_entity_community(
c.summary AS summary
""",
entity_uuid=entity.uuid,
_database=DEFAULT_DATABASE,
)
communities: list[CommunityNode] = [

View file

@ -21,6 +21,7 @@ from datetime import datetime, timezone
from neo4j import AsyncDriver
from typing_extensions import LiteralString
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.nodes import EpisodeType, EpisodicNode
EPISODE_WINDOW_LEN = 3
@ -30,12 +31,22 @@ logger = logging.getLogger(__name__)
async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False):
if delete_existing:
records, _, _ = await driver.execute_query("""
records, _, _ = await driver.execute_query(
"""
SHOW INDEXES YIELD name
""")
""",
_database=DEFAULT_DATABASE,
)
index_names = [record['name'] for record in records]
await asyncio.gather(
*[driver.execute_query("""DROP INDEX $name""", name=name) for name in index_names]
*[
driver.execute_query(
"""DROP INDEX $name""",
name=name,
_database=DEFAULT_DATABASE,
)
for name in index_names
]
)
range_indices: list[LiteralString] = [
@ -71,7 +82,15 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
index_queries: list[LiteralString] = range_indices + fulltext_indices
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
await asyncio.gather(
*[
driver.execute_query(
query,
_database=DEFAULT_DATABASE,
)
for query in index_queries
]
)
async def clear_data(driver: AsyncDriver):
@ -121,6 +140,7 @@ async def retrieve_episodes(
reference_time=reference_time,
num_episodes=last_n,
group_ids=group_ids,
_database=DEFAULT_DATABASE,
)
episodes = [
EpisodicNode(

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.3.13"
version = "0.3.14"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",

View file

@ -75,16 +75,6 @@ async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
edges = await graphiti.search(
'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None
)
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
edges = await graphiti.search('issues with higher ed', group_ids=None)
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None)
pretty_results = {
'edges': [edge.fact for edge in results.edges],