Make default DB explicit (#195)
* add default database * update * init tests * update test * bump version * removed unused imports
This commit is contained in:
parent
8b72250f0b
commit
b217d1e51f
13 changed files with 142 additions and 58 deletions
|
|
@ -26,7 +26,12 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
COMMUNITY_EDGE_SAVE,
|
||||
ENTITY_EDGE_SAVE,
|
||||
EPISODIC_EDGE_SAVE,
|
||||
)
|
||||
from graphiti_core.nodes import Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -49,6 +54,7 @@ class Edge(BaseModel, ABC):
|
|||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Edge: {self.uuid}')
|
||||
|
|
@ -70,17 +76,13 @@ class Edge(BaseModel, ABC):
|
|||
class EpisodicEdge(Edge):
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||
MATCH (node:Entity {uuid: $entity_uuid})
|
||||
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN r.uuid AS uuid""",
|
||||
EPISODIC_EDGE_SAVE,
|
||||
episode_uuid=self.source_node_uuid,
|
||||
entity_uuid=self.target_node_uuid,
|
||||
uuid=self.uuid,
|
||||
group_id=self.group_id,
|
||||
created_at=self.created_at,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||
|
|
@ -100,6 +102,7 @@ class EpisodicEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||
|
|
@ -122,6 +125,7 @@ class EpisodicEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||
|
|
@ -144,6 +148,7 @@ class EpisodicEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||
|
|
@ -184,14 +189,7 @@ class EntityEdge(Edge):
|
|||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (source:Entity {uuid: $source_uuid})
|
||||
MATCH (target:Entity {uuid: $target_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
||||
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
|
||||
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
||||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
||||
RETURN r.uuid AS uuid""",
|
||||
ENTITY_EDGE_SAVE,
|
||||
source_uuid=self.source_node_uuid,
|
||||
target_uuid=self.target_node_uuid,
|
||||
uuid=self.uuid,
|
||||
|
|
@ -204,6 +202,7 @@ class EntityEdge(Edge):
|
|||
expired_at=self.expired_at,
|
||||
valid_at=self.valid_at,
|
||||
invalid_at=self.invalid_at,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||
|
|
@ -230,6 +229,7 @@ class EntityEdge(Edge):
|
|||
e.invalid_at AS invalid_at
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
|
@ -259,6 +259,7 @@ class EntityEdge(Edge):
|
|||
e.invalid_at AS invalid_at
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
|
@ -288,6 +289,7 @@ class EntityEdge(Edge):
|
|||
e.invalid_at AS invalid_at
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
|
@ -300,17 +302,13 @@ class EntityEdge(Edge):
|
|||
class CommunityEdge(Edge):
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN r.uuid AS uuid""",
|
||||
COMMUNITY_EDGE_SAVE,
|
||||
community_uuid=self.source_node_uuid,
|
||||
entity_uuid=self.target_node_uuid,
|
||||
uuid=self.uuid,
|
||||
group_id=self.group_id,
|
||||
created_at=self.created_at,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved edge to neo4j: {self.uuid}')
|
||||
|
|
@ -330,6 +328,7 @@ class CommunityEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_community_edge_from_record(record) for record in records]
|
||||
|
|
@ -350,6 +349,7 @@ class CommunityEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_community_edge_from_record(record) for record in records]
|
||||
|
|
@ -370,6 +370,7 @@ class CommunityEdge(Edge):
|
|||
e.created_at AS created_at
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_community_edge_from_record(record) for record in records]
|
||||
|
|
|
|||
|
|
@ -14,11 +14,14 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
from neo4j import time as neo4j_time
|
||||
|
||||
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
||||
|
||||
|
||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||
return neo_date.to_native() if neo_date else None
|
||||
|
|
|
|||
0
graphiti_core/models/__init__.py
Normal file
0
graphiti_core/models/__init__.py
Normal file
0
graphiti_core/models/edges/__init__.py
Normal file
0
graphiti_core/models/edges/__init__.py
Normal file
22
graphiti_core/models/edges/edge_db_queries.py
Normal file
22
graphiti_core/models/edges/edge_db_queries.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
EPISODIC_EDGE_SAVE = """
|
||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||
MATCH (node:Entity {uuid: $entity_uuid})
|
||||
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN r.uuid AS uuid"""
|
||||
|
||||
ENTITY_EDGE_SAVE = """
|
||||
MATCH (source:Entity {uuid: $source_uuid})
|
||||
MATCH (target:Entity {uuid: $target_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
||||
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, episodes: $episodes,
|
||||
created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at}
|
||||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $fact_embedding)
|
||||
RETURN r.uuid AS uuid"""
|
||||
|
||||
COMMUNITY_EDGE_SAVE = """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN r.uuid AS uuid"""
|
||||
0
graphiti_core/models/nodes/__init__.py
Normal file
0
graphiti_core/models/nodes/__init__.py
Normal file
17
graphiti_core/models/nodes/node_db_queries.py
Normal file
17
graphiti_core/models/nodes/node_db_queries.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
EPISODIC_NODE_SAVE = """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid"""
|
||||
|
||||
ENTITY_NODE_SAVE = """
|
||||
MERGE (n:Entity {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||
RETURN n.uuid AS uuid"""
|
||||
|
||||
COMMUNITY_NODE_SAVE = """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||
RETURN n.uuid AS uuid"""
|
||||
|
|
@ -27,6 +27,12 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import NodeNotFoundError
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
COMMUNITY_NODE_SAVE,
|
||||
ENTITY_NODE_SAVE,
|
||||
EPISODIC_NODE_SAVE,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -84,6 +90,7 @@ class Node(BaseModel, ABC):
|
|||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Node: {self.uuid}')
|
||||
|
|
@ -119,11 +126,7 @@ class EpisodicNode(Node):
|
|||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid""",
|
||||
EPISODIC_NODE_SAVE,
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
|
|
@ -133,6 +136,7 @@ class EpisodicNode(Node):
|
|||
created_at=self.created_at,
|
||||
valid_at=self.valid_at,
|
||||
source=self.source.value,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||
|
|
@ -154,6 +158,7 @@ class EpisodicNode(Node):
|
|||
e.source AS source
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||
|
|
@ -179,6 +184,7 @@ class EpisodicNode(Node):
|
|||
e.source AS source
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||
|
|
@ -201,6 +207,7 @@ class EpisodicNode(Node):
|
|||
e.source AS source
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||
|
|
@ -223,17 +230,14 @@ class EntityNode(Node):
|
|||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MERGE (n:Entity {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||
RETURN n.uuid AS uuid""",
|
||||
ENTITY_NODE_SAVE,
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
summary=self.summary,
|
||||
name_embedding=self.name_embedding,
|
||||
created_at=self.created_at,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||
|
|
@ -254,6 +258,7 @@ class EntityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
|
@ -277,6 +282,7 @@ class EntityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
|
@ -297,6 +303,7 @@ class EntityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
|
@ -310,17 +317,14 @@ class CommunityNode(Node):
|
|||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||
RETURN n.uuid AS uuid""",
|
||||
COMMUNITY_NODE_SAVE,
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
summary=self.summary,
|
||||
name_embedding=self.name_embedding,
|
||||
created_at=self.created_at,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
logger.debug(f'Saved Node to neo4j: {self.uuid}')
|
||||
|
|
@ -350,6 +354,7 @@ class CommunityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
nodes = [get_community_node_from_record(record) for record in records]
|
||||
|
|
@ -373,6 +378,7 @@ class CommunityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
uuids=uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
|
@ -393,6 +399,7 @@ class CommunityNode(Node):
|
|||
n.summary AS summary
|
||||
""",
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import numpy as np
|
|||
from neo4j import AsyncDriver, Query
|
||||
|
||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||
from graphiti_core.helpers import lucene_sanitize, normalize_l2
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, lucene_sanitize, normalize_l2
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
EntityNode,
|
||||
|
|
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
|||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
DEFAULT_MIN_SCORE = 0.6
|
||||
DEFAULT_MMR_LAMBDA = 0.5
|
||||
MAX_QUERY_LENGTH = 512
|
||||
MAX_QUERY_LENGTH = 128
|
||||
|
||||
|
||||
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
||||
|
|
@ -91,6 +91,7 @@ async def get_mentioned_nodes(
|
|||
n.summary AS summary
|
||||
""",
|
||||
uuids=episode_uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
|
@ -114,6 +115,7 @@ async def get_communities_by_nodes(
|
|||
c.summary AS summary
|
||||
""",
|
||||
uuids=node_uuids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
|
@ -161,6 +163,7 @@ async def edge_fulltext_search(
|
|||
target_uuid=target_node_uuid,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
|
@ -211,6 +214,7 @@ async def edge_similarity_search(
|
|||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
|
@ -246,6 +250,7 @@ async def node_fulltext_search(
|
|||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -281,6 +286,7 @@ async def node_similarity_search(
|
|||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -315,6 +321,7 @@ async def community_fulltext_search(
|
|||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -350,6 +357,7 @@ async def community_similarity_search(
|
|||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -541,6 +549,7 @@ async def node_distance_reranker(
|
|||
query,
|
||||
node_uuid=uuid,
|
||||
center_uuid=center_node_uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
for uuid in filtered_uuids
|
||||
]
|
||||
|
|
@ -577,6 +586,7 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
|||
driver.execute_query(
|
||||
query,
|
||||
node_uuid=uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
for uuid in sorted_uuids
|
||||
]
|
||||
|
|
|
|||
|
|
@ -8,8 +8,13 @@ from pydantic import BaseModel
|
|||
|
||||
from graphiti_core.edges import CommunityEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
EntityNode,
|
||||
get_community_node_from_record,
|
||||
)
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
||||
|
||||
|
|
@ -29,11 +34,14 @@ async def get_community_clusters(
|
|||
community_clusters: list[list[EntityNode]] = []
|
||||
|
||||
if group_ids is None:
|
||||
group_id_values, _, _ = await driver.execute_query("""
|
||||
group_id_values, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
|
||||
RETURN
|
||||
collect(DISTINCT n.group_id) AS group_ids
|
||||
""")
|
||||
""",
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
group_ids = group_id_values[0]['group_ids']
|
||||
|
||||
|
|
@ -51,6 +59,7 @@ async def get_community_clusters(
|
|||
""",
|
||||
uuid=node.uuid,
|
||||
group_id=group_id,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
projection[node.uuid] = [
|
||||
|
|
@ -209,10 +218,13 @@ async def build_communities(
|
|||
|
||||
|
||||
async def remove_communities(driver: AsyncDriver):
|
||||
await driver.execute_query("""
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
DETACH DELETE c
|
||||
""")
|
||||
""",
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
|
||||
async def determine_entity_community(
|
||||
|
|
@ -231,6 +243,7 @@ async def determine_entity_community(
|
|||
c.summary AS summary
|
||||
""",
|
||||
entity_uuid=entity.uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
if len(records) > 0:
|
||||
|
|
@ -249,6 +262,7 @@ async def determine_entity_community(
|
|||
c.summary AS summary
|
||||
""",
|
||||
entity_uuid=entity.uuid,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
communities: list[CommunityNode] = [
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from datetime import datetime, timezone
|
|||
from neo4j import AsyncDriver
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
|
||||
EPISODE_WINDOW_LEN = 3
|
||||
|
|
@ -30,12 +31,22 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False):
|
||||
if delete_existing:
|
||||
records, _, _ = await driver.execute_query("""
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
SHOW INDEXES YIELD name
|
||||
""")
|
||||
""",
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
index_names = [record['name'] for record in records]
|
||||
await asyncio.gather(
|
||||
*[driver.execute_query("""DROP INDEX $name""", name=name) for name in index_names]
|
||||
*[
|
||||
driver.execute_query(
|
||||
"""DROP INDEX $name""",
|
||||
name=name,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
for name in index_names
|
||||
]
|
||||
)
|
||||
|
||||
range_indices: list[LiteralString] = [
|
||||
|
|
@ -71,7 +82,15 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
|||
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
||||
|
||||
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
|
||||
await asyncio.gather(
|
||||
*[
|
||||
driver.execute_query(
|
||||
query,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
for query in index_queries
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def clear_data(driver: AsyncDriver):
|
||||
|
|
@ -121,6 +140,7 @@ async def retrieve_episodes(
|
|||
reference_time=reference_time,
|
||||
num_episodes=last_n,
|
||||
group_ids=group_ids,
|
||||
_database=DEFAULT_DATABASE,
|
||||
)
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.3.13"
|
||||
version = "0.3.14"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
|
|
|||
|
|
@ -75,16 +75,6 @@ async def test_graphiti_init():
|
|||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
|
||||
edges = await graphiti.search(
|
||||
'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None
|
||||
)
|
||||
|
||||
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
|
||||
|
||||
edges = await graphiti.search('issues with higher ed', group_ids=None)
|
||||
|
||||
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
||||
|
||||
results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None)
|
||||
pretty_results = {
|
||||
'edges': [edge.fact for edge in results.edges],
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue