Make default DB explicit (#195)
* add default database * update * init tests * update test * bump version * removed unused imports
This commit is contained in:
parent
8b72250f0b
commit
b217d1e51f
13 changed files with 142 additions and 58 deletions
|
|
@ -26,7 +26,12 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||||
from graphiti_core.helpers import parse_db_date
|
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
||||||
|
from graphiti_core.models.edges.edge_db_queries import (
|
||||||
|
COMMUNITY_EDGE_SAVE,
|
||||||
|
ENTITY_EDGE_SAVE,
|
||||||
|
EPISODIC_EDGE_SAVE,
|
||||||
|
)
|
||||||
from graphiti_core.nodes import Node
|
from graphiti_core.nodes import Node
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -49,6 +54,7 @@ class Edge(BaseModel, ABC):
|
||||||
DELETE e
|
DELETE e
|
||||||
""",
|
""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Deleted Edge: {self.uuid}')
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
||||||
|
|
@ -70,17 +76,13 @@ class Edge(BaseModel, ABC):
|
||||||
class EpisodicEdge(Edge):
|
class EpisodicEdge(Edge):
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
EPISODIC_EDGE_SAVE,
|
||||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
|
||||||
MATCH (node:Entity {uuid: $entity_uuid})
|
|
||||||
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
|
||||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
|
||||||
RETURN r.uuid AS uuid""",
|
|
||||||
episode_uuid=self.source_node_uuid,
|
episode_uuid=self.source_node_uuid,
|
||||||
entity_uuid=self.target_node_uuid,
|
entity_uuid=self.target_node_uuid,
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
group_id=self.group_id,
|
group_id=self.group_id,
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||||
|
|
@ -100,6 +102,7 @@ class EpisodicEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||||
|
|
@ -122,6 +125,7 @@ class EpisodicEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||||
|
|
@ -144,6 +148,7 @@ class EpisodicEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||||
|
|
@ -184,14 +189,7 @@ class EntityEdge(Edge):
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
ENTITY_EDGE_SAVE,
|
||||||
MATCH (source:Entity {uuid: $source_uuid})
|
|
||||||
MATCH (target:Entity {uuid: $target_uuid})
|
|
||||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
|
||||||
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
|
|
||||||
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
|
||||||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
|
||||||
RETURN r.uuid AS uuid""",
|
|
||||||
source_uuid=self.source_node_uuid,
|
source_uuid=self.source_node_uuid,
|
||||||
target_uuid=self.target_node_uuid,
|
target_uuid=self.target_node_uuid,
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
|
|
@ -204,6 +202,7 @@ class EntityEdge(Edge):
|
||||||
expired_at=self.expired_at,
|
expired_at=self.expired_at,
|
||||||
valid_at=self.valid_at,
|
valid_at=self.valid_at,
|
||||||
invalid_at=self.invalid_at,
|
invalid_at=self.invalid_at,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||||
|
|
@ -230,6 +229,7 @@ class EntityEdge(Edge):
|
||||||
e.invalid_at AS invalid_at
|
e.invalid_at AS invalid_at
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -259,6 +259,7 @@ class EntityEdge(Edge):
|
||||||
e.invalid_at AS invalid_at
|
e.invalid_at AS invalid_at
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -288,6 +289,7 @@ class EntityEdge(Edge):
|
||||||
e.invalid_at AS invalid_at
|
e.invalid_at AS invalid_at
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -300,17 +302,13 @@ class EntityEdge(Edge):
|
||||||
class CommunityEdge(Edge):
|
class CommunityEdge(Edge):
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
COMMUNITY_EDGE_SAVE,
|
||||||
MATCH (community:Community {uuid: $community_uuid})
|
|
||||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
|
||||||
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
|
||||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
|
||||||
RETURN r.uuid AS uuid""",
|
|
||||||
community_uuid=self.source_node_uuid,
|
community_uuid=self.source_node_uuid,
|
||||||
entity_uuid=self.target_node_uuid,
|
entity_uuid=self.target_node_uuid,
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
group_id=self.group_id,
|
group_id=self.group_id,
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||||
|
|
@ -330,6 +328,7 @@ class CommunityEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_community_edge_from_record(record) for record in records]
|
edges = [get_community_edge_from_record(record) for record in records]
|
||||||
|
|
@ -350,6 +349,7 @@ class CommunityEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_community_edge_from_record(record) for record in records]
|
edges = [get_community_edge_from_record(record) for record in records]
|
||||||
|
|
@ -370,6 +370,7 @@ class CommunityEdge(Edge):
|
||||||
e.created_at AS created_at
|
e.created_at AS created_at
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_community_edge_from_record(record) for record in records]
|
edges = [get_community_edge_from_record(record) for record in records]
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,14 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from neo4j import time as neo4j_time
|
from neo4j import time as neo4j_time
|
||||||
|
|
||||||
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
||||||
|
|
||||||
|
|
||||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||||
return neo_date.to_native() if neo_date else None
|
return neo_date.to_native() if neo_date else None
|
||||||
|
|
|
||||||
0
graphiti_core/models/__init__.py
Normal file
0
graphiti_core/models/__init__.py
Normal file
0
graphiti_core/models/edges/__init__.py
Normal file
0
graphiti_core/models/edges/__init__.py
Normal file
22
graphiti_core/models/edges/edge_db_queries.py
Normal file
22
graphiti_core/models/edges/edge_db_queries.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
EPISODIC_EDGE_SAVE = """
|
||||||
|
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||||
|
MATCH (node:Entity {uuid: $entity_uuid})
|
||||||
|
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
||||||
|
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||||
|
RETURN r.uuid AS uuid"""
|
||||||
|
|
||||||
|
ENTITY_EDGE_SAVE = """
|
||||||
|
MATCH (source:Entity {uuid: $source_uuid})
|
||||||
|
MATCH (target:Entity {uuid: $target_uuid})
|
||||||
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
||||||
|
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
|
||||||
|
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
||||||
|
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
||||||
|
RETURN r.uuid AS uuid"""
|
||||||
|
|
||||||
|
COMMUNITY_EDGE_SAVE = """
|
||||||
|
MATCH (community:Community {uuid: $community_uuid})
|
||||||
|
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||||
|
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||||
|
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||||
|
RETURN r.uuid AS uuid"""
|
||||||
0
graphiti_core/models/nodes/__init__.py
Normal file
0
graphiti_core/models/nodes/__init__.py
Normal file
17
graphiti_core/models/nodes/node_db_queries.py
Normal file
17
graphiti_core/models/nodes/node_db_queries.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
EPISODIC_NODE_SAVE = """
|
||||||
|
MERGE (n:Episodic {uuid: $uuid})
|
||||||
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||||
|
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||||
|
RETURN n.uuid AS uuid"""
|
||||||
|
|
||||||
|
ENTITY_NODE_SAVE = """
|
||||||
|
MERGE (n:Entity {uuid: $uuid})
|
||||||
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||||
|
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||||
|
RETURN n.uuid AS uuid"""
|
||||||
|
|
||||||
|
COMMUNITY_NODE_SAVE = """
|
||||||
|
MERGE (n:Community {uuid: $uuid})
|
||||||
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||||
|
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||||
|
RETURN n.uuid AS uuid"""
|
||||||
|
|
@ -27,6 +27,12 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import NodeNotFoundError
|
from graphiti_core.errors import NodeNotFoundError
|
||||||
|
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||||
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
|
COMMUNITY_NODE_SAVE,
|
||||||
|
ENTITY_NODE_SAVE,
|
||||||
|
EPISODIC_NODE_SAVE,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -84,6 +90,7 @@ class Node(BaseModel, ABC):
|
||||||
DETACH DELETE n
|
DETACH DELETE n
|
||||||
""",
|
""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Deleted Node: {self.uuid}')
|
logger.debug(f'Deleted Node: {self.uuid}')
|
||||||
|
|
@ -119,11 +126,7 @@ class EpisodicNode(Node):
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
EPISODIC_NODE_SAVE,
|
||||||
MERGE (n:Episodic {uuid: $uuid})
|
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
|
||||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
|
||||||
RETURN n.uuid AS uuid""",
|
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
group_id=self.group_id,
|
group_id=self.group_id,
|
||||||
|
|
@ -133,6 +136,7 @@ class EpisodicNode(Node):
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
valid_at=self.valid_at,
|
valid_at=self.valid_at,
|
||||||
source=self.source.value,
|
source=self.source.value,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||||
|
|
@ -154,6 +158,7 @@ class EpisodicNode(Node):
|
||||||
e.source AS source
|
e.source AS source
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||||
|
|
@ -179,6 +184,7 @@ class EpisodicNode(Node):
|
||||||
e.source AS source
|
e.source AS source
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||||
|
|
@ -201,6 +207,7 @@ class EpisodicNode(Node):
|
||||||
e.source AS source
|
e.source AS source
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||||
|
|
@ -223,17 +230,14 @@ class EntityNode(Node):
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
ENTITY_NODE_SAVE,
|
||||||
MERGE (n:Entity {uuid: $uuid})
|
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
|
||||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
|
||||||
RETURN n.uuid AS uuid""",
|
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
group_id=self.group_id,
|
group_id=self.group_id,
|
||||||
summary=self.summary,
|
summary=self.summary,
|
||||||
name_embedding=self.name_embedding,
|
name_embedding=self.name_embedding,
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||||
|
|
@ -254,6 +258,7 @@ class EntityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
@ -277,6 +282,7 @@ class EntityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
@ -297,6 +303,7 @@ class EntityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
@ -310,17 +317,14 @@ class CommunityNode(Node):
|
||||||
|
|
||||||
async def save(self, driver: AsyncDriver):
|
async def save(self, driver: AsyncDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
COMMUNITY_NODE_SAVE,
|
||||||
MERGE (n:Community {uuid: $uuid})
|
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
|
||||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
|
||||||
RETURN n.uuid AS uuid""",
|
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
group_id=self.group_id,
|
group_id=self.group_id,
|
||||||
summary=self.summary,
|
summary=self.summary,
|
||||||
name_embedding=self.name_embedding,
|
name_embedding=self.name_embedding,
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||||
|
|
@ -350,6 +354,7 @@ class CommunityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = [get_community_node_from_record(record) for record in records]
|
nodes = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
@ -373,6 +378,7 @@ class CommunityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
@ -393,6 +399,7 @@ class CommunityNode(Node):
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ import numpy as np
|
||||||
from neo4j import AsyncDriver, Query
|
from neo4j import AsyncDriver, Query
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||||
from graphiti_core.helpers import lucene_sanitize, normalize_l2
|
from graphiti_core.helpers import DEFAULT_DATABASE, lucene_sanitize, normalize_l2
|
||||||
from graphiti_core.nodes import (
|
from graphiti_core.nodes import (
|
||||||
CommunityNode,
|
CommunityNode,
|
||||||
EntityNode,
|
EntityNode,
|
||||||
|
|
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
||||||
RELEVANT_SCHEMA_LIMIT = 3
|
RELEVANT_SCHEMA_LIMIT = 3
|
||||||
DEFAULT_MIN_SCORE = 0.6
|
DEFAULT_MIN_SCORE = 0.6
|
||||||
DEFAULT_MMR_LAMBDA = 0.5
|
DEFAULT_MMR_LAMBDA = 0.5
|
||||||
MAX_QUERY_LENGTH = 512
|
MAX_QUERY_LENGTH = 128
|
||||||
|
|
||||||
|
|
||||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||||
|
|
@ -91,6 +91,7 @@ async def get_mentioned_nodes(
|
||||||
n.summary AS summary
|
n.summary AS summary
|
||||||
""",
|
""",
|
||||||
uuids=episode_uuids,
|
uuids=episode_uuids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
@ -114,6 +115,7 @@ async def get_communities_by_nodes(
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
""",
|
""",
|
||||||
uuids=node_uuids,
|
uuids=node_uuids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
@ -161,6 +163,7 @@ async def edge_fulltext_search(
|
||||||
target_uuid=target_node_uuid,
|
target_uuid=target_node_uuid,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -211,6 +214,7 @@ async def edge_similarity_search(
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -246,6 +250,7 @@ async def node_fulltext_search(
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -281,6 +286,7 @@ async def node_similarity_search(
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -315,6 +321,7 @@ async def community_fulltext_search(
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -350,6 +357,7 @@ async def community_similarity_search(
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -541,6 +549,7 @@ async def node_distance_reranker(
|
||||||
query,
|
query,
|
||||||
node_uuid=uuid,
|
node_uuid=uuid,
|
||||||
center_uuid=center_node_uuid,
|
center_uuid=center_node_uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
for uuid in filtered_uuids
|
for uuid in filtered_uuids
|
||||||
]
|
]
|
||||||
|
|
@ -577,6 +586,7 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
||||||
driver.execute_query(
|
driver.execute_query(
|
||||||
query,
|
query,
|
||||||
node_uuid=uuid,
|
node_uuid=uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
for uuid in sorted_uuids
|
for uuid in sorted_uuids
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,13 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from graphiti_core.edges import CommunityEdge
|
from graphiti_core.edges import CommunityEdge
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
|
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
|
from graphiti_core.nodes import (
|
||||||
|
CommunityNode,
|
||||||
|
EntityNode,
|
||||||
|
get_community_node_from_record,
|
||||||
|
)
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
||||||
|
|
||||||
|
|
@ -29,11 +34,14 @@ async def get_community_clusters(
|
||||||
community_clusters: list[list[EntityNode]] = []
|
community_clusters: list[list[EntityNode]] = []
|
||||||
|
|
||||||
if group_ids is None:
|
if group_ids is None:
|
||||||
group_id_values, _, _ = await driver.execute_query("""
|
group_id_values, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
|
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
|
||||||
RETURN
|
RETURN
|
||||||
collect(DISTINCT n.group_id) AS group_ids
|
collect(DISTINCT n.group_id) AS group_ids
|
||||||
""")
|
""",
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
|
)
|
||||||
|
|
||||||
group_ids = group_id_values[0]['group_ids']
|
group_ids = group_id_values[0]['group_ids']
|
||||||
|
|
||||||
|
|
@ -51,6 +59,7 @@ async def get_community_clusters(
|
||||||
""",
|
""",
|
||||||
uuid=node.uuid,
|
uuid=node.uuid,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
projection[node.uuid] = [
|
projection[node.uuid] = [
|
||||||
|
|
@ -209,10 +218,13 @@ async def build_communities(
|
||||||
|
|
||||||
|
|
||||||
async def remove_communities(driver: AsyncDriver):
|
async def remove_communities(driver: AsyncDriver):
|
||||||
await driver.execute_query("""
|
await driver.execute_query(
|
||||||
|
"""
|
||||||
MATCH (c:Community)
|
MATCH (c:Community)
|
||||||
DETACH DELETE c
|
DETACH DELETE c
|
||||||
""")
|
""",
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def determine_entity_community(
|
async def determine_entity_community(
|
||||||
|
|
@ -231,6 +243,7 @@ async def determine_entity_community(
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
""",
|
""",
|
||||||
entity_uuid=entity.uuid,
|
entity_uuid=entity.uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(records) > 0:
|
if len(records) > 0:
|
||||||
|
|
@ -249,6 +262,7 @@ async def determine_entity_community(
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
""",
|
""",
|
||||||
entity_uuid=entity.uuid,
|
entity_uuid=entity.uuid,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
communities: list[CommunityNode] = [
|
communities: list[CommunityNode] = [
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from datetime import datetime, timezone
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
|
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||||
|
|
||||||
EPISODE_WINDOW_LEN = 3
|
EPISODE_WINDOW_LEN = 3
|
||||||
|
|
@ -30,12 +31,22 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False):
|
async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False):
|
||||||
if delete_existing:
|
if delete_existing:
|
||||||
records, _, _ = await driver.execute_query("""
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
SHOW INDEXES YIELD name
|
SHOW INDEXES YIELD name
|
||||||
""")
|
""",
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
|
)
|
||||||
index_names = [record['name'] for record in records]
|
index_names = [record['name'] for record in records]
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[driver.execute_query("""DROP INDEX $name""", name=name) for name in index_names]
|
*[
|
||||||
|
driver.execute_query(
|
||||||
|
"""DROP INDEX $name""",
|
||||||
|
name=name,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
|
)
|
||||||
|
for name in index_names
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
range_indices: list[LiteralString] = [
|
range_indices: list[LiteralString] = [
|
||||||
|
|
@ -71,7 +82,15 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
||||||
|
|
||||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||||
|
|
||||||
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
driver.execute_query(
|
||||||
|
query,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
|
)
|
||||||
|
for query in index_queries
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def clear_data(driver: AsyncDriver):
|
async def clear_data(driver: AsyncDriver):
|
||||||
|
|
@ -121,6 +140,7 @@ async def retrieve_episodes(
|
||||||
reference_time=reference_time,
|
reference_time=reference_time,
|
||||||
num_episodes=last_n,
|
num_episodes=last_n,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
|
_database=DEFAULT_DATABASE,
|
||||||
)
|
)
|
||||||
episodes = [
|
episodes = [
|
||||||
EpisodicNode(
|
EpisodicNode(
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.3.13"
|
version = "0.3.14"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
authors = [
|
authors = [
|
||||||
"Paul Paliychuk <paul@getzep.com>",
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
|
|
||||||
|
|
@ -75,16 +75,6 @@ async def test_graphiti_init():
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||||
|
|
||||||
edges = await graphiti.search(
|
|
||||||
'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
|
|
||||||
|
|
||||||
edges = await graphiti.search('issues with higher ed', group_ids=None)
|
|
||||||
|
|
||||||
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
|
||||||
|
|
||||||
results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None)
|
results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None)
|
||||||
pretty_results = {
|
pretty_results = {
|
||||||
'edges': [edge.fact for edge in results.edges],
|
'edges': [edge.fact for edge in results.edges],
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue