Merge branch 'main' into fix-bfs-max-depth-ignored
This commit is contained in:
commit
ecf9a75fbf
25 changed files with 1339 additions and 1068 deletions
27
.github/workflows/unit_tests.yml
vendored
27
.github/workflows/unit_tests.yml
vendored
|
|
@ -20,6 +20,15 @@ jobs:
|
|||
ports:
|
||||
- 6379:6379
|
||||
options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5
|
||||
neo4j:
|
||||
image: neo4j:5.26-community
|
||||
ports:
|
||||
- 7687:7687
|
||||
- 7474:7474
|
||||
env:
|
||||
NEO4J_AUTH: neo4j/testpass
|
||||
NEO4J_PLUGINS: '["apoc"]'
|
||||
options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
|
|
@ -37,15 +46,33 @@ jobs:
|
|||
- name: Run non-integration tests
|
||||
env:
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
NEO4J_URI: bolt://localhost:7687
|
||||
NEO4J_USER: neo4j
|
||||
NEO4J_PASSWORD: testpass
|
||||
run: |
|
||||
uv run pytest -m "not integration"
|
||||
- name: Wait for FalkorDB
|
||||
run: |
|
||||
timeout 60 bash -c 'until redis-cli -h localhost -p 6379 ping; do sleep 1; done'
|
||||
- name: Wait for Neo4j
|
||||
run: |
|
||||
timeout 60 bash -c 'until wget -O /dev/null http://localhost:7474 >/dev/null 2>&1; do sleep 1; done'
|
||||
- name: Run FalkorDB integration tests
|
||||
env:
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
FALKORDB_HOST: localhost
|
||||
FALKORDB_PORT: 6379
|
||||
DISABLE_NEO4J: 1
|
||||
run: |
|
||||
uv run pytest tests/driver/test_falkordb_driver.py
|
||||
- name: Run Neo4j integration tests
|
||||
env:
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
NEO4J_URI: bolt://localhost:7687
|
||||
NEO4J_USER: neo4j
|
||||
NEO4J_PASSWORD: testpass
|
||||
FALKORDB_HOST: localhost
|
||||
FALKORDB_PORT: 6379
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
uv run pytest tests/test_*_int.py -k "neo4j"
|
||||
|
|
|
|||
|
|
@ -18,11 +18,17 @@ import copy
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Coroutine
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphProvider(Enum):
|
||||
NEO4J = 'neo4j'
|
||||
FALKORDB = 'falkordb'
|
||||
|
||||
|
||||
class GraphDriverSession(ABC):
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
|
@ -46,7 +52,7 @@ class GraphDriverSession(ABC):
|
|||
|
||||
|
||||
class GraphDriver(ABC):
|
||||
provider: str
|
||||
provider: GraphProvider
|
||||
fulltext_syntax: str = (
|
||||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||
)
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ else:
|
|||
'Install it with: pip install graphiti-core[falkordb]'
|
||||
) from None
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ class FalkorDriverSession(GraphDriverSession):
|
|||
|
||||
|
||||
class FalkorDriver(GraphDriver):
|
||||
provider: str = 'falkordb'
|
||||
provider = GraphProvider.FALKORDB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -119,7 +119,7 @@ class FalkorDriver(GraphDriver):
|
|||
# check if index already exists
|
||||
logger.info(f'Index already exists: {e}')
|
||||
return None
|
||||
logger.error(f'Error executing FalkorDB query: {e}')
|
||||
logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}')
|
||||
raise
|
||||
|
||||
# Convert the result header to a list of strings
|
||||
|
|
|
|||
|
|
@ -21,13 +21,13 @@ from typing import Any
|
|||
from neo4j import AsyncGraphDatabase, EagerResult
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Neo4jDriver(GraphDriver):
|
||||
provider: str = 'neo4j'
|
||||
provider = GraphProvider.NEO4J
|
||||
|
||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
||||
super().__init__()
|
||||
|
|
@ -45,7 +45,11 @@ class Neo4jDriver(GraphDriver):
|
|||
params = {}
|
||||
params.setdefault('database_', self._database)
|
||||
|
||||
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
||||
try:
|
||||
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f'Error executing Neo4j query: {e}\n{cypher_query_}\n{params}')
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -29,29 +29,17 @@ from graphiti_core.embedder import EmbedderClient
|
|||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
COMMUNITY_EDGE_SAVE,
|
||||
ENTITY_EDGE_SAVE,
|
||||
COMMUNITY_EDGE_RETURN,
|
||||
ENTITY_EDGE_RETURN,
|
||||
EPISODIC_EDGE_RETURN,
|
||||
EPISODIC_EDGE_SAVE,
|
||||
get_community_edge_save_query,
|
||||
get_entity_edge_save_query,
|
||||
)
|
||||
from graphiti_core.nodes import Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENTITY_EDGE_RETURN: LiteralString = """
|
||||
RETURN
|
||||
e.uuid AS uuid,
|
||||
startNode(e).uuid AS source_node_uuid,
|
||||
endNode(e).uuid AS target_node_uuid,
|
||||
e.created_at AS created_at,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.fact AS fact,
|
||||
e.episodes AS episodes,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at,
|
||||
properties(e) AS attributes"""
|
||||
|
||||
|
||||
class Edge(BaseModel, ABC):
|
||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
|
@ -66,9 +54,9 @@ class Edge(BaseModel, ABC):
|
|||
async def delete(self, driver: GraphDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
||||
DELETE e
|
||||
""",
|
||||
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
||||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
|
|
@ -107,14 +95,10 @@ class EpisodicEdge(Edge):
|
|||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
||||
RETURN
|
||||
"""
|
||||
+ EPISODIC_EDGE_RETURN,
|
||||
uuid=uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -129,15 +113,11 @@ class EpisodicEdge(Edge):
|
|||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN
|
||||
"""
|
||||
+ EPISODIC_EDGE_RETURN,
|
||||
uuids=uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -161,19 +141,17 @@ class EpisodicEdge(Edge):
|
|||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
RETURN
|
||||
"""
|
||||
+ EPISODIC_EDGE_RETURN
|
||||
+ """
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
|
|
@ -221,11 +199,14 @@ class EntityEdge(Edge):
|
|||
return self.fact_embedding
|
||||
|
||||
async def load_fact_embedding(self, driver: GraphDriver):
|
||||
query: LiteralString = """
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN e.fact_embedding AS fact_embedding
|
||||
"""
|
||||
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
if len(records) == 0:
|
||||
raise EdgeNotFoundError(self.uuid)
|
||||
|
|
@ -251,7 +232,7 @@ class EntityEdge(Edge):
|
|||
edge_data.update(self.attributes or {})
|
||||
|
||||
result = await driver.execute_query(
|
||||
ENTITY_EDGE_SAVE,
|
||||
get_entity_edge_save_query(driver.provider),
|
||||
edge_data=edge_data,
|
||||
)
|
||||
|
||||
|
|
@ -263,8 +244,9 @@ class EntityEdge(Edge):
|
|||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN,
|
||||
uuid=uuid,
|
||||
routing_='r',
|
||||
|
|
@ -283,9 +265,10 @@ class EntityEdge(Edge):
|
|||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.uuid IN $uuids
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN,
|
||||
uuids=uuids,
|
||||
routing_='r',
|
||||
|
|
@ -314,22 +297,21 @@ class EntityEdge(Edge):
|
|||
else ''
|
||||
)
|
||||
|
||||
query: LiteralString = (
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
+ with_embeddings_query
|
||||
+ """
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
+ limit_query
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
limit=limit,
|
||||
|
|
@ -344,13 +326,15 @@ class EntityEdge(Edge):
|
|||
|
||||
@classmethod
|
||||
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
||||
query: LiteralString = (
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN,
|
||||
node_uuid=node_uuid,
|
||||
routing_='r',
|
||||
)
|
||||
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -360,7 +344,7 @@ class EntityEdge(Edge):
|
|||
class CommunityEdge(Edge):
|
||||
async def save(self, driver: GraphDriver):
|
||||
result = await driver.execute_query(
|
||||
COMMUNITY_EDGE_SAVE,
|
||||
get_community_edge_save_query(driver.provider),
|
||||
community_uuid=self.source_node_uuid,
|
||||
entity_uuid=self.target_node_uuid,
|
||||
uuid=self.uuid,
|
||||
|
|
@ -376,14 +360,10 @@ class CommunityEdge(Edge):
|
|||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_EDGE_RETURN,
|
||||
uuid=uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -396,15 +376,11 @@ class CommunityEdge(Edge):
|
|||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_EDGE_RETURN,
|
||||
uuids=uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -426,19 +402,17 @@ class CommunityEdge(Edge):
|
|||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_EDGE_RETURN
|
||||
+ """
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
|
|
|
|||
|
|
@ -5,16 +5,9 @@ This module provides database-agnostic query generation for Neo4j and FalkorDB,
|
|||
supporting index creation, fulltext search, and bulk operations.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
ENTITY_EDGE_SAVE_BULK,
|
||||
)
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
ENTITY_NODE_SAVE_BULK,
|
||||
)
|
||||
from graphiti_core.driver.driver import GraphProvider
|
||||
|
||||
# Mapping from Neo4j fulltext index names to FalkorDB node labels
|
||||
NEO4J_TO_FALKORDB_MAPPING = {
|
||||
|
|
@ -25,8 +18,8 @@ NEO4J_TO_FALKORDB_MAPPING = {
|
|||
}
|
||||
|
||||
|
||||
def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
||||
if db_type == 'falkordb':
|
||||
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return [
|
||||
# Entity node
|
||||
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
|
||||
|
|
@ -41,109 +34,70 @@ def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
|||
# HAS_MEMBER edge
|
||||
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||
]
|
||||
else:
|
||||
return [
|
||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
||||
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
||||
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
||||
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
||||
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
||||
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
||||
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
||||
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
||||
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
||||
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
||||
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
||||
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
||||
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
||||
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
||||
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
||||
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
||||
]
|
||||
|
||||
return [
|
||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
||||
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
||||
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
||||
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
||||
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
||||
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
||||
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
||||
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
||||
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
||||
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
||||
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
||||
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
||||
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
||||
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
||||
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
||||
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
||||
]
|
||||
|
||||
|
||||
def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
||||
if db_type == 'falkordb':
|
||||
def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return [
|
||||
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
|
||||
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
|
||||
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
|
||||
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
||||
]
|
||||
else:
|
||||
return [
|
||||
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
||||
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
||||
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
||||
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
||||
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
||||
]
|
||||
|
||||
return [
|
||||
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
||||
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
||||
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
||||
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
||||
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
||||
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
||||
]
|
||||
|
||||
|
||||
def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
|
||||
if db_type == 'falkordb':
|
||||
def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
||||
else:
|
||||
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
||||
|
||||
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
||||
|
||||
|
||||
def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
|
||||
if db_type == 'falkordb':
|
||||
def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
||||
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
||||
else:
|
||||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||
|
||||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||
|
||||
|
||||
def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
|
||||
if db_type == 'falkordb':
|
||||
def get_relationships_query(name: str, provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
||||
else:
|
||||
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
||||
|
||||
|
||||
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
|
||||
if db_type == 'falkordb':
|
||||
queries = []
|
||||
for node in nodes:
|
||||
for label in node['labels']:
|
||||
queries.append(
|
||||
(
|
||||
f"""
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {{uuid: node.uuid}})
|
||||
SET n:{label}
|
||||
SET n = node
|
||||
WITH n, node
|
||||
SET n.name_embedding = vecf32(node.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
""",
|
||||
{'nodes': [node]},
|
||||
)
|
||||
)
|
||||
return queries
|
||||
else:
|
||||
return ENTITY_NODE_SAVE_BULK
|
||||
|
||||
|
||||
def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
|
||||
if db_type == 'falkordb':
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
||||
WITH r, edge
|
||||
RETURN edge.uuid AS uuid"""
|
||||
else:
|
||||
return ENTITY_EDGE_SAVE_BULK
|
||||
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import (
|
||||
|
|
@ -93,8 +93,11 @@ load_dotenv()
|
|||
|
||||
class AddEpisodeResults(BaseModel):
|
||||
episode: EpisodicNode
|
||||
episodic_edges: list[EpisodicEdge]
|
||||
nodes: list[EntityNode]
|
||||
edges: list[EntityEdge]
|
||||
communities: list[CommunityNode]
|
||||
community_edges: list[CommunityEdge]
|
||||
|
||||
|
||||
class Graphiti:
|
||||
|
|
@ -520,9 +523,12 @@ class Graphiti:
|
|||
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
|
||||
)
|
||||
|
||||
communities = []
|
||||
community_edges = []
|
||||
|
||||
# Update any communities
|
||||
if update_communities:
|
||||
await semaphore_gather(
|
||||
communities, community_edges = await semaphore_gather(
|
||||
*[
|
||||
update_community(self.driver, self.llm_client, self.embedder, node)
|
||||
for node in nodes
|
||||
|
|
@ -532,7 +538,14 @@ class Graphiti:
|
|||
end = time()
|
||||
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
||||
|
||||
return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
|
||||
return AddEpisodeResults(
|
||||
episode=episode,
|
||||
episodic_edges=episodic_edges,
|
||||
nodes=hydrated_nodes,
|
||||
edges=entity_edges,
|
||||
communities=communities,
|
||||
community_edges=community_edges,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
@ -817,7 +830,9 @@ class Graphiti:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
|
||||
async def build_communities(
|
||||
self, group_ids: list[str] | None = None
|
||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||
"""
|
||||
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
||||
the content of these communities.
|
||||
|
|
@ -846,7 +861,7 @@ class Graphiti:
|
|||
max_coroutines=self.max_coroutines,
|
||||
)
|
||||
|
||||
return community_nodes
|
||||
return community_nodes, community_edges
|
||||
|
||||
async def search(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from numpy._typing import NDArray
|
|||
from pydantic import BaseModel
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphProvider
|
||||
from graphiti_core.errors import GroupIdValidationError
|
||||
|
||||
load_dotenv()
|
||||
|
|
@ -52,12 +53,12 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None
|
|||
)
|
||||
|
||||
|
||||
def get_default_group_id(db_type: str) -> str:
|
||||
def get_default_group_id(provider: GraphProvider) -> str:
|
||||
"""
|
||||
This function differentiates the default group id based on the database type.
|
||||
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
||||
"""
|
||||
if db_type == 'falkordb':
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return '_'
|
||||
else:
|
||||
return ''
|
||||
|
|
|
|||
|
|
@ -14,43 +14,117 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
from graphiti_core.driver.driver import GraphProvider
|
||||
|
||||
EPISODIC_EDGE_SAVE = """
|
||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||
MATCH (node:Entity {uuid: $entity_uuid})
|
||||
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN r.uuid AS uuid"""
|
||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||
MATCH (node:Entity {uuid: $entity_uuid})
|
||||
MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
|
||||
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
EPISODIC_EDGE_SAVE_BULK = """
|
||||
UNWIND $episodic_edges AS edge
|
||||
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
|
||||
MATCH (node:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
|
||||
SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
|
||||
RETURN r.uuid AS uuid
|
||||
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
|
||||
MATCH (node:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
|
||||
SET e = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
ENTITY_EDGE_SAVE = """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET r = $edge_data
|
||||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding)
|
||||
RETURN r.uuid AS uuid"""
|
||||
|
||||
ENTITY_EDGE_SAVE_BULK = """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET r = edge
|
||||
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
||||
RETURN edge.uuid AS uuid
|
||||
EPISODIC_EDGE_RETURN = """
|
||||
e.uuid AS uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
"""
|
||||
|
||||
COMMUNITY_EDGE_SAVE = """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN r.uuid AS uuid"""
|
||||
|
||||
def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
return """
|
||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||
SET e = $edge_data
|
||||
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
||||
WITH r, edge
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
|
||||
return """
|
||||
UNWIND $entity_edges AS edge
|
||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||
SET e = edge
|
||||
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
|
||||
RETURN edge.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
ENTITY_EDGE_RETURN = """
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.group_id AS group_id,
|
||||
e.name AS name,
|
||||
e.fact AS fact,
|
||||
e.episodes AS episodes,
|
||||
e.created_at AS created_at,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at,
|
||||
properties(e) AS attributes
|
||||
"""
|
||||
|
||||
|
||||
def get_community_edge_save_query(provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node {uuid: $entity_uuid})
|
||||
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
return """
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN e.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
COMMUNITY_EDGE_RETURN = """
|
||||
e.uuid AS uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -14,39 +14,120 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphiti_core.driver.driver import GraphProvider
|
||||
|
||||
EPISODIC_NODE_SAVE = """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid"""
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
EPISODIC_NODE_SAVE_BULK = """
|
||||
UNWIND $episodes AS episode
|
||||
MERGE (n:Episodic {uuid: episode.uuid})
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
||||
source: episode.source, content: episode.content,
|
||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
||||
source: episode.source, content: episode.content,
|
||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
ENTITY_NODE_SAVE = """
|
||||
MERGE (n:Entity {uuid: $entity_data.uuid})
|
||||
SET n:$($labels)
|
||||
SET n = $entity_data
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
|
||||
RETURN n.uuid AS uuid"""
|
||||
|
||||
ENTITY_NODE_SAVE_BULK = """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
SET n:$(node.labels)
|
||||
SET n = node
|
||||
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
EPISODIC_NODE_RETURN = """
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source,
|
||||
e.entity_edges AS entity_edges
|
||||
"""
|
||||
|
||||
COMMUNITY_NODE_SAVE = """
|
||||
|
||||
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
SET n:{labels}
|
||||
SET n = $entity_data
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
return f"""
|
||||
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||
SET n:{labels}
|
||||
SET n = $entity_data
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
queries = []
|
||||
for node in nodes:
|
||||
for label in node['labels']:
|
||||
queries.append(
|
||||
(
|
||||
f"""
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {{uuid: node.uuid}})
|
||||
SET n:{label}
|
||||
SET n = node
|
||||
WITH n, node
|
||||
SET n.name_embedding = vecf32(node.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
""",
|
||||
{'nodes': [node]},
|
||||
)
|
||||
)
|
||||
return queries
|
||||
|
||||
return """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
SET n:$(node.labels)
|
||||
SET n = node
|
||||
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
ENTITY_NODE_RETURN = """
|
||||
n.uuid AS uuid,
|
||||
n.name AS name,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
"""
|
||||
|
||||
|
||||
def get_community_node_save_query(provider: GraphProvider) -> str:
|
||||
if provider == GraphProvider.FALKORDB:
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: $name_embedding}
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
return """
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||
RETURN n.uuid AS uuid"""
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
|
||||
COMMUNITY_NODE_RETURN = """
|
||||
n.uuid AS uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.summary AS summary,
|
||||
n.created_at AS created_at
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -25,29 +25,22 @@ from uuid import uuid4
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.errors import NodeNotFoundError
|
||||
from graphiti_core.helpers import parse_db_date
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
COMMUNITY_NODE_SAVE,
|
||||
ENTITY_NODE_SAVE,
|
||||
COMMUNITY_NODE_RETURN,
|
||||
ENTITY_NODE_RETURN,
|
||||
EPISODIC_NODE_RETURN,
|
||||
EPISODIC_NODE_SAVE,
|
||||
get_community_node_save_query,
|
||||
get_entity_node_save_query,
|
||||
)
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENTITY_NODE_RETURN: LiteralString = """
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes"""
|
||||
|
||||
|
||||
class EpisodeType(Enum):
|
||||
"""
|
||||
|
|
@ -96,18 +89,26 @@ class Node(BaseModel, ABC):
|
|||
async def save(self, driver: GraphDriver): ...
|
||||
|
||||
async def delete(self, driver: GraphDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
if driver.provider == GraphProvider.FALKORDB:
|
||||
for label in ['Entity', 'Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
MATCH (n:{label} {{uuid: $uuid}})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
else:
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.debug(f'Deleted Node: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.uuid)
|
||||
|
||||
|
|
@ -118,15 +119,23 @@ class Node(BaseModel, ABC):
|
|||
|
||||
@classmethod
|
||||
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
return 'SUCCESS'
|
||||
if driver.provider == GraphProvider.FALKORDB:
|
||||
for label in ['Entity', 'Episodic', 'Community']:
|
||||
await driver.execute_query(
|
||||
f"""
|
||||
MATCH (n:{label} {{group_id: $group_id}})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
group_id=group_id,
|
||||
)
|
||||
else:
|
||||
await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||
|
|
@ -169,17 +178,10 @@ class EpisodicNode(Node):
|
|||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic {uuid: $uuid})
|
||||
RETURN e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source,
|
||||
e.entity_edges AS entity_edges
|
||||
""",
|
||||
MATCH (e:Episodic {uuid: $uuid})
|
||||
RETURN
|
||||
"""
|
||||
+ EPISODIC_NODE_RETURN,
|
||||
uuid=uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -195,18 +197,11 @@ class EpisodicNode(Node):
|
|||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source,
|
||||
e.entity_edges AS entity_edges
|
||||
""",
|
||||
"""
|
||||
+ EPISODIC_NODE_RETURN,
|
||||
uuids=uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -228,22 +223,17 @@ class EpisodicNode(Node):
|
|||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN DISTINCT
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source,
|
||||
e.entity_edges AS entity_edges
|
||||
ORDER BY e.uuid DESC
|
||||
"""
|
||||
"""
|
||||
+ EPISODIC_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
|
|
@ -259,18 +249,10 @@ class EpisodicNode(Node):
|
|||
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
||||
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
||||
RETURN DISTINCT
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source,
|
||||
e.entity_edges AS entity_edges
|
||||
""",
|
||||
"""
|
||||
+ EPISODIC_NODE_RETURN,
|
||||
entity_node_uuid=entity_node_uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -297,11 +279,14 @@ class EntityNode(Node):
|
|||
return self.name_embedding
|
||||
|
||||
async def load_name_embedding(self, driver: GraphDriver):
|
||||
query: LiteralString = """
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN n.name_embedding AS name_embedding
|
||||
"""
|
||||
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
if len(records) == 0:
|
||||
raise NodeNotFoundError(self.uuid)
|
||||
|
|
@ -317,12 +302,12 @@ class EntityNode(Node):
|
|||
'summary': self.summary,
|
||||
'created_at': self.created_at,
|
||||
}
|
||||
|
||||
entity_data.update(self.attributes or {})
|
||||
|
||||
labels = ':'.join(self.labels + ['Entity'])
|
||||
|
||||
result = await driver.execute_query(
|
||||
ENTITY_NODE_SAVE,
|
||||
labels=self.labels + ['Entity'],
|
||||
get_entity_node_save_query(driver.provider, labels),
|
||||
entity_data=entity_data,
|
||||
)
|
||||
|
||||
|
|
@ -332,14 +317,12 @@ class EntityNode(Node):
|
|||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN,
|
||||
uuid=uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -355,8 +338,10 @@ class EntityNode(Node):
|
|||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid IN $uuids
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN,
|
||||
uuids=uuids,
|
||||
routing_='r',
|
||||
|
|
@ -379,22 +364,26 @@ class EntityNode(Node):
|
|||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||
with_embeddings_query: LiteralString = (
|
||||
""",
|
||||
n.name_embedding AS name_embedding
|
||||
"""
|
||||
n.name_embedding AS name_embedding
|
||||
"""
|
||||
if with_embeddings
|
||||
else ''
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
WHERE n.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ with_embeddings_query
|
||||
+ """
|
||||
ORDER BY n.uuid DESC
|
||||
"""
|
||||
ORDER BY n.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
|
|
@ -413,7 +402,7 @@ class CommunityNode(Node):
|
|||
|
||||
async def save(self, driver: GraphDriver):
|
||||
result = await driver.execute_query(
|
||||
COMMUNITY_NODE_SAVE,
|
||||
get_community_node_save_query(driver.provider),
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
|
|
@ -436,11 +425,14 @@ class CommunityNode(Node):
|
|||
return self.name_embedding
|
||||
|
||||
async def load_name_embedding(self, driver: GraphDriver):
|
||||
query: LiteralString = """
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community {uuid: $uuid})
|
||||
RETURN c.name_embedding AS name_embedding
|
||||
"""
|
||||
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
if len(records) == 0:
|
||||
raise NodeNotFoundError(self.uuid)
|
||||
|
|
@ -451,14 +443,10 @@ class CommunityNode(Node):
|
|||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community {uuid: $uuid})
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
MATCH (n:Community {uuid: $uuid})
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_NODE_RETURN,
|
||||
uuid=uuid,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -474,14 +462,11 @@ class CommunityNode(Node):
|
|||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community) WHERE n.uuid IN $uuids
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
MATCH (n:Community)
|
||||
WHERE n.uuid IN $uuids
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_NODE_RETURN,
|
||||
uuids=uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -503,18 +488,17 @@ class CommunityNode(Node):
|
|||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community) WHERE n.group_id IN $group_ids
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
WHERE n.group_id IN $group_ids
|
||||
"""
|
||||
+ cursor_query
|
||||
+ """
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY n.uuid DESC
|
||||
"""
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY n.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
group_ids=group_ids,
|
||||
uuid=uuid_cursor,
|
||||
|
|
@ -586,6 +570,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|||
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
|
||||
if not nodes: # Handle empty list case
|
||||
return
|
||||
|
||||
name_embeddings = await embedder.create_batch([node.name for node in nodes])
|
||||
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
|
||||
node.name_embedding = name_embedding
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ def edge_search_filter_query_constructor(
|
|||
|
||||
if filters.edge_types is not None:
|
||||
edge_types = filters.edge_types
|
||||
edge_types_filter = '\nAND r.name in $edge_types'
|
||||
edge_types_filter = '\nAND e.name in $edge_types'
|
||||
filter_query += edge_types_filter
|
||||
filter_params['edge_types'] = edge_types
|
||||
|
||||
|
|
@ -88,7 +88,7 @@ def edge_search_filter_query_constructor(
|
|||
filter_params['valid_at_' + str(j)] = date_filter.date
|
||||
|
||||
and_filters = [
|
||||
'(r.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
|
||||
'(e.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
|
||||
for j, date_filter in enumerate(or_list)
|
||||
]
|
||||
and_filter_query = ''
|
||||
|
|
@ -113,7 +113,7 @@ def edge_search_filter_query_constructor(
|
|||
filter_params['invalid_at_' + str(j)] = date_filter.date
|
||||
|
||||
and_filters = [
|
||||
'(r.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
|
||||
'(e.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
|
||||
for j, date_filter in enumerate(or_list)
|
||||
]
|
||||
and_filter_query = ''
|
||||
|
|
@ -138,7 +138,7 @@ def edge_search_filter_query_constructor(
|
|||
filter_params['created_at_' + str(j)] = date_filter.date
|
||||
|
||||
and_filters = [
|
||||
'(r.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
|
||||
'(e.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
|
||||
for j, date_filter in enumerate(or_list)
|
||||
]
|
||||
and_filter_query = ''
|
||||
|
|
@ -163,7 +163,7 @@ def edge_search_filter_query_constructor(
|
|||
filter_params['expired_at_' + str(j)] = date_filter.date
|
||||
|
||||
and_filters = [
|
||||
'(r.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
|
||||
'(e.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
|
||||
for j, date_filter in enumerate(or_list)
|
||||
]
|
||||
and_filter_query = ''
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import numpy as np
|
|||
from numpy._typing import NDArray
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||
from graphiti_core.graph_queries import (
|
||||
get_nodes_query,
|
||||
|
|
@ -36,6 +36,8 @@ from graphiti_core.helpers import (
|
|||
normalize_l2,
|
||||
semaphore_gather,
|
||||
)
|
||||
from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN
|
||||
from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN
|
||||
from graphiti_core.nodes import (
|
||||
ENTITY_NODE_RETURN,
|
||||
CommunityNode,
|
||||
|
|
@ -100,20 +102,13 @@ async def get_mentioned_nodes(
|
|||
) -> list[EntityNode]:
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
|
||||
query = """
|
||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
n.name AS name,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
"""
|
||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
|
||||
WHERE episode.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN,
|
||||
uuids=episode_uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -128,18 +123,13 @@ async def get_communities_by_nodes(
|
|||
) -> list[CommunityNode]:
|
||||
node_uuids = [node.uuid for node in nodes]
|
||||
|
||||
query = """
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
c.uuid As uuid,
|
||||
c.group_id AS group_id,
|
||||
c.name AS name,
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
"""
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
"""
|
||||
MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity)
|
||||
WHERE m.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
"""
|
||||
+ COMMUNITY_NODE_RETURN,
|
||||
uuids=node_uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -164,38 +154,30 @@ async def edge_fulltext_search(
|
|||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
|
||||
query = (
|
||||
get_relationships_query('edge_name_and_fact', db_type=driver.provider)
|
||||
get_relationships_query('edge_name_and_fact', provider=driver.provider)
|
||||
+ """
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
||||
WHERE r.group_id IN $group_ids """
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids """
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH r, score, startNode(r) AS n, endNode(r) AS m
|
||||
WITH e, score, n, m
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at,
|
||||
properties(r) AS attributes
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
params=filter_params,
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
routing_='r',
|
||||
**filter_params,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
|
@ -219,58 +201,47 @@ async def edge_similarity_search(
|
|||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
query_params.update(filter_params)
|
||||
|
||||
group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL'
|
||||
group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
|
||||
if group_ids is not None:
|
||||
group_filter_query += '\nAND r.group_id IN $group_ids'
|
||||
group_filter_query += '\nAND e.group_id IN $group_ids'
|
||||
query_params['group_ids'] = group_ids
|
||||
query_params['source_node_uuid'] = source_node_uuid
|
||||
query_params['target_node_uuid'] = target_node_uuid
|
||||
|
||||
if source_node_uuid is not None:
|
||||
group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
|
||||
query_params['source_uuid'] = source_node_uuid
|
||||
group_filter_query += '\nAND (n.uuid = $source_uuid)'
|
||||
|
||||
if target_node_uuid is not None:
|
||||
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
||||
query_params['target_uuid'] = target_node_uuid
|
||||
group_filter_query += '\nAND (m.uuid = $target_uuid)'
|
||||
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
+ """
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH DISTINCT r, """
|
||||
+ get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
|
||||
WITH DISTINCT e, n, m, """
|
||||
+ get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
|
||||
+ """ AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
startNode(r).uuid AS source_node_uuid,
|
||||
endNode(r).uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at,
|
||||
properties(r) AS attributes
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
)
|
||||
records, header, _ = await driver.execute_query(
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
params=query_params,
|
||||
search_vector=search_vector,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
routing_='r',
|
||||
**query_params,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
|
@ -293,41 +264,31 @@ async def edge_bfs_search(
|
|||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||
|
||||
query = (
|
||||
f"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
|
||||
UNWIND relationships(path) AS rel
|
||||
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
||||
WHERE e.uuid = rel.uuid
|
||||
AND e.group_id IN $group_ids
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,$depth}(n:Entity)
|
||||
UNWIND relationships(path) AS rel
|
||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
||||
WHERE r.uuid = rel.uuid
|
||||
AND r.group_id IN $group_ids
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
startNode(r).uuid AS source_node_uuid,
|
||||
endNode(r).uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at,
|
||||
properties(r) AS attributes
|
||||
LIMIT $limit
|
||||
+ """
|
||||
RETURN DISTINCT
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
+ """
|
||||
LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
params=filter_params,
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
routing_='r',
|
||||
**filter_params,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
|
@ -352,23 +313,27 @@ async def node_fulltext_search(
|
|||
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
|
||||
+ """
|
||||
YIELD node AS n, score
|
||||
WITH n, score
|
||||
LIMIT $limit
|
||||
WHERE n:Entity AND n.group_id IN $group_ids
|
||||
WHERE n:Entity AND n.group_id IN $group_ids
|
||||
WITH n, score
|
||||
LIMIT $limit
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
"""
|
||||
)
|
||||
records, header, _ = await driver.execute_query(
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
params=filter_params,
|
||||
query=fuzzy_query,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
routing_='r',
|
||||
**filter_params,
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
|
@ -406,22 +371,23 @@ async def node_similarity_search(
|
|||
WITH n, """
|
||||
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
||||
+ """ AS score
|
||||
WHERE score > $min_score"""
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
"""
|
||||
)
|
||||
|
||||
records, header, _ = await driver.execute_query(
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
params=query_params,
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
routing_='r',
|
||||
**query_params,
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
|
@ -444,26 +410,29 @@ async def node_bfs_search(
|
|||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||
|
||||
query = (
|
||||
f"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
||||
WHERE n.group_id = origin.group_id
|
||||
AND origin.group_id IN $group_ids
|
||||
"""
|
||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,$depth}(n:Entity)
|
||||
WHERE n.group_id = origin.group_id
|
||||
AND origin.group_id IN $group_ids
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
+ """
|
||||
LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
params=filter_params,
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
depth=bfs_max_depth,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
routing_='r',
|
||||
**filter_params,
|
||||
)
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -489,16 +458,10 @@ async def episode_fulltext_search(
|
|||
MATCH (e:Episodic)
|
||||
WHERE e.uuid = episode.uuid
|
||||
AND e.group_id IN $group_ids
|
||||
RETURN
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source,
|
||||
e.entity_edges AS entity_edges
|
||||
RETURN
|
||||
"""
|
||||
+ EPISODIC_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
|
@ -530,15 +493,12 @@ async def community_fulltext_search(
|
|||
query = (
|
||||
get_nodes_query(driver.provider, 'community_name', '$query')
|
||||
+ """
|
||||
YIELD node AS comm, score
|
||||
WHERE comm.group_id IN $group_ids
|
||||
YIELD node AS n, score
|
||||
WHERE n.group_id IN $group_ids
|
||||
RETURN
|
||||
comm.uuid AS uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary,
|
||||
comm.name_embedding AS name_embedding
|
||||
"""
|
||||
+ COMMUNITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
|
@ -568,39 +528,37 @@ async def community_similarity_search(
|
|||
|
||||
group_filter_query: LiteralString = ''
|
||||
if group_ids is not None:
|
||||
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
||||
group_filter_query += 'WHERE n.group_id IN $group_ids'
|
||||
query_params['group_ids'] = group_ids
|
||||
|
||||
query = (
|
||||
RUNTIME_QUERY
|
||||
+ """
|
||||
MATCH (comm:Community)
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
WITH comm, """
|
||||
+ get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
|
||||
WITH n,
|
||||
"""
|
||||
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
||||
+ """ AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
comm.uuid As uuid,
|
||||
comm.group_id AS group_id,
|
||||
comm.name AS name,
|
||||
comm.created_at AS created_at,
|
||||
comm.summary AS summary,
|
||||
comm.name_embedding AS name_embedding
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
"""
|
||||
+ COMMUNITY_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
routing_='r',
|
||||
**query_params,
|
||||
)
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
|
|
@ -719,8 +677,8 @@ async def get_relevant_nodes(
|
|||
WHERE m.group_id = $group_id
|
||||
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
||||
|
||||
WITH node,
|
||||
top_vector_nodes,
|
||||
WITH node,
|
||||
top_vector_nodes,
|
||||
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
||||
|
||||
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
||||
|
|
@ -728,10 +686,10 @@ async def get_relevant_nodes(
|
|||
UNWIND combined_nodes AS combined_node
|
||||
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
||||
|
||||
RETURN
|
||||
RETURN
|
||||
node.uuid AS search_node_uuid,
|
||||
[x IN deduped_nodes | {
|
||||
uuid: x.uuid,
|
||||
uuid: x.uuid,
|
||||
name: x.name,
|
||||
name_embedding: x.name_embedding,
|
||||
group_id: x.group_id,
|
||||
|
|
@ -755,12 +713,12 @@ async def get_relevant_nodes(
|
|||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
params=query_params,
|
||||
nodes=query_nodes,
|
||||
group_id=group_id,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
routing_='r',
|
||||
**query_params,
|
||||
)
|
||||
|
||||
relevant_nodes_dict: dict[str, list[EntityNode]] = {
|
||||
|
|
@ -825,11 +783,11 @@ async def get_relevant_edges(
|
|||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
params=query_params,
|
||||
edges=[edge.model_dump() for edge in edges],
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
routing_='r',
|
||||
**query_params,
|
||||
)
|
||||
|
||||
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
||||
|
|
@ -895,11 +853,11 @@ async def get_edge_invalidation_candidates(
|
|||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
params=query_params,
|
||||
edges=[edge.model_dump() for edge in edges],
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
routing_='r',
|
||||
**query_params,
|
||||
)
|
||||
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
||||
result['search_edge_uuid']: [
|
||||
|
|
@ -943,18 +901,17 @@ async def node_distance_reranker(
|
|||
scores: dict[str, float] = {center_node_uuid: 0.0}
|
||||
|
||||
# Find the shortest path to center node
|
||||
query = """
|
||||
results, header, _ = await driver.execute_query(
|
||||
"""
|
||||
UNWIND $node_uuids AS node_uuid
|
||||
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
||||
RETURN 1 AS score, node_uuid AS uuid
|
||||
"""
|
||||
results, header, _ = await driver.execute_query(
|
||||
query,
|
||||
""",
|
||||
node_uuids=filtered_uuids,
|
||||
center_uuid=center_node_uuid,
|
||||
routing_='r',
|
||||
)
|
||||
if driver.provider == 'falkordb':
|
||||
if driver.provider == GraphProvider.FALKORDB:
|
||||
results = [dict(zip(header, row, strict=True)) for row in results]
|
||||
|
||||
for result in results:
|
||||
|
|
@ -987,13 +944,12 @@ async def episode_mentions_reranker(
|
|||
scores: dict[str, float] = {}
|
||||
|
||||
# Find the shortest path to center node
|
||||
query = """
|
||||
UNWIND $node_uuids AS node_uuid
|
||||
results, _, _ = await driver.execute_query(
|
||||
"""
|
||||
UNWIND $node_uuids AS node_uuid
|
||||
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
||||
RETURN count(*) AS score, n.uuid AS uuid
|
||||
"""
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
""",
|
||||
node_uuids=sorted_uuids,
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -1053,15 +1009,16 @@ def maximal_marginal_relevance(
|
|||
async def get_embeddings_for_nodes(
|
||||
driver: GraphDriver, nodes: list[EntityNode]
|
||||
) -> dict[str, list[float]]:
|
||||
query: LiteralString = """MATCH (n:Entity)
|
||||
WHERE n.uuid IN $node_uuids
|
||||
RETURN DISTINCT
|
||||
n.uuid AS uuid,
|
||||
n.name_embedding AS name_embedding
|
||||
"""
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query, node_uuids=[node.uuid for node in nodes], routing_='r'
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid IN $node_uuids
|
||||
RETURN DISTINCT
|
||||
n.uuid AS uuid,
|
||||
n.name_embedding AS name_embedding
|
||||
""",
|
||||
node_uuids=[node.uuid for node in nodes],
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
embeddings_dict: dict[str, list[float]] = {}
|
||||
|
|
@ -1077,15 +1034,14 @@ async def get_embeddings_for_nodes(
|
|||
async def get_embeddings_for_communities(
|
||||
driver: GraphDriver, communities: list[CommunityNode]
|
||||
) -> dict[str, list[float]]:
|
||||
query: LiteralString = """MATCH (c:Community)
|
||||
WHERE c.uuid IN $community_uuids
|
||||
RETURN DISTINCT
|
||||
c.uuid AS uuid,
|
||||
c.name_embedding AS name_embedding
|
||||
"""
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
WHERE c.uuid IN $community_uuids
|
||||
RETURN DISTINCT
|
||||
c.uuid AS uuid,
|
||||
c.name_embedding AS name_embedding
|
||||
""",
|
||||
community_uuids=[community.uuid for community in communities],
|
||||
routing_='r',
|
||||
)
|
||||
|
|
@ -1103,15 +1059,14 @@ async def get_embeddings_for_communities(
|
|||
async def get_embeddings_for_edges(
|
||||
driver: GraphDriver, edges: list[EntityEdge]
|
||||
) -> dict[str, list[float]]:
|
||||
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
||||
WHERE e.uuid IN $edge_uuids
|
||||
RETURN DISTINCT
|
||||
e.uuid AS uuid,
|
||||
e.fact_embedding AS fact_embedding
|
||||
"""
|
||||
|
||||
results, _, _ = await driver.execute_query(
|
||||
query,
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
||||
WHERE e.uuid IN $edge_uuids
|
||||
RETURN DISTINCT
|
||||
e.uuid AS uuid,
|
||||
e.fact_embedding AS fact_embedding
|
||||
""",
|
||||
edge_uuids=[edge.uuid for edge in edges],
|
||||
routing_='r',
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,17 +25,15 @@ from typing_extensions import Any
|
|||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.graph_queries import (
|
||||
get_entity_edge_save_bulk_query,
|
||||
get_entity_node_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import normalize_l2, semaphore_gather
|
||||
from graphiti_core.models.edges.edge_db_queries import (
|
||||
EPISODIC_EDGE_SAVE_BULK,
|
||||
get_entity_edge_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.models.nodes.node_db_queries import (
|
||||
EPISODIC_NODE_SAVE_BULK,
|
||||
get_entity_node_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
|
|
@ -158,7 +156,7 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
edges.append(edge_data)
|
||||
|
||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||
entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider)
|
||||
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
|
||||
await tx.run(entity_node_save_bulk, nodes=nodes)
|
||||
await tx.run(
|
||||
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ async def get_community_clusters(
|
|||
group_id_values, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
|
||||
RETURN
|
||||
RETURN
|
||||
collect(DISTINCT n.group_id) AS group_ids
|
||||
""",
|
||||
)
|
||||
|
|
@ -233,10 +233,10 @@ async def determine_entity_community(
|
|||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
||||
RETURN
|
||||
c.uuid As uuid,
|
||||
c.uuid AS uuid,
|
||||
c.name AS name,
|
||||
c.group_id AS group_id,
|
||||
c.created_at AS created_at,
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
entity_uuid=entity.uuid,
|
||||
|
|
@ -250,10 +250,10 @@ async def determine_entity_community(
|
|||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
||||
RETURN
|
||||
c.uuid As uuid,
|
||||
c.uuid AS uuid,
|
||||
c.name AS name,
|
||||
c.group_id AS group_id,
|
||||
c.created_at AS created_at,
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
entity_uuid=entity.uuid,
|
||||
|
|
@ -286,11 +286,11 @@ async def determine_entity_community(
|
|||
|
||||
async def update_community(
|
||||
driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
|
||||
):
|
||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||
community, is_new = await determine_entity_community(driver, entity)
|
||||
|
||||
if community is None:
|
||||
return
|
||||
return [], []
|
||||
|
||||
new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
|
||||
new_name = await generate_summary_description(llm_client, new_summary)
|
||||
|
|
@ -298,10 +298,14 @@ async def update_community(
|
|||
community.summary = new_summary
|
||||
community.name = new_name
|
||||
|
||||
community_edges = []
|
||||
if is_new:
|
||||
community_edge = (build_community_edges([entity], community, utc_now()))[0]
|
||||
await community_edge.save(driver)
|
||||
community_edges.append(community_edge)
|
||||
|
||||
await community.generate_name_embedding(embedder)
|
||||
|
||||
await community.save(driver)
|
||||
|
||||
return [community], community_edges
|
||||
|
|
|
|||
|
|
@ -15,14 +15,15 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||
from graphiti_core.helpers import parse_db_date, semaphore_gather
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
|
||||
|
||||
EPISODE_WINDOW_LEN = 3
|
||||
|
||||
|
|
@ -33,8 +34,8 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
|||
if delete_existing:
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
SHOW INDEXES YIELD name
|
||||
""",
|
||||
SHOW INDEXES YIELD name
|
||||
""",
|
||||
)
|
||||
index_names = [record['name'] for record in records]
|
||||
await semaphore_gather(
|
||||
|
|
@ -108,19 +109,16 @@ async def retrieve_episodes(
|
|||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
MATCH (e:Episodic)
|
||||
WHERE e.valid_at <= $reference_time
|
||||
"""
|
||||
+ group_id_filter
|
||||
+ source_filter
|
||||
+ """
|
||||
RETURN e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.group_id AS group_id,
|
||||
e.name AS name,
|
||||
e.source_description AS source_description,
|
||||
e.source AS source
|
||||
RETURN
|
||||
"""
|
||||
+ EPISODIC_NODE_RETURN
|
||||
+ """
|
||||
ORDER BY e.valid_at DESC
|
||||
LIMIT $num_episodes
|
||||
"""
|
||||
|
|
@ -133,18 +131,5 @@ async def retrieve_episodes(
|
|||
group_ids=group_ids,
|
||||
)
|
||||
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
content=record['content'],
|
||||
created_at=parse_db_date(record['created_at'])
|
||||
or datetime.min.replace(tzinfo=timezone.utc),
|
||||
valid_at=parse_db_date(record['valid_at']) or datetime.min.replace(tzinfo=timezone.utc),
|
||||
uuid=record['uuid'],
|
||||
group_id=record['group_id'],
|
||||
source=EpisodeType.from_str(record['source']),
|
||||
name=record['name'],
|
||||
source_description=record['source_description'],
|
||||
)
|
||||
for record in result
|
||||
]
|
||||
episodes = [get_episodic_node_from_record(record) for record in result]
|
||||
return list(reversed(episodes)) # Return in chronological order
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ name = "graphiti-core"
|
|||
description = "A temporal graph building library"
|
||||
version = "0.18.0"
|
||||
authors = [
|
||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||
{ "name" = "Daniel Chalef", "email" = "daniel@getzep.com" },
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
{ name = "Daniel Chalef", email = "daniel@getzep.com" },
|
||||
]
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.driver.driver import GraphProvider
|
||||
|
||||
try:
|
||||
from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
|
||||
|
||||
|
|
@ -48,7 +50,7 @@ class TestFalkorDriver:
|
|||
driver = FalkorDriver(
|
||||
host='test-host', port='1234', username='test-user', password='test-pass'
|
||||
)
|
||||
assert driver.provider == 'falkordb'
|
||||
assert driver.provider == GraphProvider.FALKORDB
|
||||
mock_falkor_db.assert_called_once_with(
|
||||
host='test-host', port='1234', username='test-user', password='test-pass'
|
||||
)
|
||||
|
|
@ -59,14 +61,14 @@ class TestFalkorDriver:
|
|||
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
|
||||
mock_falkor_db = MagicMock()
|
||||
driver = FalkorDriver(falkor_db=mock_falkor_db)
|
||||
assert driver.provider == 'falkordb'
|
||||
assert driver.provider == GraphProvider.FALKORDB
|
||||
assert driver.client is mock_falkor_db
|
||||
mock_falkor_db_class.assert_not_called()
|
||||
|
||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
||||
def test_provider(self):
|
||||
"""Test driver provider identification."""
|
||||
assert self.driver.provider == 'falkordb'
|
||||
assert self.driver.provider == GraphProvider.FALKORDB
|
||||
|
||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
||||
def test_get_graph_with_name(self):
|
||||
|
|
|
|||
|
|
@ -14,10 +14,68 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.helpers import lucene_sanitize
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HAS_NEO4J = False
|
||||
HAS_FALKORDB = False
|
||||
if os.getenv('DISABLE_NEO4J') is None:
|
||||
try:
|
||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||
|
||||
HAS_NEO4J = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if os.getenv('DISABLE_FALKORDB') is None:
|
||||
try:
|
||||
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
||||
|
||||
HAS_FALKORDB = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
|
||||
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
|
||||
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
|
||||
|
||||
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
|
||||
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
||||
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||
|
||||
|
||||
def get_driver(driver_name: str) -> GraphDriver:
|
||||
if driver_name == 'neo4j':
|
||||
return Neo4jDriver(
|
||||
uri=NEO4J_URI,
|
||||
user=NEO4J_USER,
|
||||
password=NEO4J_PASSWORD,
|
||||
)
|
||||
elif driver_name == 'falkordb':
|
||||
return FalkorDriver(
|
||||
host=FALKORDB_HOST,
|
||||
port=int(FALKORDB_PORT),
|
||||
username=FALKORDB_USER,
|
||||
password=FALKORDB_PASSWORD,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Driver {driver_name} not available')
|
||||
|
||||
|
||||
drivers: list[str] = []
|
||||
if HAS_NEO4J:
|
||||
drivers.append('neo4j')
|
||||
if HAS_FALKORDB:
|
||||
drivers.append('falkordb')
|
||||
|
||||
|
||||
def test_lucene_sanitize():
|
||||
# Call the function with test data
|
||||
|
|
|
|||
384
tests/test_edge_int.py
Normal file
384
tests/test_edge_int.py
Normal file
|
|
@ -0,0 +1,384 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder.openai import OpenAIEmbedder
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||
from tests.helpers_test import drivers, get_driver
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
group_id = f'test_group_{str(uuid4())}'
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_episodic_edge(driver):
|
||||
graph_driver = get_driver(driver)
|
||||
embedder = OpenAIEmbedder()
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
episode_node = EpisodicNode(
|
||||
name='test_episode',
|
||||
labels=[],
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source=EpisodeType.message,
|
||||
source_description='conversation message',
|
||||
content='Alice likes Bob',
|
||||
entity_edges=[],
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
node_count = await get_node_count(graph_driver, episode_node.uuid)
|
||||
assert node_count == 0
|
||||
await episode_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, episode_node.uuid)
|
||||
assert node_count == 1
|
||||
|
||||
alice_node = EntityNode(
|
||||
name='Alice',
|
||||
labels=[],
|
||||
created_at=now,
|
||||
summary='Alice summary',
|
||||
group_id=group_id,
|
||||
)
|
||||
await alice_node.generate_name_embedding(embedder)
|
||||
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
assert node_count == 0
|
||||
await alice_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
assert node_count == 1
|
||||
|
||||
episodic_edge = EpisodicEdge(
|
||||
source_node_uuid=episode_node.uuid,
|
||||
target_node_uuid=alice_node.uuid,
|
||||
created_at=now,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
|
||||
assert edge_count == 0
|
||||
await episodic_edge.save(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
|
||||
assert edge_count == 1
|
||||
|
||||
retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
|
||||
assert retrieved.uuid == episodic_edge.uuid
|
||||
assert retrieved.source_node_uuid == episode_node.uuid
|
||||
assert retrieved.target_node_uuid == alice_node.uuid
|
||||
assert retrieved.created_at == now
|
||||
assert retrieved.group_id == group_id
|
||||
|
||||
retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == episodic_edge.uuid
|
||||
assert retrieved[0].source_node_uuid == episode_node.uuid
|
||||
assert retrieved[0].target_node_uuid == alice_node.uuid
|
||||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == episodic_edge.uuid
|
||||
assert retrieved[0].source_node_uuid == episode_node.uuid
|
||||
assert retrieved[0].target_node_uuid == alice_node.uuid
|
||||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
await episodic_edge.delete(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
|
||||
assert edge_count == 0
|
||||
|
||||
await episode_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, episode_node.uuid)
|
||||
assert node_count == 0
|
||||
|
||||
await alice_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
assert node_count == 0
|
||||
|
||||
await graph_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_entity_edge(driver):
|
||||
graph_driver = get_driver(driver)
|
||||
embedder = OpenAIEmbedder()
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
alice_node = EntityNode(
|
||||
name='Alice',
|
||||
labels=[],
|
||||
created_at=now,
|
||||
summary='Alice summary',
|
||||
group_id=group_id,
|
||||
)
|
||||
await alice_node.generate_name_embedding(embedder)
|
||||
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
assert node_count == 0
|
||||
await alice_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
assert node_count == 1
|
||||
|
||||
bob_node = EntityNode(
|
||||
name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
|
||||
)
|
||||
await bob_node.generate_name_embedding(embedder)
|
||||
|
||||
node_count = await get_node_count(graph_driver, bob_node.uuid)
|
||||
assert node_count == 0
|
||||
await bob_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, bob_node.uuid)
|
||||
assert node_count == 1
|
||||
|
||||
entity_edge = EntityEdge(
|
||||
source_node_uuid=alice_node.uuid,
|
||||
target_node_uuid=bob_node.uuid,
|
||||
created_at=now,
|
||||
name='likes',
|
||||
fact='Alice likes Bob',
|
||||
episodes=[],
|
||||
expired_at=now,
|
||||
valid_at=now,
|
||||
invalid_at=now,
|
||||
group_id=group_id,
|
||||
)
|
||||
edge_embedding = await entity_edge.generate_embedding(embedder)
|
||||
|
||||
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
|
||||
assert edge_count == 0
|
||||
await entity_edge.save(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
|
||||
assert edge_count == 1
|
||||
|
||||
retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
|
||||
assert retrieved.uuid == entity_edge.uuid
|
||||
assert retrieved.source_node_uuid == alice_node.uuid
|
||||
assert retrieved.target_node_uuid == bob_node.uuid
|
||||
assert retrieved.created_at == now
|
||||
assert retrieved.group_id == group_id
|
||||
|
||||
retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == entity_edge.uuid
|
||||
assert retrieved[0].source_node_uuid == alice_node.uuid
|
||||
assert retrieved[0].target_node_uuid == bob_node.uuid
|
||||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == entity_edge.uuid
|
||||
assert retrieved[0].source_node_uuid == alice_node.uuid
|
||||
assert retrieved[0].target_node_uuid == bob_node.uuid
|
||||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == entity_edge.uuid
|
||||
assert retrieved[0].source_node_uuid == alice_node.uuid
|
||||
assert retrieved[0].target_node_uuid == bob_node.uuid
|
||||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
await entity_edge.load_fact_embedding(graph_driver)
|
||||
assert np.allclose(entity_edge.fact_embedding, edge_embedding)
|
||||
|
||||
await entity_edge.delete(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
|
||||
assert edge_count == 0
|
||||
|
||||
await alice_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
assert node_count == 0
|
||||
|
||||
await bob_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, bob_node.uuid)
|
||||
assert node_count == 0
|
||||
|
||||
await graph_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_community_edge(driver):
|
||||
graph_driver = get_driver(driver)
|
||||
embedder = OpenAIEmbedder()
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
community_node_1 = CommunityNode(
|
||||
name='Community A',
|
||||
group_id=group_id,
|
||||
summary='Community A summary',
|
||||
)
|
||||
await community_node_1.generate_name_embedding(embedder)
|
||||
node_count = await get_node_count(graph_driver, community_node_1.uuid)
|
||||
assert node_count == 0
|
||||
await community_node_1.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, community_node_1.uuid)
|
||||
assert node_count == 1
|
||||
|
||||
community_node_2 = CommunityNode(
|
||||
name='Community B',
|
||||
group_id=group_id,
|
||||
summary='Community B summary',
|
||||
)
|
||||
await community_node_2.generate_name_embedding(embedder)
|
||||
node_count = await get_node_count(graph_driver, community_node_2.uuid)
|
||||
assert node_count == 0
|
||||
await community_node_2.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, community_node_2.uuid)
|
||||
assert node_count == 1
|
||||
|
||||
alice_node = EntityNode(
|
||||
name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
|
||||
)
|
||||
await alice_node.generate_name_embedding(embedder)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
assert node_count == 0
|
||||
await alice_node.save(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
assert node_count == 1
|
||||
|
||||
community_edge = CommunityEdge(
|
||||
source_node_uuid=community_node_1.uuid,
|
||||
target_node_uuid=community_node_2.uuid,
|
||||
created_at=now,
|
||||
group_id=group_id,
|
||||
)
|
||||
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
|
||||
assert edge_count == 0
|
||||
await community_edge.save(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
|
||||
assert edge_count == 1
|
||||
|
||||
retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
|
||||
assert retrieved.uuid == community_edge.uuid
|
||||
assert retrieved.source_node_uuid == community_node_1.uuid
|
||||
assert retrieved.target_node_uuid == community_node_2.uuid
|
||||
assert retrieved.created_at == now
|
||||
assert retrieved.group_id == group_id
|
||||
|
||||
retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == community_edge.uuid
|
||||
assert retrieved[0].source_node_uuid == community_node_1.uuid
|
||||
assert retrieved[0].target_node_uuid == community_node_2.uuid
|
||||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == community_edge.uuid
|
||||
assert retrieved[0].source_node_uuid == community_node_1.uuid
|
||||
assert retrieved[0].target_node_uuid == community_node_2.uuid
|
||||
assert retrieved[0].created_at == now
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
await community_edge.delete(graph_driver)
|
||||
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
|
||||
assert edge_count == 0
|
||||
|
||||
await alice_node.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||
assert node_count == 0
|
||||
|
||||
await community_node_1.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, community_node_1.uuid)
|
||||
assert node_count == 0
|
||||
|
||||
await community_node_2.delete(graph_driver)
|
||||
node_count = await get_node_count(graph_driver, community_node_2.uuid)
|
||||
assert node_count == 0
|
||||
|
||||
await graph_driver.close()
|
||||
|
||||
|
||||
async def get_node_count(driver: GraphDriver, uuid: str):
|
||||
results, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n {uuid: $uuid})
|
||||
RETURN COUNT(n) as count
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
return int(results[0]['count'])
|
||||
|
||||
|
||||
async def get_edge_count(driver: GraphDriver, uuid: str):
|
||||
results, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e {uuid: $uuid}]->(m)
|
||||
RETURN COUNT(e) as count
|
||||
UNION ALL
|
||||
MATCH (n)-[e:RELATES_TO]->(m {uuid: $uuid})-[e2:RELATES_TO]->(m2)
|
||||
RETURN COUNT(m) as count
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
return sum(int(result['count']) for result in results)
|
||||
|
|
@ -14,26 +14,18 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
from graphiti_core.helpers import validate_excluded_entity_types
|
||||
from tests.helpers_test import drivers, get_driver
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
NEO4J_URI = os.getenv('NEO4J_URI')
|
||||
NEO4J_USER = os.getenv('NEO4J_USER')
|
||||
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
|
||||
|
||||
|
||||
# Test entity type definitions
|
||||
class Person(BaseModel):
|
||||
|
|
@ -65,9 +57,14 @@ class Location(BaseModel):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exclude_default_entity_type():
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_exclude_default_entity_type(driver):
|
||||
"""Test excluding the default 'Entity' type while keeping custom types."""
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
||||
graphiti = Graphiti(graph_driver=get_driver(driver))
|
||||
|
||||
try:
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
|
@ -118,9 +115,14 @@ async def test_exclude_default_entity_type():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exclude_specific_custom_types():
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_exclude_specific_custom_types(driver):
|
||||
"""Test excluding specific custom entity types while keeping others."""
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
||||
graphiti = Graphiti(graph_driver=get_driver(driver))
|
||||
|
||||
try:
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
|
@ -177,9 +179,14 @@ async def test_exclude_specific_custom_types():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exclude_all_types():
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_exclude_all_types(driver):
|
||||
"""Test excluding all entity types (edge case)."""
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
||||
graphiti = Graphiti(graph_driver=get_driver(driver))
|
||||
|
||||
try:
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
|
@ -221,9 +228,14 @@ async def test_exclude_all_types():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exclude_no_types():
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_exclude_no_types(driver):
|
||||
"""Test normal behavior when no types are excluded (baseline test)."""
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
||||
graphiti = Graphiti(graph_driver=get_driver(driver))
|
||||
|
||||
try:
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
|
@ -299,9 +311,14 @@ def test_validation_invalid_excluded_types():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excluded_types_parameter_validation_in_add_episode():
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_excluded_types_parameter_validation_in_add_episode(driver):
|
||||
"""Test that add_episode validates excluded_entity_types parameter."""
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
||||
graphiti = Graphiti(graph_driver=get_driver(driver))
|
||||
|
||||
try:
|
||||
entity_types = {
|
||||
|
|
|
|||
|
|
@ -1,164 +0,0 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_helpers import search_results_to_context_string
|
||||
|
||||
try:
|
||||
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
||||
|
||||
HAS_FALKORDB = True
|
||||
except ImportError:
|
||||
FalkorDriver = None
|
||||
HAS_FALKORDB = False
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
|
||||
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
||||
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
||||
async def test_graphiti_falkordb_init():
|
||||
logger = setup_logging()
|
||||
|
||||
falkor_driver = FalkorDriver(
|
||||
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
|
||||
)
|
||||
|
||||
graphiti = Graphiti(graph_driver=falkor_driver)
|
||||
|
||||
results = await graphiti.search_(query='Who is the user?')
|
||||
|
||||
pretty_results = search_results_to_context_string(results)
|
||||
|
||||
logger.info(pretty_results)
|
||||
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
||||
async def test_graph_falkordb_integration():
|
||||
falkor_driver = FalkorDriver(
|
||||
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
|
||||
)
|
||||
|
||||
client = Graphiti(graph_driver=falkor_driver)
|
||||
embedder = client.embedder
|
||||
driver = client.driver
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
episode = EpisodicNode(
|
||||
name='test_episode',
|
||||
labels=[],
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source='message',
|
||||
source_description='conversation message',
|
||||
content='Alice likes Bob',
|
||||
entity_edges=[],
|
||||
)
|
||||
|
||||
alice_node = EntityNode(
|
||||
name='Alice',
|
||||
labels=[],
|
||||
created_at=now,
|
||||
summary='Alice summary',
|
||||
)
|
||||
|
||||
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
|
||||
|
||||
episodic_edge_1 = EpisodicEdge(
|
||||
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
|
||||
)
|
||||
|
||||
episodic_edge_2 = EpisodicEdge(
|
||||
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
|
||||
)
|
||||
|
||||
entity_edge = EntityEdge(
|
||||
source_node_uuid=alice_node.uuid,
|
||||
target_node_uuid=bob_node.uuid,
|
||||
created_at=now,
|
||||
name='likes',
|
||||
fact='Alice likes Bob',
|
||||
episodes=[],
|
||||
expired_at=now,
|
||||
valid_at=now,
|
||||
invalid_at=now,
|
||||
)
|
||||
|
||||
await entity_edge.generate_embedding(embedder)
|
||||
|
||||
nodes = [episode, alice_node, bob_node]
|
||||
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
||||
|
||||
# test save
|
||||
await semaphore_gather(*[node.save(driver) for node in nodes])
|
||||
await semaphore_gather(*[edge.save(driver) for edge in edges])
|
||||
|
||||
# test get
|
||||
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
|
||||
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
|
||||
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
|
||||
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
|
||||
|
||||
# test delete
|
||||
await semaphore_gather(*[node.delete(driver) for node in nodes])
|
||||
await semaphore_gather(*[edge.delete(driver) for edge in edges])
|
||||
|
||||
await client.close()
|
||||
|
|
@ -15,31 +15,19 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
|
||||
from graphiti_core.search.search_helpers import search_results_to_context_string
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from tests.helpers_test import drivers, get_driver
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
NEO4J_URI = os.getenv('NEO4J_URI')
|
||||
NEO4j_USER = os.getenv('NEO4J_USER')
|
||||
NEO4j_PASSWORD = os.getenv('NEO4J_PASSWORD')
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
|
|
@ -63,9 +51,18 @@ def setup_logging():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphiti_init():
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_graphiti_init(driver):
|
||||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
driver = get_driver(driver)
|
||||
graphiti = Graphiti(graph_driver=driver)
|
||||
|
||||
await graphiti.build_indices_and_constraints()
|
||||
|
||||
search_filter = SearchFilters(
|
||||
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]]
|
||||
)
|
||||
|
|
@ -76,74 +73,6 @@ async def test_graphiti_init():
|
|||
)
|
||||
|
||||
pretty_results = search_results_to_context_string(results)
|
||||
|
||||
logger.info(pretty_results)
|
||||
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_integration():
|
||||
client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
embedder = client.embedder
|
||||
driver = client.driver
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
episode = EpisodicNode(
|
||||
name='test_episode',
|
||||
labels=[],
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source='message',
|
||||
source_description='conversation message',
|
||||
content='Alice likes Bob',
|
||||
entity_edges=[],
|
||||
)
|
||||
|
||||
alice_node = EntityNode(
|
||||
name='Alice',
|
||||
labels=[],
|
||||
created_at=now,
|
||||
summary='Alice summary',
|
||||
)
|
||||
|
||||
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
|
||||
|
||||
episodic_edge_1 = EpisodicEdge(
|
||||
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
|
||||
)
|
||||
|
||||
episodic_edge_2 = EpisodicEdge(
|
||||
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
|
||||
)
|
||||
|
||||
entity_edge = EntityEdge(
|
||||
source_node_uuid=alice_node.uuid,
|
||||
target_node_uuid=bob_node.uuid,
|
||||
created_at=now,
|
||||
name='likes',
|
||||
fact='Alice likes Bob',
|
||||
episodes=[],
|
||||
expired_at=now,
|
||||
valid_at=now,
|
||||
invalid_at=now,
|
||||
)
|
||||
|
||||
await entity_edge.generate_embedding(embedder)
|
||||
|
||||
nodes = [episode, alice_node, bob_node]
|
||||
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
||||
|
||||
# test save
|
||||
await semaphore_gather(*[node.save(driver) for node in nodes])
|
||||
await semaphore_gather(*[edge.save(driver) for edge in edges])
|
||||
|
||||
# test get
|
||||
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
|
||||
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
|
||||
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
|
||||
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
|
||||
|
||||
# test delete
|
||||
await semaphore_gather(*[node.delete(driver) for node in nodes])
|
||||
await semaphore_gather(*[edge.delete(driver) for edge in edges])
|
||||
|
|
|
|||
|
|
@ -1,139 +0,0 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
EntityNode,
|
||||
EpisodeType,
|
||||
EpisodicNode,
|
||||
)
|
||||
|
||||
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
|
||||
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
||||
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||
|
||||
try:
|
||||
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
||||
|
||||
HAS_FALKORDB = True
|
||||
except ImportError:
|
||||
FalkorDriver = None
|
||||
HAS_FALKORDB = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_entity_node():
|
||||
return EntityNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Test Entity',
|
||||
group_id='test_group',
|
||||
labels=['Entity'],
|
||||
name_embedding=[0.5] * 1024,
|
||||
summary='Entity Summary',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_episodic_node():
|
||||
return EpisodicNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Episode 1',
|
||||
group_id='test_group',
|
||||
source=EpisodeType.text,
|
||||
source_description='Test source',
|
||||
content='Some content here',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_community_node():
|
||||
return CommunityNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Community A',
|
||||
name_embedding=[0.5] * 1024,
|
||||
group_id='test_group',
|
||||
summary='Community summary',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
||||
async def test_entity_node_save_get_and_delete(sample_entity_node):
|
||||
falkor_driver = FalkorDriver(
|
||||
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
|
||||
)
|
||||
|
||||
await sample_entity_node.save(falkor_driver)
|
||||
|
||||
retrieved = await EntityNode.get_by_uuid(falkor_driver, sample_entity_node.uuid)
|
||||
assert retrieved.uuid == sample_entity_node.uuid
|
||||
assert retrieved.name == 'Test Entity'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
|
||||
await sample_entity_node.delete(falkor_driver)
|
||||
await falkor_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
||||
async def test_community_node_save_get_and_delete(sample_community_node):
|
||||
falkor_driver = FalkorDriver(
|
||||
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
|
||||
)
|
||||
|
||||
await sample_community_node.save(falkor_driver)
|
||||
|
||||
retrieved = await CommunityNode.get_by_uuid(falkor_driver, sample_community_node.uuid)
|
||||
assert retrieved.uuid == sample_community_node.uuid
|
||||
assert retrieved.name == 'Community A'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
assert retrieved.summary == 'Community summary'
|
||||
|
||||
await sample_community_node.delete(falkor_driver)
|
||||
await falkor_driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
||||
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
|
||||
falkor_driver = FalkorDriver(
|
||||
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
|
||||
)
|
||||
|
||||
await sample_episodic_node.save(falkor_driver)
|
||||
|
||||
retrieved = await EpisodicNode.get_by_uuid(falkor_driver, sample_episodic_node.uuid)
|
||||
assert retrieved.uuid == sample_episodic_node.uuid
|
||||
assert retrieved.name == 'Episode 1'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
assert retrieved.source == EpisodeType.text
|
||||
assert retrieved.source_description == 'Test source'
|
||||
assert retrieved.content == 'Some content here'
|
||||
|
||||
await sample_episodic_node.delete(falkor_driver)
|
||||
await falkor_driver.close()
|
||||
|
|
@ -14,23 +14,22 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from neo4j import AsyncGraphDatabase
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.nodes import (
|
||||
CommunityNode,
|
||||
EntityNode,
|
||||
EpisodeType,
|
||||
EpisodicNode,
|
||||
)
|
||||
from tests.helpers_test import drivers, get_driver
|
||||
|
||||
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
|
||||
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
|
||||
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
|
||||
group_id = f'test_group_{str(uuid4())}'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -38,8 +37,8 @@ def sample_entity_node():
|
|||
return EntityNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Test Entity',
|
||||
group_id='test_group',
|
||||
labels=['Entity'],
|
||||
group_id=group_id,
|
||||
labels=[],
|
||||
name_embedding=[0.5] * 1024,
|
||||
summary='Entity Summary',
|
||||
)
|
||||
|
|
@ -50,11 +49,11 @@ def sample_episodic_node():
|
|||
return EpisodicNode(
|
||||
uuid=str(uuid4()),
|
||||
name='Episode 1',
|
||||
group_id='test_group',
|
||||
group_id=group_id,
|
||||
source=EpisodeType.text,
|
||||
source_description='Test source',
|
||||
content='Some content here',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
valid_at=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -64,59 +63,181 @@ def sample_community_node():
|
|||
uuid=str(uuid4()),
|
||||
name='Community A',
|
||||
name_embedding=[0.5] * 1024,
|
||||
group_id='test_group',
|
||||
group_id=group_id,
|
||||
summary='Community summary',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_entity_node_save_get_and_delete(sample_entity_node):
|
||||
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
||||
await sample_entity_node.save(neo4j_driver)
|
||||
retrieved = await EntityNode.get_by_uuid(neo4j_driver, sample_entity_node.uuid)
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_entity_node(sample_entity_node, driver):
|
||||
driver = get_driver(driver)
|
||||
uuid = sample_entity_node.uuid
|
||||
|
||||
# Create node
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 0
|
||||
await sample_entity_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 1
|
||||
|
||||
retrieved = await EntityNode.get_by_uuid(driver, sample_entity_node.uuid)
|
||||
assert retrieved.uuid == sample_entity_node.uuid
|
||||
assert retrieved.name == 'Test Entity'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
assert retrieved.group_id == group_id
|
||||
|
||||
await sample_entity_node.delete(neo4j_driver)
|
||||
retrieved = await EntityNode.get_by_uuids(driver, [sample_entity_node.uuid])
|
||||
assert retrieved[0].uuid == sample_entity_node.uuid
|
||||
assert retrieved[0].name == 'Test Entity'
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
await neo4j_driver.close()
|
||||
retrieved = await EntityNode.get_by_group_ids(driver, [group_id], limit=2)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == sample_entity_node.uuid
|
||||
assert retrieved[0].name == 'Test Entity'
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
await sample_entity_node.load_name_embedding(driver)
|
||||
assert np.allclose(sample_entity_node.name_embedding, [0.5] * 1024)
|
||||
|
||||
# Delete node by uuid
|
||||
await sample_entity_node.delete(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 0
|
||||
|
||||
# Delete node by group id
|
||||
await sample_entity_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 1
|
||||
await sample_entity_node.delete_by_group_id(driver, group_id)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 0
|
||||
|
||||
await driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_community_node_save_get_and_delete(sample_community_node):
|
||||
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_community_node(sample_community_node, driver):
|
||||
driver = get_driver(driver)
|
||||
uuid = sample_community_node.uuid
|
||||
|
||||
await sample_community_node.save(neo4j_driver)
|
||||
# Create node
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 0
|
||||
await sample_community_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 1
|
||||
|
||||
retrieved = await CommunityNode.get_by_uuid(neo4j_driver, sample_community_node.uuid)
|
||||
retrieved = await CommunityNode.get_by_uuid(driver, sample_community_node.uuid)
|
||||
assert retrieved.uuid == sample_community_node.uuid
|
||||
assert retrieved.name == 'Community A'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
assert retrieved.group_id == group_id
|
||||
assert retrieved.summary == 'Community summary'
|
||||
|
||||
await sample_community_node.delete(neo4j_driver)
|
||||
retrieved = await CommunityNode.get_by_uuids(driver, [sample_community_node.uuid])
|
||||
assert retrieved[0].uuid == sample_community_node.uuid
|
||||
assert retrieved[0].name == 'Community A'
|
||||
assert retrieved[0].group_id == group_id
|
||||
assert retrieved[0].summary == 'Community summary'
|
||||
|
||||
await neo4j_driver.close()
|
||||
retrieved = await CommunityNode.get_by_group_ids(driver, [group_id], limit=2)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == sample_community_node.uuid
|
||||
assert retrieved[0].name == 'Community A'
|
||||
assert retrieved[0].group_id == group_id
|
||||
|
||||
# Delete node by uuid
|
||||
await sample_community_node.delete(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 0
|
||||
|
||||
# Delete node by group id
|
||||
await sample_community_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 1
|
||||
await sample_community_node.delete_by_group_id(driver, group_id)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 0
|
||||
|
||||
await driver.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
|
||||
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
||||
@pytest.mark.parametrize(
|
||||
'driver',
|
||||
drivers,
|
||||
ids=drivers,
|
||||
)
|
||||
async def test_episodic_node(sample_episodic_node, driver):
|
||||
driver = get_driver(driver)
|
||||
uuid = sample_episodic_node.uuid
|
||||
|
||||
await sample_episodic_node.save(neo4j_driver)
|
||||
# Create node
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 0
|
||||
await sample_episodic_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 1
|
||||
|
||||
retrieved = await EpisodicNode.get_by_uuid(neo4j_driver, sample_episodic_node.uuid)
|
||||
retrieved = await EpisodicNode.get_by_uuid(driver, sample_episodic_node.uuid)
|
||||
assert retrieved.uuid == sample_episodic_node.uuid
|
||||
assert retrieved.name == 'Episode 1'
|
||||
assert retrieved.group_id == 'test_group'
|
||||
assert retrieved.group_id == group_id
|
||||
assert retrieved.source == EpisodeType.text
|
||||
assert retrieved.source_description == 'Test source'
|
||||
assert retrieved.content == 'Some content here'
|
||||
assert retrieved.valid_at == sample_episodic_node.valid_at
|
||||
|
||||
await sample_episodic_node.delete(neo4j_driver)
|
||||
retrieved = await EpisodicNode.get_by_uuids(driver, [sample_episodic_node.uuid])
|
||||
assert retrieved[0].uuid == sample_episodic_node.uuid
|
||||
assert retrieved[0].name == 'Episode 1'
|
||||
assert retrieved[0].group_id == group_id
|
||||
assert retrieved[0].source == EpisodeType.text
|
||||
assert retrieved[0].source_description == 'Test source'
|
||||
assert retrieved[0].content == 'Some content here'
|
||||
assert retrieved[0].valid_at == sample_episodic_node.valid_at
|
||||
|
||||
await neo4j_driver.close()
|
||||
retrieved = await EpisodicNode.get_by_group_ids(driver, [group_id], limit=2)
|
||||
assert len(retrieved) == 1
|
||||
assert retrieved[0].uuid == sample_episodic_node.uuid
|
||||
assert retrieved[0].name == 'Episode 1'
|
||||
assert retrieved[0].group_id == group_id
|
||||
assert retrieved[0].source == EpisodeType.text
|
||||
assert retrieved[0].source_description == 'Test source'
|
||||
assert retrieved[0].content == 'Some content here'
|
||||
assert retrieved[0].valid_at == sample_episodic_node.valid_at
|
||||
|
||||
# Delete node by uuid
|
||||
await sample_episodic_node.delete(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 0
|
||||
|
||||
# Delete node by group id
|
||||
await sample_episodic_node.save(driver)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 1
|
||||
await sample_episodic_node.delete_by_group_id(driver, group_id)
|
||||
node_count = await get_node_count(driver, uuid)
|
||||
assert node_count == 0
|
||||
|
||||
await driver.close()
|
||||
|
||||
|
||||
async def get_node_count(driver: GraphDriver, uuid: str):
|
||||
result, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n {uuid: $uuid})
|
||||
RETURN COUNT(n) as count
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
return int(result[0]['count'])
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue