chore/prepare kuzu integration (#762)
* Prepare code * Fix tests * As -> AS, remove trailing spaces * Enable more tests for FalkorDB * Fix more cypher queries * Return all created nodes and edges * Add Neo4j service to unit tests workflow - Introduced Neo4j as a service in the GitHub Actions workflow for unit tests. - Configured Neo4j with appropriate ports, authentication, and health checks. - Updated test steps to include waiting for Neo4j and running integration tests against it. - Set environment variables for Neo4j connection in both non-integration and integration test steps. * Update Neo4j authentication in unit tests workflow - Changed Neo4j authentication password from 'test' to 'testpass' in the GitHub Actions workflow. - Updated health check command to reflect the new password. - Ensured consistency across all test steps that utilize Neo4j credentials. * fix health check * Fix Neo4j integration tests in CI workflow Remove reference to non-existent test_neo4j_driver.py file from test command. Integration tests now run via parametrized tests using the drivers list. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Add OPENAI_API_KEY to Neo4j integration tests Neo4j integration tests require OpenAI API access for LLM functionality. Add the secret environment variable to enable these tests to run properly. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix Neo4j Cypher syntax error in BFS search queries Replace parameter substitution in relationship pattern ranges (*1..$depth) with direct string interpolation (*1..{bfs_max_depth}). Neo4j doesn't allow parameter maps in MATCH pattern ranges - they must be literal values. Fixed in both node_bfs_search and edge_bfs_search functions. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix variable name mismatch in edge_bfs_search query Change relationship variable from 'r' to 'e' to match ENTITY_EDGE_RETURN constant expectations. The ENTITY_EDGE_RETURN constant references variable 'e' for relationships, but the query was using 'r', causing "Variable e not defined" errors. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Isolate database tests in CI workflow - FalkorDB tests: Add DISABLE_NEO4J=1 and remove Neo4j env vars - Neo4j tests: Keep current setup without DISABLE_NEO4J flag This ensures proper test isolation where each test suite only runs against its intended database backend. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Siddhartha Sahu <sid@kuzudb.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
9ceeb54186
commit
dcc9da3f68
25 changed files with 1339 additions and 1068 deletions
27
.github/workflows/unit_tests.yml
vendored
27
.github/workflows/unit_tests.yml
vendored
|
|
@ -20,6 +20,15 @@ jobs:
|
||||||
ports:
|
ports:
|
||||||
- 6379:6379
|
- 6379:6379
|
||||||
options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5
|
options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5
|
||||||
|
neo4j:
|
||||||
|
image: neo4j:5.26-community
|
||||||
|
ports:
|
||||||
|
- 7687:7687
|
||||||
|
- 7474:7474
|
||||||
|
env:
|
||||||
|
NEO4J_AUTH: neo4j/testpass
|
||||||
|
NEO4J_PLUGINS: '["apoc"]'
|
||||||
|
options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
|
|
@ -37,15 +46,33 @@ jobs:
|
||||||
- name: Run non-integration tests
|
- name: Run non-integration tests
|
||||||
env:
|
env:
|
||||||
PYTHONPATH: ${{ github.workspace }}
|
PYTHONPATH: ${{ github.workspace }}
|
||||||
|
NEO4J_URI: bolt://localhost:7687
|
||||||
|
NEO4J_USER: neo4j
|
||||||
|
NEO4J_PASSWORD: testpass
|
||||||
run: |
|
run: |
|
||||||
uv run pytest -m "not integration"
|
uv run pytest -m "not integration"
|
||||||
- name: Wait for FalkorDB
|
- name: Wait for FalkorDB
|
||||||
run: |
|
run: |
|
||||||
timeout 60 bash -c 'until redis-cli -h localhost -p 6379 ping; do sleep 1; done'
|
timeout 60 bash -c 'until redis-cli -h localhost -p 6379 ping; do sleep 1; done'
|
||||||
|
- name: Wait for Neo4j
|
||||||
|
run: |
|
||||||
|
timeout 60 bash -c 'until wget -O /dev/null http://localhost:7474 >/dev/null 2>&1; do sleep 1; done'
|
||||||
- name: Run FalkorDB integration tests
|
- name: Run FalkorDB integration tests
|
||||||
env:
|
env:
|
||||||
PYTHONPATH: ${{ github.workspace }}
|
PYTHONPATH: ${{ github.workspace }}
|
||||||
FALKORDB_HOST: localhost
|
FALKORDB_HOST: localhost
|
||||||
FALKORDB_PORT: 6379
|
FALKORDB_PORT: 6379
|
||||||
|
DISABLE_NEO4J: 1
|
||||||
run: |
|
run: |
|
||||||
uv run pytest tests/driver/test_falkordb_driver.py
|
uv run pytest tests/driver/test_falkordb_driver.py
|
||||||
|
- name: Run Neo4j integration tests
|
||||||
|
env:
|
||||||
|
PYTHONPATH: ${{ github.workspace }}
|
||||||
|
NEO4J_URI: bolt://localhost:7687
|
||||||
|
NEO4J_USER: neo4j
|
||||||
|
NEO4J_PASSWORD: testpass
|
||||||
|
FALKORDB_HOST: localhost
|
||||||
|
FALKORDB_PORT: 6379
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
run: |
|
||||||
|
uv run pytest tests/test_*_int.py -k "neo4j"
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,17 @@ import copy
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphProvider(Enum):
|
||||||
|
NEO4J = 'neo4j'
|
||||||
|
FALKORDB = 'falkordb'
|
||||||
|
|
||||||
|
|
||||||
class GraphDriverSession(ABC):
|
class GraphDriverSession(ABC):
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
@ -46,7 +52,7 @@ class GraphDriverSession(ABC):
|
||||||
|
|
||||||
|
|
||||||
class GraphDriver(ABC):
|
class GraphDriver(ABC):
|
||||||
provider: str
|
provider: GraphProvider
|
||||||
fulltext_syntax: str = (
|
fulltext_syntax: str = (
|
||||||
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ else:
|
||||||
'Install it with: pip install graphiti-core[falkordb]'
|
'Install it with: pip install graphiti-core[falkordb]'
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -71,7 +71,7 @@ class FalkorDriverSession(GraphDriverSession):
|
||||||
|
|
||||||
|
|
||||||
class FalkorDriver(GraphDriver):
|
class FalkorDriver(GraphDriver):
|
||||||
provider: str = 'falkordb'
|
provider = GraphProvider.FALKORDB
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -119,7 +119,7 @@ class FalkorDriver(GraphDriver):
|
||||||
# check if index already exists
|
# check if index already exists
|
||||||
logger.info(f'Index already exists: {e}')
|
logger.info(f'Index already exists: {e}')
|
||||||
return None
|
return None
|
||||||
logger.error(f'Error executing FalkorDB query: {e}')
|
logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Convert the result header to a list of strings
|
# Convert the result header to a list of strings
|
||||||
|
|
|
||||||
|
|
@ -21,13 +21,13 @@ from typing import Any
|
||||||
from neo4j import AsyncGraphDatabase, EagerResult
|
from neo4j import AsyncGraphDatabase, EagerResult
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Neo4jDriver(GraphDriver):
|
class Neo4jDriver(GraphDriver):
|
||||||
provider: str = 'neo4j'
|
provider = GraphProvider.NEO4J
|
||||||
|
|
||||||
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -45,7 +45,11 @@ class Neo4jDriver(GraphDriver):
|
||||||
params = {}
|
params = {}
|
||||||
params.setdefault('database_', self._database)
|
params.setdefault('database_', self._database)
|
||||||
|
|
||||||
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
try:
|
||||||
|
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Error executing Neo4j query: {e}\n{cypher_query_}\n{params}')
|
||||||
|
raise
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,29 +29,17 @@ from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
||||||
from graphiti_core.helpers import parse_db_date
|
from graphiti_core.helpers import parse_db_date
|
||||||
from graphiti_core.models.edges.edge_db_queries import (
|
from graphiti_core.models.edges.edge_db_queries import (
|
||||||
COMMUNITY_EDGE_SAVE,
|
COMMUNITY_EDGE_RETURN,
|
||||||
ENTITY_EDGE_SAVE,
|
ENTITY_EDGE_RETURN,
|
||||||
|
EPISODIC_EDGE_RETURN,
|
||||||
EPISODIC_EDGE_SAVE,
|
EPISODIC_EDGE_SAVE,
|
||||||
|
get_community_edge_save_query,
|
||||||
|
get_entity_edge_save_query,
|
||||||
)
|
)
|
||||||
from graphiti_core.nodes import Node
|
from graphiti_core.nodes import Node
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ENTITY_EDGE_RETURN: LiteralString = """
|
|
||||||
RETURN
|
|
||||||
e.uuid AS uuid,
|
|
||||||
startNode(e).uuid AS source_node_uuid,
|
|
||||||
endNode(e).uuid AS target_node_uuid,
|
|
||||||
e.created_at AS created_at,
|
|
||||||
e.name AS name,
|
|
||||||
e.group_id AS group_id,
|
|
||||||
e.fact AS fact,
|
|
||||||
e.episodes AS episodes,
|
|
||||||
e.expired_at AS expired_at,
|
|
||||||
e.valid_at AS valid_at,
|
|
||||||
e.invalid_at AS invalid_at,
|
|
||||||
properties(e) AS attributes"""
|
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel, ABC):
|
class Edge(BaseModel, ABC):
|
||||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
|
|
@ -66,9 +54,9 @@ class Edge(BaseModel, ABC):
|
||||||
async def delete(self, driver: GraphDriver):
|
async def delete(self, driver: GraphDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
||||||
DELETE e
|
DELETE e
|
||||||
""",
|
""",
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -107,14 +95,10 @@ class EpisodicEdge(Edge):
|
||||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
||||||
RETURN
|
RETURN
|
||||||
e.uuid As uuid,
|
"""
|
||||||
e.group_id AS group_id,
|
+ EPISODIC_EDGE_RETURN,
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
e.created_at AS created_at
|
|
||||||
""",
|
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -129,15 +113,11 @@ class EpisodicEdge(Edge):
|
||||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||||
WHERE e.uuid IN $uuids
|
WHERE e.uuid IN $uuids
|
||||||
RETURN
|
RETURN
|
||||||
e.uuid As uuid,
|
"""
|
||||||
e.group_id AS group_id,
|
+ EPISODIC_EDGE_RETURN,
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
e.created_at AS created_at
|
|
||||||
""",
|
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -161,19 +141,17 @@ class EpisodicEdge(Edge):
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
"""
|
"""
|
||||||
+ cursor_query
|
+ cursor_query
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
e.uuid As uuid,
|
"""
|
||||||
e.group_id AS group_id,
|
+ EPISODIC_EDGE_RETURN
|
||||||
n.uuid AS source_node_uuid,
|
+ """
|
||||||
m.uuid AS target_node_uuid,
|
ORDER BY e.uuid DESC
|
||||||
e.created_at AS created_at
|
"""
|
||||||
ORDER BY e.uuid DESC
|
|
||||||
"""
|
|
||||||
+ limit_query,
|
+ limit_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
uuid=uuid_cursor,
|
uuid=uuid_cursor,
|
||||||
|
|
@ -221,11 +199,14 @@ class EntityEdge(Edge):
|
||||||
return self.fact_embedding
|
return self.fact_embedding
|
||||||
|
|
||||||
async def load_fact_embedding(self, driver: GraphDriver):
|
async def load_fact_embedding(self, driver: GraphDriver):
|
||||||
query: LiteralString = """
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||||
RETURN e.fact_embedding AS fact_embedding
|
RETURN e.fact_embedding AS fact_embedding
|
||||||
"""
|
""",
|
||||||
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
|
uuid=self.uuid,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
|
||||||
if len(records) == 0:
|
if len(records) == 0:
|
||||||
raise EdgeNotFoundError(self.uuid)
|
raise EdgeNotFoundError(self.uuid)
|
||||||
|
|
@ -251,7 +232,7 @@ class EntityEdge(Edge):
|
||||||
edge_data.update(self.attributes or {})
|
edge_data.update(self.attributes or {})
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
ENTITY_EDGE_SAVE,
|
get_entity_edge_save_query(driver.provider),
|
||||||
edge_data=edge_data,
|
edge_data=edge_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -263,8 +244,9 @@ class EntityEdge(Edge):
|
||||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||||
"""
|
RETURN
|
||||||
|
"""
|
||||||
+ ENTITY_EDGE_RETURN,
|
+ ENTITY_EDGE_RETURN,
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
|
@ -283,9 +265,10 @@ class EntityEdge(Edge):
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
WHERE e.uuid IN $uuids
|
WHERE e.uuid IN $uuids
|
||||||
"""
|
RETURN
|
||||||
|
"""
|
||||||
+ ENTITY_EDGE_RETURN,
|
+ ENTITY_EDGE_RETURN,
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
|
@ -314,22 +297,21 @@ class EntityEdge(Edge):
|
||||||
else ''
|
else ''
|
||||||
)
|
)
|
||||||
|
|
||||||
query: LiteralString = (
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
"""
|
"""
|
||||||
+ cursor_query
|
+ cursor_query
|
||||||
|
+ """
|
||||||
|
RETURN
|
||||||
|
"""
|
||||||
+ ENTITY_EDGE_RETURN
|
+ ENTITY_EDGE_RETURN
|
||||||
+ with_embeddings_query
|
+ with_embeddings_query
|
||||||
+ """
|
+ """
|
||||||
ORDER BY e.uuid DESC
|
ORDER BY e.uuid DESC
|
||||||
"""
|
"""
|
||||||
+ limit_query
|
+ limit_query,
|
||||||
)
|
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
|
||||||
query,
|
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
uuid=uuid_cursor,
|
uuid=uuid_cursor,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -344,13 +326,15 @@ class EntityEdge(Edge):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
||||||
query: LiteralString = (
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||||
"""
|
RETURN
|
||||||
+ ENTITY_EDGE_RETURN
|
"""
|
||||||
|
+ ENTITY_EDGE_RETURN,
|
||||||
|
node_uuid=node_uuid,
|
||||||
|
routing_='r',
|
||||||
)
|
)
|
||||||
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
|
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -360,7 +344,7 @@ class EntityEdge(Edge):
|
||||||
class CommunityEdge(Edge):
|
class CommunityEdge(Edge):
|
||||||
async def save(self, driver: GraphDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
COMMUNITY_EDGE_SAVE,
|
get_community_edge_save_query(driver.provider),
|
||||||
community_uuid=self.source_node_uuid,
|
community_uuid=self.source_node_uuid,
|
||||||
entity_uuid=self.target_node_uuid,
|
entity_uuid=self.target_node_uuid,
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
|
|
@ -376,14 +360,10 @@ class CommunityEdge(Edge):
|
||||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
|
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
|
||||||
RETURN
|
RETURN
|
||||||
e.uuid As uuid,
|
"""
|
||||||
e.group_id AS group_id,
|
+ COMMUNITY_EDGE_RETURN,
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
e.created_at AS created_at
|
|
||||||
""",
|
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -396,15 +376,11 @@ class CommunityEdge(Edge):
|
||||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
|
||||||
WHERE e.uuid IN $uuids
|
WHERE e.uuid IN $uuids
|
||||||
RETURN
|
RETURN
|
||||||
e.uuid As uuid,
|
"""
|
||||||
e.group_id AS group_id,
|
+ COMMUNITY_EDGE_RETURN,
|
||||||
n.uuid AS source_node_uuid,
|
|
||||||
m.uuid AS target_node_uuid,
|
|
||||||
e.created_at AS created_at
|
|
||||||
""",
|
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -426,19 +402,17 @@ class CommunityEdge(Edge):
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
|
||||||
WHERE e.group_id IN $group_ids
|
WHERE e.group_id IN $group_ids
|
||||||
"""
|
"""
|
||||||
+ cursor_query
|
+ cursor_query
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
e.uuid As uuid,
|
"""
|
||||||
e.group_id AS group_id,
|
+ COMMUNITY_EDGE_RETURN
|
||||||
n.uuid AS source_node_uuid,
|
+ """
|
||||||
m.uuid AS target_node_uuid,
|
ORDER BY e.uuid DESC
|
||||||
e.created_at AS created_at
|
"""
|
||||||
ORDER BY e.uuid DESC
|
|
||||||
"""
|
|
||||||
+ limit_query,
|
+ limit_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
uuid=uuid_cursor,
|
uuid=uuid_cursor,
|
||||||
|
|
|
||||||
|
|
@ -5,16 +5,9 @@ This module provides database-agnostic query generation for Neo4j and FalkorDB,
|
||||||
supporting index creation, fulltext search, and bulk operations.
|
supporting index creation, fulltext search, and bulk operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.models.edges.edge_db_queries import (
|
from graphiti_core.driver.driver import GraphProvider
|
||||||
ENTITY_EDGE_SAVE_BULK,
|
|
||||||
)
|
|
||||||
from graphiti_core.models.nodes.node_db_queries import (
|
|
||||||
ENTITY_NODE_SAVE_BULK,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mapping from Neo4j fulltext index names to FalkorDB node labels
|
# Mapping from Neo4j fulltext index names to FalkorDB node labels
|
||||||
NEO4J_TO_FALKORDB_MAPPING = {
|
NEO4J_TO_FALKORDB_MAPPING = {
|
||||||
|
|
@ -25,8 +18,8 @@ NEO4J_TO_FALKORDB_MAPPING = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
||||||
if db_type == 'falkordb':
|
if provider == GraphProvider.FALKORDB:
|
||||||
return [
|
return [
|
||||||
# Entity node
|
# Entity node
|
||||||
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
|
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
|
||||||
|
|
@ -41,109 +34,70 @@ def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
||||||
# HAS_MEMBER edge
|
# HAS_MEMBER edge
|
||||||
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
return [
|
return [
|
||||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||||
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
||||||
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
||||||
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
||||||
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||||
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
||||||
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
||||||
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
||||||
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
||||||
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
||||||
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
||||||
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
||||||
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
||||||
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
||||||
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
||||||
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
||||||
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
||||||
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
||||||
if db_type == 'falkordb':
|
if provider == GraphProvider.FALKORDB:
|
||||||
return [
|
return [
|
||||||
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
|
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
|
||||||
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
|
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
|
||||||
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
|
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
|
||||||
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
return [
|
return [
|
||||||
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
||||||
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
||||||
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
||||||
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
||||||
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
||||||
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
||||||
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
||||||
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
|
def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str:
|
||||||
if db_type == 'falkordb':
|
if provider == GraphProvider.FALKORDB:
|
||||||
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||||
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
||||||
else:
|
|
||||||
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
||||||
|
|
||||||
|
|
||||||
def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
|
def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
||||||
if db_type == 'falkordb':
|
if provider == GraphProvider.FALKORDB:
|
||||||
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
||||||
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
||||||
else:
|
|
||||||
return f'vector.similarity.cosine({vec1}, {vec2})'
|
return f'vector.similarity.cosine({vec1}, {vec2})'
|
||||||
|
|
||||||
|
|
||||||
def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
|
def get_relationships_query(name: str, provider: GraphProvider) -> str:
|
||||||
if db_type == 'falkordb':
|
if provider == GraphProvider.FALKORDB:
|
||||||
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
||||||
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
||||||
else:
|
|
||||||
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
|
||||||
|
|
||||||
|
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
||||||
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
|
|
||||||
if db_type == 'falkordb':
|
|
||||||
queries = []
|
|
||||||
for node in nodes:
|
|
||||||
for label in node['labels']:
|
|
||||||
queries.append(
|
|
||||||
(
|
|
||||||
f"""
|
|
||||||
UNWIND $nodes AS node
|
|
||||||
MERGE (n:Entity {{uuid: node.uuid}})
|
|
||||||
SET n:{label}
|
|
||||||
SET n = node
|
|
||||||
WITH n, node
|
|
||||||
SET n.name_embedding = vecf32(node.name_embedding)
|
|
||||||
RETURN n.uuid AS uuid
|
|
||||||
""",
|
|
||||||
{'nodes': [node]},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return queries
|
|
||||||
else:
|
|
||||||
return ENTITY_NODE_SAVE_BULK
|
|
||||||
|
|
||||||
|
|
||||||
def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
|
|
||||||
if db_type == 'falkordb':
|
|
||||||
return """
|
|
||||||
UNWIND $entity_edges AS edge
|
|
||||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
||||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
||||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
||||||
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
|
||||||
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
|
||||||
WITH r, edge
|
|
||||||
RETURN edge.uuid AS uuid"""
|
|
||||||
else:
|
|
||||||
return ENTITY_EDGE_SAVE_BULK
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import (
|
from graphiti_core.helpers import (
|
||||||
|
|
@ -93,8 +93,11 @@ load_dotenv()
|
||||||
|
|
||||||
class AddEpisodeResults(BaseModel):
|
class AddEpisodeResults(BaseModel):
|
||||||
episode: EpisodicNode
|
episode: EpisodicNode
|
||||||
|
episodic_edges: list[EpisodicEdge]
|
||||||
nodes: list[EntityNode]
|
nodes: list[EntityNode]
|
||||||
edges: list[EntityEdge]
|
edges: list[EntityEdge]
|
||||||
|
communities: list[CommunityNode]
|
||||||
|
community_edges: list[CommunityEdge]
|
||||||
|
|
||||||
|
|
||||||
class Graphiti:
|
class Graphiti:
|
||||||
|
|
@ -520,9 +523,12 @@ class Graphiti:
|
||||||
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
|
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
|
||||||
)
|
)
|
||||||
|
|
||||||
|
communities = []
|
||||||
|
community_edges = []
|
||||||
|
|
||||||
# Update any communities
|
# Update any communities
|
||||||
if update_communities:
|
if update_communities:
|
||||||
await semaphore_gather(
|
communities, community_edges = await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
update_community(self.driver, self.llm_client, self.embedder, node)
|
update_community(self.driver, self.llm_client, self.embedder, node)
|
||||||
for node in nodes
|
for node in nodes
|
||||||
|
|
@ -532,7 +538,14 @@ class Graphiti:
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
||||||
|
|
||||||
return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
|
return AddEpisodeResults(
|
||||||
|
episode=episode,
|
||||||
|
episodic_edges=episodic_edges,
|
||||||
|
nodes=hydrated_nodes,
|
||||||
|
edges=entity_edges,
|
||||||
|
communities=communities,
|
||||||
|
community_edges=community_edges,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
@ -817,7 +830,9 @@ class Graphiti:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
|
async def build_communities(
|
||||||
|
self, group_ids: list[str] | None = None
|
||||||
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||||
"""
|
"""
|
||||||
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
|
||||||
the content of these communities.
|
the content of these communities.
|
||||||
|
|
@ -846,7 +861,7 @@ class Graphiti:
|
||||||
max_coroutines=self.max_coroutines,
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
|
|
||||||
return community_nodes
|
return community_nodes, community_edges
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from numpy._typing import NDArray
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphProvider
|
||||||
from graphiti_core.errors import GroupIdValidationError
|
from graphiti_core.errors import GroupIdValidationError
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
@ -52,12 +53,12 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_default_group_id(db_type: str) -> str:
|
def get_default_group_id(provider: GraphProvider) -> str:
|
||||||
"""
|
"""
|
||||||
This function differentiates the default group id based on the database type.
|
This function differentiates the default group id based on the database type.
|
||||||
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
|
||||||
"""
|
"""
|
||||||
if db_type == 'falkordb':
|
if provider == GraphProvider.FALKORDB:
|
||||||
return '_'
|
return '_'
|
||||||
else:
|
else:
|
||||||
return ''
|
return ''
|
||||||
|
|
|
||||||
|
|
@ -14,43 +14,117 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphProvider
|
||||||
|
|
||||||
EPISODIC_EDGE_SAVE = """
|
EPISODIC_EDGE_SAVE = """
|
||||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||||
MATCH (node:Entity {uuid: $entity_uuid})
|
MATCH (node:Entity {uuid: $entity_uuid})
|
||||||
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
|
||||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||||
RETURN r.uuid AS uuid"""
|
RETURN e.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
EPISODIC_EDGE_SAVE_BULK = """
|
EPISODIC_EDGE_SAVE_BULK = """
|
||||||
UNWIND $episodic_edges AS edge
|
UNWIND $episodic_edges AS edge
|
||||||
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
|
MATCH (episode:Episodic {uuid: edge.source_node_uuid})
|
||||||
MATCH (node:Entity {uuid: edge.target_node_uuid})
|
MATCH (node:Entity {uuid: edge.target_node_uuid})
|
||||||
MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
|
MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
|
||||||
SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
|
SET e = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
|
||||||
RETURN r.uuid AS uuid
|
RETURN e.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ENTITY_EDGE_SAVE = """
|
EPISODIC_EDGE_RETURN = """
|
||||||
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
e.uuid AS uuid,
|
||||||
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
e.group_id AS group_id,
|
||||||
MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
n.uuid AS source_node_uuid,
|
||||||
SET r = $edge_data
|
m.uuid AS target_node_uuid,
|
||||||
WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding)
|
e.created_at AS created_at
|
||||||
RETURN r.uuid AS uuid"""
|
|
||||||
|
|
||||||
ENTITY_EDGE_SAVE_BULK = """
|
|
||||||
UNWIND $entity_edges AS edge
|
|
||||||
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
||||||
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
||||||
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
||||||
SET r = edge
|
|
||||||
WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
|
|
||||||
RETURN edge.uuid AS uuid
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
COMMUNITY_EDGE_SAVE = """
|
|
||||||
MATCH (community:Community {uuid: $community_uuid})
|
def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
||||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
if provider == GraphProvider.FALKORDB:
|
||||||
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
return """
|
||||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||||
RETURN r.uuid AS uuid"""
|
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||||
|
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||||
|
SET e = $edge_data
|
||||||
|
RETURN e.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
return """
|
||||||
|
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
||||||
|
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
||||||
|
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
||||||
|
SET e = $edge_data
|
||||||
|
WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
|
||||||
|
RETURN e.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
||||||
|
if provider == GraphProvider.FALKORDB:
|
||||||
|
return """
|
||||||
|
UNWIND $entity_edges AS edge
|
||||||
|
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||||
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||||
|
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||||
|
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
||||||
|
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
||||||
|
WITH r, edge
|
||||||
|
RETURN edge.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
return """
|
||||||
|
UNWIND $entity_edges AS edge
|
||||||
|
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
||||||
|
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
||||||
|
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
||||||
|
SET e = edge
|
||||||
|
WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
|
||||||
|
RETURN edge.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
ENTITY_EDGE_RETURN = """
|
||||||
|
e.uuid AS uuid,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
e.group_id AS group_id,
|
||||||
|
e.name AS name,
|
||||||
|
e.fact AS fact,
|
||||||
|
e.episodes AS episodes,
|
||||||
|
e.created_at AS created_at,
|
||||||
|
e.expired_at AS expired_at,
|
||||||
|
e.valid_at AS valid_at,
|
||||||
|
e.invalid_at AS invalid_at,
|
||||||
|
properties(e) AS attributes
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_community_edge_save_query(provider: GraphProvider) -> str:
|
||||||
|
if provider == GraphProvider.FALKORDB:
|
||||||
|
return """
|
||||||
|
MATCH (community:Community {uuid: $community_uuid})
|
||||||
|
MATCH (node {uuid: $entity_uuid})
|
||||||
|
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||||
|
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||||
|
RETURN e.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
return """
|
||||||
|
MATCH (community:Community {uuid: $community_uuid})
|
||||||
|
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||||
|
MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||||
|
SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||||
|
RETURN e.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
COMMUNITY_EDGE_RETURN = """
|
||||||
|
e.uuid AS uuid,
|
||||||
|
e.group_id AS group_id,
|
||||||
|
n.uuid AS source_node_uuid,
|
||||||
|
m.uuid AS target_node_uuid,
|
||||||
|
e.created_at AS created_at
|
||||||
|
"""
|
||||||
|
|
|
||||||
|
|
@ -14,39 +14,120 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphProvider
|
||||||
|
|
||||||
EPISODIC_NODE_SAVE = """
|
EPISODIC_NODE_SAVE = """
|
||||||
MERGE (n:Episodic {uuid: $uuid})
|
MERGE (n:Episodic {uuid: $uuid})
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||||
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
|
||||||
RETURN n.uuid AS uuid"""
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
EPISODIC_NODE_SAVE_BULK = """
|
EPISODIC_NODE_SAVE_BULK = """
|
||||||
UNWIND $episodes AS episode
|
UNWIND $episodes AS episode
|
||||||
MERGE (n:Episodic {uuid: episode.uuid})
|
MERGE (n:Episodic {uuid: episode.uuid})
|
||||||
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
|
||||||
source: episode.source, content: episode.content,
|
source: episode.source, content: episode.content,
|
||||||
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
|
||||||
RETURN n.uuid AS uuid
|
RETURN n.uuid AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ENTITY_NODE_SAVE = """
|
EPISODIC_NODE_RETURN = """
|
||||||
MERGE (n:Entity {uuid: $entity_data.uuid})
|
e.content AS content,
|
||||||
SET n:$($labels)
|
e.created_at AS created_at,
|
||||||
SET n = $entity_data
|
e.valid_at AS valid_at,
|
||||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
|
e.uuid AS uuid,
|
||||||
RETURN n.uuid AS uuid"""
|
e.name AS name,
|
||||||
|
e.group_id AS group_id,
|
||||||
ENTITY_NODE_SAVE_BULK = """
|
e.source_description AS source_description,
|
||||||
UNWIND $nodes AS node
|
e.source AS source,
|
||||||
MERGE (n:Entity {uuid: node.uuid})
|
e.entity_edges AS entity_edges
|
||||||
SET n:$(node.labels)
|
|
||||||
SET n = node
|
|
||||||
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
|
||||||
RETURN n.uuid AS uuid
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
COMMUNITY_NODE_SAVE = """
|
|
||||||
|
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
||||||
|
if provider == GraphProvider.FALKORDB:
|
||||||
|
return f"""
|
||||||
|
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||||
|
SET n:{labels}
|
||||||
|
SET n = $entity_data
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
||||||
|
SET n:{labels}
|
||||||
|
SET n = $entity_data
|
||||||
|
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
|
||||||
|
if provider == GraphProvider.FALKORDB:
|
||||||
|
queries = []
|
||||||
|
for node in nodes:
|
||||||
|
for label in node['labels']:
|
||||||
|
queries.append(
|
||||||
|
(
|
||||||
|
f"""
|
||||||
|
UNWIND $nodes AS node
|
||||||
|
MERGE (n:Entity {{uuid: node.uuid}})
|
||||||
|
SET n:{label}
|
||||||
|
SET n = node
|
||||||
|
WITH n, node
|
||||||
|
SET n.name_embedding = vecf32(node.name_embedding)
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
""",
|
||||||
|
{'nodes': [node]},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return queries
|
||||||
|
|
||||||
|
return """
|
||||||
|
UNWIND $nodes AS node
|
||||||
|
MERGE (n:Entity {uuid: node.uuid})
|
||||||
|
SET n:$(node.labels)
|
||||||
|
SET n = node
|
||||||
|
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
ENTITY_NODE_RETURN = """
|
||||||
|
n.uuid AS uuid,
|
||||||
|
n.name AS name,
|
||||||
|
n.group_id AS group_id,
|
||||||
|
n.created_at AS created_at,
|
||||||
|
n.summary AS summary,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS attributes
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_community_node_save_query(provider: GraphProvider) -> str:
|
||||||
|
if provider == GraphProvider.FALKORDB:
|
||||||
|
return """
|
||||||
|
MERGE (n:Community {uuid: $uuid})
|
||||||
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: $name_embedding}
|
||||||
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
return """
|
||||||
MERGE (n:Community {uuid: $uuid})
|
MERGE (n:Community {uuid: $uuid})
|
||||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||||
RETURN n.uuid AS uuid"""
|
RETURN n.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
COMMUNITY_NODE_RETURN = """
|
||||||
|
n.uuid AS uuid,
|
||||||
|
n.name AS name,
|
||||||
|
n.name_embedding AS name_embedding,
|
||||||
|
n.group_id AS group_id,
|
||||||
|
n.summary AS summary,
|
||||||
|
n.created_at AS created_at
|
||||||
|
"""
|
||||||
|
|
|
||||||
|
|
@ -25,29 +25,22 @@ from uuid import uuid4
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.errors import NodeNotFoundError
|
from graphiti_core.errors import NodeNotFoundError
|
||||||
from graphiti_core.helpers import parse_db_date
|
from graphiti_core.helpers import parse_db_date
|
||||||
from graphiti_core.models.nodes.node_db_queries import (
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
COMMUNITY_NODE_SAVE,
|
COMMUNITY_NODE_RETURN,
|
||||||
ENTITY_NODE_SAVE,
|
ENTITY_NODE_RETURN,
|
||||||
|
EPISODIC_NODE_RETURN,
|
||||||
EPISODIC_NODE_SAVE,
|
EPISODIC_NODE_SAVE,
|
||||||
|
get_community_node_save_query,
|
||||||
|
get_entity_node_save_query,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.datetime_utils import utc_now
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ENTITY_NODE_RETURN: LiteralString = """
|
|
||||||
RETURN
|
|
||||||
n.uuid As uuid,
|
|
||||||
n.name AS name,
|
|
||||||
n.group_id AS group_id,
|
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS attributes"""
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodeType(Enum):
|
class EpisodeType(Enum):
|
||||||
"""
|
"""
|
||||||
|
|
@ -96,18 +89,26 @@ class Node(BaseModel, ABC):
|
||||||
async def save(self, driver: GraphDriver): ...
|
async def save(self, driver: GraphDriver): ...
|
||||||
|
|
||||||
async def delete(self, driver: GraphDriver):
|
async def delete(self, driver: GraphDriver):
|
||||||
result = await driver.execute_query(
|
if driver.provider == GraphProvider.FALKORDB:
|
||||||
"""
|
for label in ['Entity', 'Episodic', 'Community']:
|
||||||
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
await driver.execute_query(
|
||||||
DETACH DELETE n
|
f"""
|
||||||
""",
|
MATCH (n:{label} {{uuid: $uuid}})
|
||||||
uuid=self.uuid,
|
DETACH DELETE n
|
||||||
)
|
""",
|
||||||
|
uuid=self.uuid,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
||||||
|
DETACH DELETE n
|
||||||
|
""",
|
||||||
|
uuid=self.uuid,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f'Deleted Node: {self.uuid}')
|
logger.debug(f'Deleted Node: {self.uuid}')
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.uuid)
|
return hash(self.uuid)
|
||||||
|
|
||||||
|
|
@ -118,15 +119,23 @@ class Node(BaseModel, ABC):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
|
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
|
||||||
await driver.execute_query(
|
if driver.provider == GraphProvider.FALKORDB:
|
||||||
"""
|
for label in ['Entity', 'Episodic', 'Community']:
|
||||||
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
await driver.execute_query(
|
||||||
DETACH DELETE n
|
f"""
|
||||||
""",
|
MATCH (n:{label} {{group_id: $group_id}})
|
||||||
group_id=group_id,
|
DETACH DELETE n
|
||||||
)
|
""",
|
||||||
|
group_id=group_id,
|
||||||
return 'SUCCESS'
|
)
|
||||||
|
else:
|
||||||
|
await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
||||||
|
DETACH DELETE n
|
||||||
|
""",
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
||||||
|
|
@ -169,17 +178,10 @@ class EpisodicNode(Node):
|
||||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic {uuid: $uuid})
|
MATCH (e:Episodic {uuid: $uuid})
|
||||||
RETURN e.content AS content,
|
RETURN
|
||||||
e.created_at AS created_at,
|
"""
|
||||||
e.valid_at AS valid_at,
|
+ EPISODIC_NODE_RETURN,
|
||||||
e.uuid AS uuid,
|
|
||||||
e.name AS name,
|
|
||||||
e.group_id AS group_id,
|
|
||||||
e.source_description AS source_description,
|
|
||||||
e.source AS source,
|
|
||||||
e.entity_edges AS entity_edges
|
|
||||||
""",
|
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -195,18 +197,11 @@ class EpisodicNode(Node):
|
||||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
MATCH (e:Episodic)
|
||||||
|
WHERE e.uuid IN $uuids
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
e.content AS content,
|
"""
|
||||||
e.created_at AS created_at,
|
+ EPISODIC_NODE_RETURN,
|
||||||
e.valid_at AS valid_at,
|
|
||||||
e.uuid AS uuid,
|
|
||||||
e.name AS name,
|
|
||||||
e.group_id AS group_id,
|
|
||||||
e.source_description AS source_description,
|
|
||||||
e.source AS source,
|
|
||||||
e.entity_edges AS entity_edges
|
|
||||||
""",
|
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -228,22 +223,17 @@ class EpisodicNode(Node):
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
|
MATCH (e:Episodic)
|
||||||
"""
|
WHERE e.group_id IN $group_ids
|
||||||
|
"""
|
||||||
+ cursor_query
|
+ cursor_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
e.content AS content,
|
"""
|
||||||
e.created_at AS created_at,
|
+ EPISODIC_NODE_RETURN
|
||||||
e.valid_at AS valid_at,
|
+ """
|
||||||
e.uuid AS uuid,
|
ORDER BY uuid DESC
|
||||||
e.name AS name,
|
"""
|
||||||
e.group_id AS group_id,
|
|
||||||
e.source_description AS source_description,
|
|
||||||
e.source AS source,
|
|
||||||
e.entity_edges AS entity_edges
|
|
||||||
ORDER BY e.uuid DESC
|
|
||||||
"""
|
|
||||||
+ limit_query,
|
+ limit_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
uuid=uuid_cursor,
|
uuid=uuid_cursor,
|
||||||
|
|
@ -259,18 +249,10 @@ class EpisodicNode(Node):
|
||||||
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
|
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
e.content AS content,
|
"""
|
||||||
e.created_at AS created_at,
|
+ EPISODIC_NODE_RETURN,
|
||||||
e.valid_at AS valid_at,
|
|
||||||
e.uuid AS uuid,
|
|
||||||
e.name AS name,
|
|
||||||
e.group_id AS group_id,
|
|
||||||
e.source_description AS source_description,
|
|
||||||
e.source AS source,
|
|
||||||
e.entity_edges AS entity_edges
|
|
||||||
""",
|
|
||||||
entity_node_uuid=entity_node_uuid,
|
entity_node_uuid=entity_node_uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -297,11 +279,14 @@ class EntityNode(Node):
|
||||||
return self.name_embedding
|
return self.name_embedding
|
||||||
|
|
||||||
async def load_name_embedding(self, driver: GraphDriver):
|
async def load_name_embedding(self, driver: GraphDriver):
|
||||||
query: LiteralString = """
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
MATCH (n:Entity {uuid: $uuid})
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
RETURN n.name_embedding AS name_embedding
|
RETURN n.name_embedding AS name_embedding
|
||||||
"""
|
""",
|
||||||
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
|
uuid=self.uuid,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
|
||||||
if len(records) == 0:
|
if len(records) == 0:
|
||||||
raise NodeNotFoundError(self.uuid)
|
raise NodeNotFoundError(self.uuid)
|
||||||
|
|
@ -317,12 +302,12 @@ class EntityNode(Node):
|
||||||
'summary': self.summary,
|
'summary': self.summary,
|
||||||
'created_at': self.created_at,
|
'created_at': self.created_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
entity_data.update(self.attributes or {})
|
entity_data.update(self.attributes or {})
|
||||||
|
|
||||||
|
labels = ':'.join(self.labels + ['Entity'])
|
||||||
|
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
ENTITY_NODE_SAVE,
|
get_entity_node_save_query(driver.provider, labels),
|
||||||
labels=self.labels + ['Entity'],
|
|
||||||
entity_data=entity_data,
|
entity_data=entity_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -332,14 +317,12 @@ class EntityNode(Node):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
query = (
|
|
||||||
"""
|
|
||||||
MATCH (n:Entity {uuid: $uuid})
|
|
||||||
"""
|
|
||||||
+ ENTITY_NODE_RETURN
|
|
||||||
)
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
"""
|
||||||
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
|
RETURN
|
||||||
|
"""
|
||||||
|
+ ENTITY_NODE_RETURN,
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -355,8 +338,10 @@ class EntityNode(Node):
|
||||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
MATCH (n:Entity)
|
||||||
"""
|
WHERE n.uuid IN $uuids
|
||||||
|
RETURN
|
||||||
|
"""
|
||||||
+ ENTITY_NODE_RETURN,
|
+ ENTITY_NODE_RETURN,
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
|
@ -379,22 +364,26 @@ class EntityNode(Node):
|
||||||
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
||||||
with_embeddings_query: LiteralString = (
|
with_embeddings_query: LiteralString = (
|
||||||
""",
|
""",
|
||||||
n.name_embedding AS name_embedding
|
n.name_embedding AS name_embedding
|
||||||
"""
|
"""
|
||||||
if with_embeddings
|
if with_embeddings
|
||||||
else ''
|
else ''
|
||||||
)
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
MATCH (n:Entity)
|
||||||
"""
|
WHERE n.group_id IN $group_ids
|
||||||
|
"""
|
||||||
+ cursor_query
|
+ cursor_query
|
||||||
|
+ """
|
||||||
|
RETURN
|
||||||
|
"""
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ with_embeddings_query
|
+ with_embeddings_query
|
||||||
+ """
|
+ """
|
||||||
ORDER BY n.uuid DESC
|
ORDER BY n.uuid DESC
|
||||||
"""
|
"""
|
||||||
+ limit_query,
|
+ limit_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
uuid=uuid_cursor,
|
uuid=uuid_cursor,
|
||||||
|
|
@ -413,7 +402,7 @@ class CommunityNode(Node):
|
||||||
|
|
||||||
async def save(self, driver: GraphDriver):
|
async def save(self, driver: GraphDriver):
|
||||||
result = await driver.execute_query(
|
result = await driver.execute_query(
|
||||||
COMMUNITY_NODE_SAVE,
|
get_community_node_save_query(driver.provider),
|
||||||
uuid=self.uuid,
|
uuid=self.uuid,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
group_id=self.group_id,
|
group_id=self.group_id,
|
||||||
|
|
@ -436,11 +425,14 @@ class CommunityNode(Node):
|
||||||
return self.name_embedding
|
return self.name_embedding
|
||||||
|
|
||||||
async def load_name_embedding(self, driver: GraphDriver):
|
async def load_name_embedding(self, driver: GraphDriver):
|
||||||
query: LiteralString = """
|
records, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
MATCH (c:Community {uuid: $uuid})
|
MATCH (c:Community {uuid: $uuid})
|
||||||
RETURN c.name_embedding AS name_embedding
|
RETURN c.name_embedding AS name_embedding
|
||||||
"""
|
""",
|
||||||
records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
|
uuid=self.uuid,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
|
||||||
if len(records) == 0:
|
if len(records) == 0:
|
||||||
raise NodeNotFoundError(self.uuid)
|
raise NodeNotFoundError(self.uuid)
|
||||||
|
|
@ -451,14 +443,10 @@ class CommunityNode(Node):
|
||||||
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community {uuid: $uuid})
|
MATCH (n:Community {uuid: $uuid})
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
"""
|
||||||
n.name AS name,
|
+ COMMUNITY_NODE_RETURN,
|
||||||
n.group_id AS group_id,
|
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary
|
|
||||||
""",
|
|
||||||
uuid=uuid,
|
uuid=uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -474,14 +462,11 @@ class CommunityNode(Node):
|
||||||
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community) WHERE n.uuid IN $uuids
|
MATCH (n:Community)
|
||||||
RETURN
|
WHERE n.uuid IN $uuids
|
||||||
n.uuid As uuid,
|
RETURN
|
||||||
n.name AS name,
|
"""
|
||||||
n.group_id AS group_id,
|
+ COMMUNITY_NODE_RETURN,
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary
|
|
||||||
""",
|
|
||||||
uuids=uuids,
|
uuids=uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -503,18 +488,17 @@ class CommunityNode(Node):
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Community) WHERE n.group_id IN $group_ids
|
MATCH (n:Community)
|
||||||
"""
|
WHERE n.group_id IN $group_ids
|
||||||
|
"""
|
||||||
+ cursor_query
|
+ cursor_query
|
||||||
+ """
|
+ """
|
||||||
RETURN
|
RETURN
|
||||||
n.uuid As uuid,
|
"""
|
||||||
n.name AS name,
|
+ COMMUNITY_NODE_RETURN
|
||||||
n.group_id AS group_id,
|
+ """
|
||||||
n.created_at AS created_at,
|
ORDER BY n.uuid DESC
|
||||||
n.summary AS summary
|
"""
|
||||||
ORDER BY n.uuid DESC
|
|
||||||
"""
|
|
||||||
+ limit_query,
|
+ limit_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
uuid=uuid_cursor,
|
uuid=uuid_cursor,
|
||||||
|
|
@ -586,6 +570,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
||||||
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
|
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
|
||||||
if not nodes: # Handle empty list case
|
if not nodes: # Handle empty list case
|
||||||
return
|
return
|
||||||
|
|
||||||
name_embeddings = await embedder.create_batch([node.name for node in nodes])
|
name_embeddings = await embedder.create_batch([node.name for node in nodes])
|
||||||
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
|
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
|
||||||
node.name_embedding = name_embedding
|
node.name_embedding = name_embedding
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ def edge_search_filter_query_constructor(
|
||||||
|
|
||||||
if filters.edge_types is not None:
|
if filters.edge_types is not None:
|
||||||
edge_types = filters.edge_types
|
edge_types = filters.edge_types
|
||||||
edge_types_filter = '\nAND r.name in $edge_types'
|
edge_types_filter = '\nAND e.name in $edge_types'
|
||||||
filter_query += edge_types_filter
|
filter_query += edge_types_filter
|
||||||
filter_params['edge_types'] = edge_types
|
filter_params['edge_types'] = edge_types
|
||||||
|
|
||||||
|
|
@ -88,7 +88,7 @@ def edge_search_filter_query_constructor(
|
||||||
filter_params['valid_at_' + str(j)] = date_filter.date
|
filter_params['valid_at_' + str(j)] = date_filter.date
|
||||||
|
|
||||||
and_filters = [
|
and_filters = [
|
||||||
'(r.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
|
'(e.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
|
||||||
for j, date_filter in enumerate(or_list)
|
for j, date_filter in enumerate(or_list)
|
||||||
]
|
]
|
||||||
and_filter_query = ''
|
and_filter_query = ''
|
||||||
|
|
@ -113,7 +113,7 @@ def edge_search_filter_query_constructor(
|
||||||
filter_params['invalid_at_' + str(j)] = date_filter.date
|
filter_params['invalid_at_' + str(j)] = date_filter.date
|
||||||
|
|
||||||
and_filters = [
|
and_filters = [
|
||||||
'(r.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
|
'(e.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
|
||||||
for j, date_filter in enumerate(or_list)
|
for j, date_filter in enumerate(or_list)
|
||||||
]
|
]
|
||||||
and_filter_query = ''
|
and_filter_query = ''
|
||||||
|
|
@ -138,7 +138,7 @@ def edge_search_filter_query_constructor(
|
||||||
filter_params['created_at_' + str(j)] = date_filter.date
|
filter_params['created_at_' + str(j)] = date_filter.date
|
||||||
|
|
||||||
and_filters = [
|
and_filters = [
|
||||||
'(r.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
|
'(e.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
|
||||||
for j, date_filter in enumerate(or_list)
|
for j, date_filter in enumerate(or_list)
|
||||||
]
|
]
|
||||||
and_filter_query = ''
|
and_filter_query = ''
|
||||||
|
|
@ -163,7 +163,7 @@ def edge_search_filter_query_constructor(
|
||||||
filter_params['expired_at_' + str(j)] = date_filter.date
|
filter_params['expired_at_' + str(j)] = date_filter.date
|
||||||
|
|
||||||
and_filters = [
|
and_filters = [
|
||||||
'(r.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
|
'(e.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
|
||||||
for j, date_filter in enumerate(or_list)
|
for j, date_filter in enumerate(or_list)
|
||||||
]
|
]
|
||||||
and_filter_query = ''
|
and_filter_query = ''
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ import numpy as np
|
||||||
from numpy._typing import NDArray
|
from numpy._typing import NDArray
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
||||||
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
||||||
from graphiti_core.graph_queries import (
|
from graphiti_core.graph_queries import (
|
||||||
get_nodes_query,
|
get_nodes_query,
|
||||||
|
|
@ -36,6 +36,8 @@ from graphiti_core.helpers import (
|
||||||
normalize_l2,
|
normalize_l2,
|
||||||
semaphore_gather,
|
semaphore_gather,
|
||||||
)
|
)
|
||||||
|
from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN
|
||||||
|
from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN
|
||||||
from graphiti_core.nodes import (
|
from graphiti_core.nodes import (
|
||||||
ENTITY_NODE_RETURN,
|
ENTITY_NODE_RETURN,
|
||||||
CommunityNode,
|
CommunityNode,
|
||||||
|
|
@ -100,20 +102,13 @@ async def get_mentioned_nodes(
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
episode_uuids = [episode.uuid for episode in episodes]
|
episode_uuids = [episode.uuid for episode in episodes]
|
||||||
|
|
||||||
query = """
|
|
||||||
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
|
||||||
RETURN DISTINCT
|
|
||||||
n.uuid As uuid,
|
|
||||||
n.group_id AS group_id,
|
|
||||||
n.name AS name,
|
|
||||||
n.created_at AS created_at,
|
|
||||||
n.summary AS summary,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS attributes
|
|
||||||
"""
|
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
"""
|
||||||
|
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
|
||||||
|
WHERE episode.uuid IN $uuids
|
||||||
|
RETURN DISTINCT
|
||||||
|
"""
|
||||||
|
+ ENTITY_NODE_RETURN,
|
||||||
uuids=episode_uuids,
|
uuids=episode_uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -128,18 +123,13 @@ async def get_communities_by_nodes(
|
||||||
) -> list[CommunityNode]:
|
) -> list[CommunityNode]:
|
||||||
node_uuids = [node.uuid for node in nodes]
|
node_uuids = [node.uuid for node in nodes]
|
||||||
|
|
||||||
query = """
|
|
||||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
|
||||||
RETURN DISTINCT
|
|
||||||
c.uuid As uuid,
|
|
||||||
c.group_id AS group_id,
|
|
||||||
c.name AS name,
|
|
||||||
c.created_at AS created_at,
|
|
||||||
c.summary AS summary
|
|
||||||
"""
|
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
"""
|
||||||
|
MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity)
|
||||||
|
WHERE m.uuid IN $uuids
|
||||||
|
RETURN DISTINCT
|
||||||
|
"""
|
||||||
|
+ COMMUNITY_NODE_RETURN,
|
||||||
uuids=node_uuids,
|
uuids=node_uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -164,38 +154,30 @@ async def edge_fulltext_search(
|
||||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
get_relationships_query('edge_name_and_fact', db_type=driver.provider)
|
get_relationships_query('edge_name_and_fact', provider=driver.provider)
|
||||||
+ """
|
+ """
|
||||||
YIELD relationship AS rel, score
|
YIELD relationship AS rel, score
|
||||||
MATCH (n:Entity)-[r:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
||||||
WHERE r.group_id IN $group_ids """
|
WHERE e.group_id IN $group_ids """
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH r, score, startNode(r) AS n, endNode(r) AS m
|
WITH e, score, n, m
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
"""
|
||||||
r.group_id AS group_id,
|
+ ENTITY_EDGE_RETURN
|
||||||
n.uuid AS source_node_uuid,
|
+ """
|
||||||
m.uuid AS target_node_uuid,
|
ORDER BY score DESC
|
||||||
r.created_at AS created_at,
|
LIMIT $limit
|
||||||
r.name AS name,
|
|
||||||
r.fact AS fact,
|
|
||||||
r.episodes AS episodes,
|
|
||||||
r.expired_at AS expired_at,
|
|
||||||
r.valid_at AS valid_at,
|
|
||||||
r.invalid_at AS invalid_at,
|
|
||||||
properties(r) AS attributes
|
|
||||||
ORDER BY score DESC LIMIT $limit
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
params=filter_params,
|
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
**filter_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -219,58 +201,47 @@ async def edge_similarity_search(
|
||||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||||
query_params.update(filter_params)
|
query_params.update(filter_params)
|
||||||
|
|
||||||
group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL'
|
group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
|
||||||
if group_ids is not None:
|
if group_ids is not None:
|
||||||
group_filter_query += '\nAND r.group_id IN $group_ids'
|
group_filter_query += '\nAND e.group_id IN $group_ids'
|
||||||
query_params['group_ids'] = group_ids
|
query_params['group_ids'] = group_ids
|
||||||
query_params['source_node_uuid'] = source_node_uuid
|
|
||||||
query_params['target_node_uuid'] = target_node_uuid
|
|
||||||
|
|
||||||
if source_node_uuid is not None:
|
if source_node_uuid is not None:
|
||||||
group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
|
query_params['source_uuid'] = source_node_uuid
|
||||||
|
group_filter_query += '\nAND (n.uuid = $source_uuid)'
|
||||||
|
|
||||||
if target_node_uuid is not None:
|
if target_node_uuid is not None:
|
||||||
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
query_params['target_uuid'] = target_node_uuid
|
||||||
|
group_filter_query += '\nAND (m.uuid = $target_uuid)'
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
RUNTIME_QUERY
|
RUNTIME_QUERY
|
||||||
+ """
|
+ """
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH DISTINCT r, """
|
WITH DISTINCT e, n, m, """
|
||||||
+ get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
|
+ get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
|
||||||
+ """ AS score
|
+ """ AS score
|
||||||
WHERE score > $min_score
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
"""
|
||||||
r.group_id AS group_id,
|
+ ENTITY_EDGE_RETURN
|
||||||
startNode(r).uuid AS source_node_uuid,
|
+ """
|
||||||
endNode(r).uuid AS target_node_uuid,
|
|
||||||
r.created_at AS created_at,
|
|
||||||
r.name AS name,
|
|
||||||
r.fact AS fact,
|
|
||||||
r.episodes AS episodes,
|
|
||||||
r.expired_at AS expired_at,
|
|
||||||
r.valid_at AS valid_at,
|
|
||||||
r.invalid_at AS invalid_at,
|
|
||||||
properties(r) AS attributes
|
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
records, header, _ = await driver.execute_query(
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
params=query_params,
|
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
source_uuid=source_node_uuid,
|
|
||||||
target_uuid=target_node_uuid,
|
|
||||||
group_ids=group_ids,
|
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
**query_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -293,41 +264,31 @@ async def edge_bfs_search(
|
||||||
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
|
f"""
|
||||||
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||||
|
MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
|
||||||
|
UNWIND relationships(path) AS rel
|
||||||
|
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
||||||
|
WHERE e.uuid = rel.uuid
|
||||||
|
AND e.group_id IN $group_ids
|
||||||
"""
|
"""
|
||||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
||||||
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
||||||
UNWIND relationships(path) AS rel
|
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
|
||||||
WHERE r.uuid = rel.uuid
|
|
||||||
AND r.group_id IN $group_ids
|
|
||||||
"""
|
|
||||||
+ filter_query
|
+ filter_query
|
||||||
+ """
|
+ """
|
||||||
RETURN DISTINCT
|
RETURN DISTINCT
|
||||||
r.uuid AS uuid,
|
"""
|
||||||
r.group_id AS group_id,
|
+ ENTITY_EDGE_RETURN
|
||||||
startNode(r).uuid AS source_node_uuid,
|
+ """
|
||||||
endNode(r).uuid AS target_node_uuid,
|
LIMIT $limit
|
||||||
r.created_at AS created_at,
|
|
||||||
r.name AS name,
|
|
||||||
r.fact AS fact,
|
|
||||||
r.episodes AS episodes,
|
|
||||||
r.expired_at AS expired_at,
|
|
||||||
r.valid_at AS valid_at,
|
|
||||||
r.invalid_at AS invalid_at,
|
|
||||||
properties(r) AS attributes
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
params=filter_params,
|
|
||||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||||
depth=bfs_max_depth,
|
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
**filter_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = [get_entity_edge_from_record(record) for record in records]
|
edges = [get_entity_edge_from_record(record) for record in records]
|
||||||
|
|
@ -352,23 +313,27 @@ async def node_fulltext_search(
|
||||||
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
|
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
|
||||||
+ """
|
+ """
|
||||||
YIELD node AS n, score
|
YIELD node AS n, score
|
||||||
WITH n, score
|
WHERE n:Entity AND n.group_id IN $group_ids
|
||||||
LIMIT $limit
|
WITH n, score
|
||||||
WHERE n:Entity AND n.group_id IN $group_ids
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
+ filter_query
|
+ filter_query
|
||||||
|
+ """
|
||||||
|
RETURN
|
||||||
|
"""
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
records, header, _ = await driver.execute_query(
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
params=filter_params,
|
|
||||||
query=fuzzy_query,
|
query=fuzzy_query,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
**filter_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
@ -406,22 +371,23 @@ async def node_similarity_search(
|
||||||
WITH n, """
|
WITH n, """
|
||||||
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
||||||
+ """ AS score
|
+ """ AS score
|
||||||
WHERE score > $min_score"""
|
WHERE score > $min_score
|
||||||
|
RETURN
|
||||||
|
"""
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
records, header, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
params=query_params,
|
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
group_ids=group_ids,
|
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
**query_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
@ -444,26 +410,29 @@ async def node_bfs_search(
|
||||||
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
|
f"""
|
||||||
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
||||||
|
MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
||||||
|
WHERE n.group_id = origin.group_id
|
||||||
|
AND origin.group_id IN $group_ids
|
||||||
"""
|
"""
|
||||||
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
||||||
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
||||||
WHERE n.group_id = origin.group_id
|
|
||||||
AND origin.group_id IN $group_ids
|
|
||||||
"""
|
|
||||||
+ filter_query
|
+ filter_query
|
||||||
|
+ """
|
||||||
|
RETURN
|
||||||
|
"""
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
+ """
|
+ """
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
params=filter_params,
|
|
||||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||||
depth=bfs_max_depth,
|
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
**filter_params,
|
||||||
)
|
)
|
||||||
nodes = [get_entity_node_from_record(record) for record in records]
|
nodes = [get_entity_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -489,16 +458,10 @@ async def episode_fulltext_search(
|
||||||
MATCH (e:Episodic)
|
MATCH (e:Episodic)
|
||||||
WHERE e.uuid = episode.uuid
|
WHERE e.uuid = episode.uuid
|
||||||
AND e.group_id IN $group_ids
|
AND e.group_id IN $group_ids
|
||||||
RETURN
|
RETURN
|
||||||
e.content AS content,
|
"""
|
||||||
e.created_at AS created_at,
|
+ EPISODIC_NODE_RETURN
|
||||||
e.valid_at AS valid_at,
|
+ """
|
||||||
e.uuid AS uuid,
|
|
||||||
e.name AS name,
|
|
||||||
e.group_id AS group_id,
|
|
||||||
e.source_description AS source_description,
|
|
||||||
e.source AS source,
|
|
||||||
e.entity_edges AS entity_edges
|
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
@ -530,15 +493,12 @@ async def community_fulltext_search(
|
||||||
query = (
|
query = (
|
||||||
get_nodes_query(driver.provider, 'community_name', '$query')
|
get_nodes_query(driver.provider, 'community_name', '$query')
|
||||||
+ """
|
+ """
|
||||||
YIELD node AS comm, score
|
YIELD node AS n, score
|
||||||
WHERE comm.group_id IN $group_ids
|
WHERE n.group_id IN $group_ids
|
||||||
RETURN
|
RETURN
|
||||||
comm.uuid AS uuid,
|
"""
|
||||||
comm.group_id AS group_id,
|
+ COMMUNITY_NODE_RETURN
|
||||||
comm.name AS name,
|
+ """
|
||||||
comm.created_at AS created_at,
|
|
||||||
comm.summary AS summary,
|
|
||||||
comm.name_embedding AS name_embedding
|
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
@ -568,39 +528,37 @@ async def community_similarity_search(
|
||||||
|
|
||||||
group_filter_query: LiteralString = ''
|
group_filter_query: LiteralString = ''
|
||||||
if group_ids is not None:
|
if group_ids is not None:
|
||||||
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
group_filter_query += 'WHERE n.group_id IN $group_ids'
|
||||||
query_params['group_ids'] = group_ids
|
query_params['group_ids'] = group_ids
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
RUNTIME_QUERY
|
RUNTIME_QUERY
|
||||||
+ """
|
+ """
|
||||||
MATCH (comm:Community)
|
MATCH (n:Community)
|
||||||
"""
|
"""
|
||||||
+ group_filter_query
|
+ group_filter_query
|
||||||
+ """
|
+ """
|
||||||
WITH comm, """
|
WITH n,
|
||||||
+ get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
|
"""
|
||||||
|
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
||||||
+ """ AS score
|
+ """ AS score
|
||||||
WHERE score > $min_score
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
comm.uuid As uuid,
|
"""
|
||||||
comm.group_id AS group_id,
|
+ COMMUNITY_NODE_RETURN
|
||||||
comm.name AS name,
|
+ """
|
||||||
comm.created_at AS created_at,
|
ORDER BY score DESC
|
||||||
comm.summary AS summary,
|
LIMIT $limit
|
||||||
comm.name_embedding AS name_embedding
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
group_ids=group_ids,
|
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
**query_params,
|
||||||
)
|
)
|
||||||
communities = [get_community_node_from_record(record) for record in records]
|
communities = [get_community_node_from_record(record) for record in records]
|
||||||
|
|
||||||
|
|
@ -719,8 +677,8 @@ async def get_relevant_nodes(
|
||||||
WHERE m.group_id = $group_id
|
WHERE m.group_id = $group_id
|
||||||
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
||||||
|
|
||||||
WITH node,
|
WITH node,
|
||||||
top_vector_nodes,
|
top_vector_nodes,
|
||||||
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
||||||
|
|
||||||
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
||||||
|
|
@ -728,10 +686,10 @@ async def get_relevant_nodes(
|
||||||
UNWIND combined_nodes AS combined_node
|
UNWIND combined_nodes AS combined_node
|
||||||
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
||||||
|
|
||||||
RETURN
|
RETURN
|
||||||
node.uuid AS search_node_uuid,
|
node.uuid AS search_node_uuid,
|
||||||
[x IN deduped_nodes | {
|
[x IN deduped_nodes | {
|
||||||
uuid: x.uuid,
|
uuid: x.uuid,
|
||||||
name: x.name,
|
name: x.name,
|
||||||
name_embedding: x.name_embedding,
|
name_embedding: x.name_embedding,
|
||||||
group_id: x.group_id,
|
group_id: x.group_id,
|
||||||
|
|
@ -755,12 +713,12 @@ async def get_relevant_nodes(
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(
|
results, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
params=query_params,
|
|
||||||
nodes=query_nodes,
|
nodes=query_nodes,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
**query_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
relevant_nodes_dict: dict[str, list[EntityNode]] = {
|
relevant_nodes_dict: dict[str, list[EntityNode]] = {
|
||||||
|
|
@ -825,11 +783,11 @@ async def get_relevant_edges(
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(
|
results, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
params=query_params,
|
|
||||||
edges=[edge.model_dump() for edge in edges],
|
edges=[edge.model_dump() for edge in edges],
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
**query_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
||||||
|
|
@ -895,11 +853,11 @@ async def get_edge_invalidation_candidates(
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(
|
results, _, _ = await driver.execute_query(
|
||||||
query,
|
query,
|
||||||
params=query_params,
|
|
||||||
edges=[edge.model_dump() for edge in edges],
|
edges=[edge.model_dump() for edge in edges],
|
||||||
limit=limit,
|
limit=limit,
|
||||||
min_score=min_score,
|
min_score=min_score,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
|
**query_params,
|
||||||
)
|
)
|
||||||
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
||||||
result['search_edge_uuid']: [
|
result['search_edge_uuid']: [
|
||||||
|
|
@ -943,18 +901,17 @@ async def node_distance_reranker(
|
||||||
scores: dict[str, float] = {center_node_uuid: 0.0}
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
||||||
|
|
||||||
# Find the shortest path to center node
|
# Find the shortest path to center node
|
||||||
query = """
|
results, header, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
UNWIND $node_uuids AS node_uuid
|
UNWIND $node_uuids AS node_uuid
|
||||||
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
||||||
RETURN 1 AS score, node_uuid AS uuid
|
RETURN 1 AS score, node_uuid AS uuid
|
||||||
"""
|
""",
|
||||||
results, header, _ = await driver.execute_query(
|
|
||||||
query,
|
|
||||||
node_uuids=filtered_uuids,
|
node_uuids=filtered_uuids,
|
||||||
center_uuid=center_node_uuid,
|
center_uuid=center_node_uuid,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
if driver.provider == 'falkordb':
|
if driver.provider == GraphProvider.FALKORDB:
|
||||||
results = [dict(zip(header, row, strict=True)) for row in results]
|
results = [dict(zip(header, row, strict=True)) for row in results]
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|
@ -987,13 +944,12 @@ async def episode_mentions_reranker(
|
||||||
scores: dict[str, float] = {}
|
scores: dict[str, float] = {}
|
||||||
|
|
||||||
# Find the shortest path to center node
|
# Find the shortest path to center node
|
||||||
query = """
|
results, _, _ = await driver.execute_query(
|
||||||
UNWIND $node_uuids AS node_uuid
|
"""
|
||||||
|
UNWIND $node_uuids AS node_uuid
|
||||||
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
||||||
RETURN count(*) AS score, n.uuid AS uuid
|
RETURN count(*) AS score, n.uuid AS uuid
|
||||||
"""
|
""",
|
||||||
results, _, _ = await driver.execute_query(
|
|
||||||
query,
|
|
||||||
node_uuids=sorted_uuids,
|
node_uuids=sorted_uuids,
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -1053,15 +1009,16 @@ def maximal_marginal_relevance(
|
||||||
async def get_embeddings_for_nodes(
|
async def get_embeddings_for_nodes(
|
||||||
driver: GraphDriver, nodes: list[EntityNode]
|
driver: GraphDriver, nodes: list[EntityNode]
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
query: LiteralString = """MATCH (n:Entity)
|
|
||||||
WHERE n.uuid IN $node_uuids
|
|
||||||
RETURN DISTINCT
|
|
||||||
n.uuid AS uuid,
|
|
||||||
n.name_embedding AS name_embedding
|
|
||||||
"""
|
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(
|
results, _, _ = await driver.execute_query(
|
||||||
query, node_uuids=[node.uuid for node in nodes], routing_='r'
|
"""
|
||||||
|
MATCH (n:Entity)
|
||||||
|
WHERE n.uuid IN $node_uuids
|
||||||
|
RETURN DISTINCT
|
||||||
|
n.uuid AS uuid,
|
||||||
|
n.name_embedding AS name_embedding
|
||||||
|
""",
|
||||||
|
node_uuids=[node.uuid for node in nodes],
|
||||||
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings_dict: dict[str, list[float]] = {}
|
embeddings_dict: dict[str, list[float]] = {}
|
||||||
|
|
@ -1077,15 +1034,14 @@ async def get_embeddings_for_nodes(
|
||||||
async def get_embeddings_for_communities(
|
async def get_embeddings_for_communities(
|
||||||
driver: GraphDriver, communities: list[CommunityNode]
|
driver: GraphDriver, communities: list[CommunityNode]
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
query: LiteralString = """MATCH (c:Community)
|
|
||||||
WHERE c.uuid IN $community_uuids
|
|
||||||
RETURN DISTINCT
|
|
||||||
c.uuid AS uuid,
|
|
||||||
c.name_embedding AS name_embedding
|
|
||||||
"""
|
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(
|
results, _, _ = await driver.execute_query(
|
||||||
query,
|
"""
|
||||||
|
MATCH (c:Community)
|
||||||
|
WHERE c.uuid IN $community_uuids
|
||||||
|
RETURN DISTINCT
|
||||||
|
c.uuid AS uuid,
|
||||||
|
c.name_embedding AS name_embedding
|
||||||
|
""",
|
||||||
community_uuids=[community.uuid for community in communities],
|
community_uuids=[community.uuid for community in communities],
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
@ -1103,15 +1059,14 @@ async def get_embeddings_for_communities(
|
||||||
async def get_embeddings_for_edges(
|
async def get_embeddings_for_edges(
|
||||||
driver: GraphDriver, edges: list[EntityEdge]
|
driver: GraphDriver, edges: list[EntityEdge]
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
||||||
WHERE e.uuid IN $edge_uuids
|
|
||||||
RETURN DISTINCT
|
|
||||||
e.uuid AS uuid,
|
|
||||||
e.fact_embedding AS fact_embedding
|
|
||||||
"""
|
|
||||||
|
|
||||||
results, _, _ = await driver.execute_query(
|
results, _, _ = await driver.execute_query(
|
||||||
query,
|
"""
|
||||||
|
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
||||||
|
WHERE e.uuid IN $edge_uuids
|
||||||
|
RETURN DISTINCT
|
||||||
|
e.uuid AS uuid,
|
||||||
|
e.fact_embedding AS fact_embedding
|
||||||
|
""",
|
||||||
edge_uuids=[edge.uuid for edge in edges],
|
edge_uuids=[edge.uuid for edge in edges],
|
||||||
routing_='r',
|
routing_='r',
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -25,17 +25,15 @@ from typing_extensions import Any
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
||||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
||||||
from graphiti_core.embedder import EmbedderClient
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.graph_queries import (
|
|
||||||
get_entity_edge_save_bulk_query,
|
|
||||||
get_entity_node_save_bulk_query,
|
|
||||||
)
|
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import normalize_l2, semaphore_gather
|
from graphiti_core.helpers import normalize_l2, semaphore_gather
|
||||||
from graphiti_core.models.edges.edge_db_queries import (
|
from graphiti_core.models.edges.edge_db_queries import (
|
||||||
EPISODIC_EDGE_SAVE_BULK,
|
EPISODIC_EDGE_SAVE_BULK,
|
||||||
|
get_entity_edge_save_bulk_query,
|
||||||
)
|
)
|
||||||
from graphiti_core.models.nodes.node_db_queries import (
|
from graphiti_core.models.nodes.node_db_queries import (
|
||||||
EPISODIC_NODE_SAVE_BULK,
|
EPISODIC_NODE_SAVE_BULK,
|
||||||
|
get_entity_node_save_bulk_query,
|
||||||
)
|
)
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
|
|
@ -158,7 +156,7 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
edges.append(edge_data)
|
edges.append(edge_data)
|
||||||
|
|
||||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||||
entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider)
|
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
|
||||||
await tx.run(entity_node_save_bulk, nodes=nodes)
|
await tx.run(entity_node_save_bulk, nodes=nodes)
|
||||||
await tx.run(
|
await tx.run(
|
||||||
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ async def get_community_clusters(
|
||||||
group_id_values, _, _ = await driver.execute_query(
|
group_id_values, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
|
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
|
||||||
RETURN
|
RETURN
|
||||||
collect(DISTINCT n.group_id) AS group_ids
|
collect(DISTINCT n.group_id) AS group_ids
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
@ -233,10 +233,10 @@ async def determine_entity_community(
|
||||||
"""
|
"""
|
||||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
||||||
RETURN
|
RETURN
|
||||||
c.uuid As uuid,
|
c.uuid AS uuid,
|
||||||
c.name AS name,
|
c.name AS name,
|
||||||
c.group_id AS group_id,
|
c.group_id AS group_id,
|
||||||
c.created_at AS created_at,
|
c.created_at AS created_at,
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
""",
|
""",
|
||||||
entity_uuid=entity.uuid,
|
entity_uuid=entity.uuid,
|
||||||
|
|
@ -250,10 +250,10 @@ async def determine_entity_community(
|
||||||
"""
|
"""
|
||||||
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
||||||
RETURN
|
RETURN
|
||||||
c.uuid As uuid,
|
c.uuid AS uuid,
|
||||||
c.name AS name,
|
c.name AS name,
|
||||||
c.group_id AS group_id,
|
c.group_id AS group_id,
|
||||||
c.created_at AS created_at,
|
c.created_at AS created_at,
|
||||||
c.summary AS summary
|
c.summary AS summary
|
||||||
""",
|
""",
|
||||||
entity_uuid=entity.uuid,
|
entity_uuid=entity.uuid,
|
||||||
|
|
@ -286,11 +286,11 @@ async def determine_entity_community(
|
||||||
|
|
||||||
async def update_community(
|
async def update_community(
|
||||||
driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
|
driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
|
||||||
):
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||||
community, is_new = await determine_entity_community(driver, entity)
|
community, is_new = await determine_entity_community(driver, entity)
|
||||||
|
|
||||||
if community is None:
|
if community is None:
|
||||||
return
|
return [], []
|
||||||
|
|
||||||
new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
|
new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
|
||||||
new_name = await generate_summary_description(llm_client, new_summary)
|
new_name = await generate_summary_description(llm_client, new_summary)
|
||||||
|
|
@ -298,10 +298,14 @@ async def update_community(
|
||||||
community.summary = new_summary
|
community.summary = new_summary
|
||||||
community.name = new_name
|
community.name = new_name
|
||||||
|
|
||||||
|
community_edges = []
|
||||||
if is_new:
|
if is_new:
|
||||||
community_edge = (build_community_edges([entity], community, utc_now()))[0]
|
community_edge = (build_community_edges([entity], community, utc_now()))[0]
|
||||||
await community_edge.save(driver)
|
await community_edge.save(driver)
|
||||||
|
community_edges.append(community_edge)
|
||||||
|
|
||||||
await community.generate_name_embedding(embedder)
|
await community.generate_name_embedding(embedder)
|
||||||
|
|
||||||
await community.save(driver)
|
await community.save(driver)
|
||||||
|
|
||||||
|
return [community], community_edges
|
||||||
|
|
|
||||||
|
|
@ -15,14 +15,15 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime
|
||||||
|
|
||||||
from typing_extensions import LiteralString
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
||||||
from graphiti_core.helpers import parse_db_date, semaphore_gather
|
from graphiti_core.helpers import semaphore_gather
|
||||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN
|
||||||
|
from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record
|
||||||
|
|
||||||
EPISODE_WINDOW_LEN = 3
|
EPISODE_WINDOW_LEN = 3
|
||||||
|
|
||||||
|
|
@ -33,8 +34,8 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
||||||
if delete_existing:
|
if delete_existing:
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
"""
|
"""
|
||||||
SHOW INDEXES YIELD name
|
SHOW INDEXES YIELD name
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
index_names = [record['name'] for record in records]
|
index_names = [record['name'] for record in records]
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
|
|
@ -108,19 +109,16 @@ async def retrieve_episodes(
|
||||||
|
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
MATCH (e:Episodic)
|
||||||
"""
|
WHERE e.valid_at <= $reference_time
|
||||||
|
"""
|
||||||
+ group_id_filter
|
+ group_id_filter
|
||||||
+ source_filter
|
+ source_filter
|
||||||
+ """
|
+ """
|
||||||
RETURN e.content AS content,
|
RETURN
|
||||||
e.created_at AS created_at,
|
"""
|
||||||
e.valid_at AS valid_at,
|
+ EPISODIC_NODE_RETURN
|
||||||
e.uuid AS uuid,
|
+ """
|
||||||
e.group_id AS group_id,
|
|
||||||
e.name AS name,
|
|
||||||
e.source_description AS source_description,
|
|
||||||
e.source AS source
|
|
||||||
ORDER BY e.valid_at DESC
|
ORDER BY e.valid_at DESC
|
||||||
LIMIT $num_episodes
|
LIMIT $num_episodes
|
||||||
"""
|
"""
|
||||||
|
|
@ -133,18 +131,5 @@ async def retrieve_episodes(
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
episodes = [
|
episodes = [get_episodic_node_from_record(record) for record in result]
|
||||||
EpisodicNode(
|
|
||||||
content=record['content'],
|
|
||||||
created_at=parse_db_date(record['created_at'])
|
|
||||||
or datetime.min.replace(tzinfo=timezone.utc),
|
|
||||||
valid_at=parse_db_date(record['valid_at']) or datetime.min.replace(tzinfo=timezone.utc),
|
|
||||||
uuid=record['uuid'],
|
|
||||||
group_id=record['group_id'],
|
|
||||||
source=EpisodeType.from_str(record['source']),
|
|
||||||
name=record['name'],
|
|
||||||
source_description=record['source_description'],
|
|
||||||
)
|
|
||||||
for record in result
|
|
||||||
]
|
|
||||||
return list(reversed(episodes)) # Return in chronological order
|
return list(reversed(episodes)) # Return in chronological order
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@ name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.18.0"
|
version = "0.18.0"
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
{ "name" = "Daniel Chalef", "email" = "daniel@getzep.com" },
|
{ name = "Daniel Chalef", email = "daniel@getzep.com" },
|
||||||
]
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphProvider
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
|
from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
|
||||||
|
|
||||||
|
|
@ -48,7 +50,7 @@ class TestFalkorDriver:
|
||||||
driver = FalkorDriver(
|
driver = FalkorDriver(
|
||||||
host='test-host', port='1234', username='test-user', password='test-pass'
|
host='test-host', port='1234', username='test-user', password='test-pass'
|
||||||
)
|
)
|
||||||
assert driver.provider == 'falkordb'
|
assert driver.provider == GraphProvider.FALKORDB
|
||||||
mock_falkor_db.assert_called_once_with(
|
mock_falkor_db.assert_called_once_with(
|
||||||
host='test-host', port='1234', username='test-user', password='test-pass'
|
host='test-host', port='1234', username='test-user', password='test-pass'
|
||||||
)
|
)
|
||||||
|
|
@ -59,14 +61,14 @@ class TestFalkorDriver:
|
||||||
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
|
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
|
||||||
mock_falkor_db = MagicMock()
|
mock_falkor_db = MagicMock()
|
||||||
driver = FalkorDriver(falkor_db=mock_falkor_db)
|
driver = FalkorDriver(falkor_db=mock_falkor_db)
|
||||||
assert driver.provider == 'falkordb'
|
assert driver.provider == GraphProvider.FALKORDB
|
||||||
assert driver.client is mock_falkor_db
|
assert driver.client is mock_falkor_db
|
||||||
mock_falkor_db_class.assert_not_called()
|
mock_falkor_db_class.assert_not_called()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
||||||
def test_provider(self):
|
def test_provider(self):
|
||||||
"""Test driver provider identification."""
|
"""Test driver provider identification."""
|
||||||
assert self.driver.provider == 'falkordb'
|
assert self.driver.provider == GraphProvider.FALKORDB
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
||||||
def test_get_graph_with_name(self):
|
def test_get_graph_with_name(self):
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,68 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.helpers import lucene_sanitize
|
from graphiti_core.helpers import lucene_sanitize
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
HAS_NEO4J = False
|
||||||
|
HAS_FALKORDB = False
|
||||||
|
if os.getenv('DISABLE_NEO4J') is None:
|
||||||
|
try:
|
||||||
|
from graphiti_core.driver.neo4j_driver import Neo4jDriver
|
||||||
|
|
||||||
|
HAS_NEO4J = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if os.getenv('DISABLE_FALKORDB') is None:
|
||||||
|
try:
|
||||||
|
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
||||||
|
|
||||||
|
HAS_FALKORDB = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
|
||||||
|
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
|
||||||
|
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
|
||||||
|
|
||||||
|
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
|
||||||
|
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
||||||
|
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
||||||
|
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_driver(driver_name: str) -> GraphDriver:
|
||||||
|
if driver_name == 'neo4j':
|
||||||
|
return Neo4jDriver(
|
||||||
|
uri=NEO4J_URI,
|
||||||
|
user=NEO4J_USER,
|
||||||
|
password=NEO4J_PASSWORD,
|
||||||
|
)
|
||||||
|
elif driver_name == 'falkordb':
|
||||||
|
return FalkorDriver(
|
||||||
|
host=FALKORDB_HOST,
|
||||||
|
port=int(FALKORDB_PORT),
|
||||||
|
username=FALKORDB_USER,
|
||||||
|
password=FALKORDB_PASSWORD,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Driver {driver_name} not available')
|
||||||
|
|
||||||
|
|
||||||
|
drivers: list[str] = []
|
||||||
|
if HAS_NEO4J:
|
||||||
|
drivers.append('neo4j')
|
||||||
|
if HAS_FALKORDB:
|
||||||
|
drivers.append('falkordb')
|
||||||
|
|
||||||
|
|
||||||
def test_lucene_sanitize():
|
def test_lucene_sanitize():
|
||||||
# Call the function with test data
|
# Call the function with test data
|
||||||
|
|
|
||||||
384
tests/test_edge_int.py
Normal file
384
tests/test_edge_int.py
Normal file
|
|
@ -0,0 +1,384 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
|
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||||
|
from graphiti_core.embedder.openai import OpenAIEmbedder
|
||||||
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||||
|
from tests.helpers_test import drivers, get_driver
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
pytest_plugins = ('pytest_asyncio',)
|
||||||
|
|
||||||
|
group_id = f'test_group_{str(uuid4())}'
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging():
|
||||||
|
# Create a logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||||
|
|
||||||
|
# Create console handler and set level to INFO
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# Create formatter
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
# Add formatter to console handler
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# Add console handler to logger
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'driver',
|
||||||
|
drivers,
|
||||||
|
ids=drivers,
|
||||||
|
)
|
||||||
|
async def test_episodic_edge(driver):
|
||||||
|
graph_driver = get_driver(driver)
|
||||||
|
embedder = OpenAIEmbedder()
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
episode_node = EpisodicNode(
|
||||||
|
name='test_episode',
|
||||||
|
labels=[],
|
||||||
|
created_at=now,
|
||||||
|
valid_at=now,
|
||||||
|
source=EpisodeType.message,
|
||||||
|
source_description='conversation message',
|
||||||
|
content='Alice likes Bob',
|
||||||
|
entity_edges=[],
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
node_count = await get_node_count(graph_driver, episode_node.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
await episode_node.save(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, episode_node.uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
|
||||||
|
alice_node = EntityNode(
|
||||||
|
name='Alice',
|
||||||
|
labels=[],
|
||||||
|
created_at=now,
|
||||||
|
summary='Alice summary',
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
await alice_node.generate_name_embedding(embedder)
|
||||||
|
|
||||||
|
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
await alice_node.save(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
|
||||||
|
episodic_edge = EpisodicEdge(
|
||||||
|
source_node_uuid=episode_node.uuid,
|
||||||
|
target_node_uuid=alice_node.uuid,
|
||||||
|
created_at=now,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
|
||||||
|
assert edge_count == 0
|
||||||
|
await episodic_edge.save(graph_driver)
|
||||||
|
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
|
||||||
|
assert edge_count == 1
|
||||||
|
|
||||||
|
retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
|
||||||
|
assert retrieved.uuid == episodic_edge.uuid
|
||||||
|
assert retrieved.source_node_uuid == episode_node.uuid
|
||||||
|
assert retrieved.target_node_uuid == alice_node.uuid
|
||||||
|
assert retrieved.created_at == now
|
||||||
|
assert retrieved.group_id == group_id
|
||||||
|
|
||||||
|
retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
|
||||||
|
assert len(retrieved) == 1
|
||||||
|
assert retrieved[0].uuid == episodic_edge.uuid
|
||||||
|
assert retrieved[0].source_node_uuid == episode_node.uuid
|
||||||
|
assert retrieved[0].target_node_uuid == alice_node.uuid
|
||||||
|
assert retrieved[0].created_at == now
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
|
||||||
|
retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
|
||||||
|
assert len(retrieved) == 1
|
||||||
|
assert retrieved[0].uuid == episodic_edge.uuid
|
||||||
|
assert retrieved[0].source_node_uuid == episode_node.uuid
|
||||||
|
assert retrieved[0].target_node_uuid == alice_node.uuid
|
||||||
|
assert retrieved[0].created_at == now
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
|
||||||
|
await episodic_edge.delete(graph_driver)
|
||||||
|
edge_count = await get_edge_count(graph_driver, episodic_edge.uuid)
|
||||||
|
assert edge_count == 0
|
||||||
|
|
||||||
|
await episode_node.delete(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, episode_node.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
await alice_node.delete(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
await graph_driver.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'driver',
|
||||||
|
drivers,
|
||||||
|
ids=drivers,
|
||||||
|
)
|
||||||
|
async def test_entity_edge(driver):
|
||||||
|
graph_driver = get_driver(driver)
|
||||||
|
embedder = OpenAIEmbedder()
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
alice_node = EntityNode(
|
||||||
|
name='Alice',
|
||||||
|
labels=[],
|
||||||
|
created_at=now,
|
||||||
|
summary='Alice summary',
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
await alice_node.generate_name_embedding(embedder)
|
||||||
|
|
||||||
|
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
await alice_node.save(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
|
||||||
|
bob_node = EntityNode(
|
||||||
|
name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
|
||||||
|
)
|
||||||
|
await bob_node.generate_name_embedding(embedder)
|
||||||
|
|
||||||
|
node_count = await get_node_count(graph_driver, bob_node.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
await bob_node.save(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, bob_node.uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
|
||||||
|
entity_edge = EntityEdge(
|
||||||
|
source_node_uuid=alice_node.uuid,
|
||||||
|
target_node_uuid=bob_node.uuid,
|
||||||
|
created_at=now,
|
||||||
|
name='likes',
|
||||||
|
fact='Alice likes Bob',
|
||||||
|
episodes=[],
|
||||||
|
expired_at=now,
|
||||||
|
valid_at=now,
|
||||||
|
invalid_at=now,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
edge_embedding = await entity_edge.generate_embedding(embedder)
|
||||||
|
|
||||||
|
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
|
||||||
|
assert edge_count == 0
|
||||||
|
await entity_edge.save(graph_driver)
|
||||||
|
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
|
||||||
|
assert edge_count == 1
|
||||||
|
|
||||||
|
retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
|
||||||
|
assert retrieved.uuid == entity_edge.uuid
|
||||||
|
assert retrieved.source_node_uuid == alice_node.uuid
|
||||||
|
assert retrieved.target_node_uuid == bob_node.uuid
|
||||||
|
assert retrieved.created_at == now
|
||||||
|
assert retrieved.group_id == group_id
|
||||||
|
|
||||||
|
retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
|
||||||
|
assert len(retrieved) == 1
|
||||||
|
assert retrieved[0].uuid == entity_edge.uuid
|
||||||
|
assert retrieved[0].source_node_uuid == alice_node.uuid
|
||||||
|
assert retrieved[0].target_node_uuid == bob_node.uuid
|
||||||
|
assert retrieved[0].created_at == now
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
|
||||||
|
retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
|
||||||
|
assert len(retrieved) == 1
|
||||||
|
assert retrieved[0].uuid == entity_edge.uuid
|
||||||
|
assert retrieved[0].source_node_uuid == alice_node.uuid
|
||||||
|
assert retrieved[0].target_node_uuid == bob_node.uuid
|
||||||
|
assert retrieved[0].created_at == now
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
|
||||||
|
retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
|
||||||
|
assert len(retrieved) == 1
|
||||||
|
assert retrieved[0].uuid == entity_edge.uuid
|
||||||
|
assert retrieved[0].source_node_uuid == alice_node.uuid
|
||||||
|
assert retrieved[0].target_node_uuid == bob_node.uuid
|
||||||
|
assert retrieved[0].created_at == now
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
|
||||||
|
await entity_edge.load_fact_embedding(graph_driver)
|
||||||
|
assert np.allclose(entity_edge.fact_embedding, edge_embedding)
|
||||||
|
|
||||||
|
await entity_edge.delete(graph_driver)
|
||||||
|
edge_count = await get_edge_count(graph_driver, entity_edge.uuid)
|
||||||
|
assert edge_count == 0
|
||||||
|
|
||||||
|
await alice_node.delete(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
await bob_node.delete(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, bob_node.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
await graph_driver.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'driver',
|
||||||
|
drivers,
|
||||||
|
ids=drivers,
|
||||||
|
)
|
||||||
|
async def test_community_edge(driver):
|
||||||
|
graph_driver = get_driver(driver)
|
||||||
|
embedder = OpenAIEmbedder()
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
community_node_1 = CommunityNode(
|
||||||
|
name='Community A',
|
||||||
|
group_id=group_id,
|
||||||
|
summary='Community A summary',
|
||||||
|
)
|
||||||
|
await community_node_1.generate_name_embedding(embedder)
|
||||||
|
node_count = await get_node_count(graph_driver, community_node_1.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
await community_node_1.save(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, community_node_1.uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
|
||||||
|
community_node_2 = CommunityNode(
|
||||||
|
name='Community B',
|
||||||
|
group_id=group_id,
|
||||||
|
summary='Community B summary',
|
||||||
|
)
|
||||||
|
await community_node_2.generate_name_embedding(embedder)
|
||||||
|
node_count = await get_node_count(graph_driver, community_node_2.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
await community_node_2.save(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, community_node_2.uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
|
||||||
|
alice_node = EntityNode(
|
||||||
|
name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
|
||||||
|
)
|
||||||
|
await alice_node.generate_name_embedding(embedder)
|
||||||
|
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
await alice_node.save(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
|
||||||
|
community_edge = CommunityEdge(
|
||||||
|
source_node_uuid=community_node_1.uuid,
|
||||||
|
target_node_uuid=community_node_2.uuid,
|
||||||
|
created_at=now,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
|
||||||
|
assert edge_count == 0
|
||||||
|
await community_edge.save(graph_driver)
|
||||||
|
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
|
||||||
|
assert edge_count == 1
|
||||||
|
|
||||||
|
retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
|
||||||
|
assert retrieved.uuid == community_edge.uuid
|
||||||
|
assert retrieved.source_node_uuid == community_node_1.uuid
|
||||||
|
assert retrieved.target_node_uuid == community_node_2.uuid
|
||||||
|
assert retrieved.created_at == now
|
||||||
|
assert retrieved.group_id == group_id
|
||||||
|
|
||||||
|
retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
|
||||||
|
assert len(retrieved) == 1
|
||||||
|
assert retrieved[0].uuid == community_edge.uuid
|
||||||
|
assert retrieved[0].source_node_uuid == community_node_1.uuid
|
||||||
|
assert retrieved[0].target_node_uuid == community_node_2.uuid
|
||||||
|
assert retrieved[0].created_at == now
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
|
||||||
|
retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
|
||||||
|
assert len(retrieved) == 1
|
||||||
|
assert retrieved[0].uuid == community_edge.uuid
|
||||||
|
assert retrieved[0].source_node_uuid == community_node_1.uuid
|
||||||
|
assert retrieved[0].target_node_uuid == community_node_2.uuid
|
||||||
|
assert retrieved[0].created_at == now
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
|
||||||
|
await community_edge.delete(graph_driver)
|
||||||
|
edge_count = await get_edge_count(graph_driver, community_edge.uuid)
|
||||||
|
assert edge_count == 0
|
||||||
|
|
||||||
|
await alice_node.delete(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, alice_node.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
await community_node_1.delete(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, community_node_1.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
await community_node_2.delete(graph_driver)
|
||||||
|
node_count = await get_node_count(graph_driver, community_node_2.uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
await graph_driver.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_node_count(driver: GraphDriver, uuid: str):
|
||||||
|
results, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n {uuid: $uuid})
|
||||||
|
RETURN COUNT(n) as count
|
||||||
|
""",
|
||||||
|
uuid=uuid,
|
||||||
|
)
|
||||||
|
return int(results[0]['count'])
|
||||||
|
|
||||||
|
|
||||||
|
async def get_edge_count(driver: GraphDriver, uuid: str):
|
||||||
|
results, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n)-[e {uuid: $uuid}]->(m)
|
||||||
|
RETURN COUNT(e) as count
|
||||||
|
UNION ALL
|
||||||
|
MATCH (n)-[e:RELATES_TO]->(m {uuid: $uuid})-[e2:RELATES_TO]->(m2)
|
||||||
|
RETURN COUNT(m) as count
|
||||||
|
""",
|
||||||
|
uuid=uuid,
|
||||||
|
)
|
||||||
|
return sum(int(result['count']) for result in results)
|
||||||
|
|
@ -14,26 +14,18 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dotenv import load_dotenv
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from graphiti_core.graphiti import Graphiti
|
from graphiti_core.graphiti import Graphiti
|
||||||
from graphiti_core.helpers import validate_excluded_entity_types
|
from graphiti_core.helpers import validate_excluded_entity_types
|
||||||
|
from tests.helpers_test import drivers, get_driver
|
||||||
|
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
pytest_plugins = ('pytest_asyncio',)
|
pytest_plugins = ('pytest_asyncio',)
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
NEO4J_URI = os.getenv('NEO4J_URI')
|
|
||||||
NEO4J_USER = os.getenv('NEO4J_USER')
|
|
||||||
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
|
|
||||||
|
|
||||||
|
|
||||||
# Test entity type definitions
|
# Test entity type definitions
|
||||||
class Person(BaseModel):
|
class Person(BaseModel):
|
||||||
|
|
@ -65,9 +57,14 @@ class Location(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_exclude_default_entity_type():
|
@pytest.mark.parametrize(
|
||||||
|
'driver',
|
||||||
|
drivers,
|
||||||
|
ids=drivers,
|
||||||
|
)
|
||||||
|
async def test_exclude_default_entity_type(driver):
|
||||||
"""Test excluding the default 'Entity' type while keeping custom types."""
|
"""Test excluding the default 'Entity' type while keeping custom types."""
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
graphiti = Graphiti(graph_driver=get_driver(driver))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await graphiti.build_indices_and_constraints()
|
await graphiti.build_indices_and_constraints()
|
||||||
|
|
@ -118,9 +115,14 @@ async def test_exclude_default_entity_type():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_exclude_specific_custom_types():
|
@pytest.mark.parametrize(
|
||||||
|
'driver',
|
||||||
|
drivers,
|
||||||
|
ids=drivers,
|
||||||
|
)
|
||||||
|
async def test_exclude_specific_custom_types(driver):
|
||||||
"""Test excluding specific custom entity types while keeping others."""
|
"""Test excluding specific custom entity types while keeping others."""
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
graphiti = Graphiti(graph_driver=get_driver(driver))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await graphiti.build_indices_and_constraints()
|
await graphiti.build_indices_and_constraints()
|
||||||
|
|
@ -177,9 +179,14 @@ async def test_exclude_specific_custom_types():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_exclude_all_types():
|
@pytest.mark.parametrize(
|
||||||
|
'driver',
|
||||||
|
drivers,
|
||||||
|
ids=drivers,
|
||||||
|
)
|
||||||
|
async def test_exclude_all_types(driver):
|
||||||
"""Test excluding all entity types (edge case)."""
|
"""Test excluding all entity types (edge case)."""
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
graphiti = Graphiti(graph_driver=get_driver(driver))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await graphiti.build_indices_and_constraints()
|
await graphiti.build_indices_and_constraints()
|
||||||
|
|
@ -221,9 +228,14 @@ async def test_exclude_all_types():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_exclude_no_types():
|
@pytest.mark.parametrize(
|
||||||
|
'driver',
|
||||||
|
drivers,
|
||||||
|
ids=drivers,
|
||||||
|
)
|
||||||
|
async def test_exclude_no_types(driver):
|
||||||
"""Test normal behavior when no types are excluded (baseline test)."""
|
"""Test normal behavior when no types are excluded (baseline test)."""
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
graphiti = Graphiti(graph_driver=get_driver(driver))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await graphiti.build_indices_and_constraints()
|
await graphiti.build_indices_and_constraints()
|
||||||
|
|
@ -299,9 +311,14 @@ def test_validation_invalid_excluded_types():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_excluded_types_parameter_validation_in_add_episode():
|
@pytest.mark.parametrize(
|
||||||
|
'driver',
|
||||||
|
drivers,
|
||||||
|
ids=drivers,
|
||||||
|
)
|
||||||
|
async def test_excluded_types_parameter_validation_in_add_episode(driver):
|
||||||
"""Test that add_episode validates excluded_entity_types parameter."""
|
"""Test that add_episode validates excluded_entity_types parameter."""
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
|
graphiti = Graphiti(graph_driver=get_driver(driver))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
entity_types = {
|
entity_types = {
|
||||||
|
|
|
||||||
|
|
@ -1,164 +0,0 @@
|
||||||
"""
|
|
||||||
Copyright 2024, Zep Software, Inc.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import unittest
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
||||||
from graphiti_core.graphiti import Graphiti
|
|
||||||
from graphiti_core.helpers import semaphore_gather
|
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
|
||||||
from graphiti_core.search.search_helpers import search_results_to_context_string
|
|
||||||
|
|
||||||
try:
|
|
||||||
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
|
||||||
|
|
||||||
HAS_FALKORDB = True
|
|
||||||
except ImportError:
|
|
||||||
FalkorDriver = None
|
|
||||||
HAS_FALKORDB = False
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.integration
|
|
||||||
|
|
||||||
pytest_plugins = ('pytest_asyncio',)
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
|
|
||||||
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
|
||||||
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
|
||||||
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging():
|
|
||||||
# Create a logger
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
|
||||||
|
|
||||||
# Create console handler and set level to INFO
|
|
||||||
console_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
console_handler.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
# Create formatter
|
|
||||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
||||||
|
|
||||||
# Add formatter to console handler
|
|
||||||
console_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
# Add console handler to logger
|
|
||||||
logger.addHandler(console_handler)
|
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
|
||||||
async def test_graphiti_falkordb_init():
|
|
||||||
logger = setup_logging()
|
|
||||||
|
|
||||||
falkor_driver = FalkorDriver(
|
|
||||||
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
|
|
||||||
)
|
|
||||||
|
|
||||||
graphiti = Graphiti(graph_driver=falkor_driver)
|
|
||||||
|
|
||||||
results = await graphiti.search_(query='Who is the user?')
|
|
||||||
|
|
||||||
pretty_results = search_results_to_context_string(results)
|
|
||||||
|
|
||||||
logger.info(pretty_results)
|
|
||||||
|
|
||||||
await graphiti.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
|
||||||
async def test_graph_falkordb_integration():
|
|
||||||
falkor_driver = FalkorDriver(
|
|
||||||
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
|
|
||||||
)
|
|
||||||
|
|
||||||
client = Graphiti(graph_driver=falkor_driver)
|
|
||||||
embedder = client.embedder
|
|
||||||
driver = client.driver
|
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
episode = EpisodicNode(
|
|
||||||
name='test_episode',
|
|
||||||
labels=[],
|
|
||||||
created_at=now,
|
|
||||||
valid_at=now,
|
|
||||||
source='message',
|
|
||||||
source_description='conversation message',
|
|
||||||
content='Alice likes Bob',
|
|
||||||
entity_edges=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
alice_node = EntityNode(
|
|
||||||
name='Alice',
|
|
||||||
labels=[],
|
|
||||||
created_at=now,
|
|
||||||
summary='Alice summary',
|
|
||||||
)
|
|
||||||
|
|
||||||
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
|
|
||||||
|
|
||||||
episodic_edge_1 = EpisodicEdge(
|
|
||||||
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
|
|
||||||
)
|
|
||||||
|
|
||||||
episodic_edge_2 = EpisodicEdge(
|
|
||||||
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
|
|
||||||
)
|
|
||||||
|
|
||||||
entity_edge = EntityEdge(
|
|
||||||
source_node_uuid=alice_node.uuid,
|
|
||||||
target_node_uuid=bob_node.uuid,
|
|
||||||
created_at=now,
|
|
||||||
name='likes',
|
|
||||||
fact='Alice likes Bob',
|
|
||||||
episodes=[],
|
|
||||||
expired_at=now,
|
|
||||||
valid_at=now,
|
|
||||||
invalid_at=now,
|
|
||||||
)
|
|
||||||
|
|
||||||
await entity_edge.generate_embedding(embedder)
|
|
||||||
|
|
||||||
nodes = [episode, alice_node, bob_node]
|
|
||||||
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
|
||||||
|
|
||||||
# test save
|
|
||||||
await semaphore_gather(*[node.save(driver) for node in nodes])
|
|
||||||
await semaphore_gather(*[edge.save(driver) for edge in edges])
|
|
||||||
|
|
||||||
# test get
|
|
||||||
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
|
|
||||||
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
|
|
||||||
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
|
|
||||||
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
|
|
||||||
|
|
||||||
# test delete
|
|
||||||
await semaphore_gather(*[node.delete(driver) for node in nodes])
|
|
||||||
await semaphore_gather(*[edge.delete(driver) for edge in edges])
|
|
||||||
|
|
||||||
await client.close()
|
|
||||||
|
|
@ -15,31 +15,19 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
||||||
from graphiti_core.graphiti import Graphiti
|
from graphiti_core.graphiti import Graphiti
|
||||||
from graphiti_core.helpers import semaphore_gather
|
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
|
||||||
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
|
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
|
||||||
from graphiti_core.search.search_helpers import search_results_to_context_string
|
from graphiti_core.search.search_helpers import search_results_to_context_string
|
||||||
from graphiti_core.utils.datetime_utils import utc_now
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
|
from tests.helpers_test import drivers, get_driver
|
||||||
|
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
pytest_plugins = ('pytest_asyncio',)
|
pytest_plugins = ('pytest_asyncio',)
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
NEO4J_URI = os.getenv('NEO4J_URI')
|
|
||||||
NEO4j_USER = os.getenv('NEO4J_USER')
|
|
||||||
NEO4j_PASSWORD = os.getenv('NEO4J_PASSWORD')
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
# Create a logger
|
# Create a logger
|
||||||
|
|
@ -63,9 +51,18 @@ def setup_logging():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_graphiti_init():
|
@pytest.mark.parametrize(
|
||||||
|
'driver',
|
||||||
|
drivers,
|
||||||
|
ids=drivers,
|
||||||
|
)
|
||||||
|
async def test_graphiti_init(driver):
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
driver = get_driver(driver)
|
||||||
|
graphiti = Graphiti(graph_driver=driver)
|
||||||
|
|
||||||
|
await graphiti.build_indices_and_constraints()
|
||||||
|
|
||||||
search_filter = SearchFilters(
|
search_filter = SearchFilters(
|
||||||
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]]
|
created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]]
|
||||||
)
|
)
|
||||||
|
|
@ -76,74 +73,6 @@ async def test_graphiti_init():
|
||||||
)
|
)
|
||||||
|
|
||||||
pretty_results = search_results_to_context_string(results)
|
pretty_results = search_results_to_context_string(results)
|
||||||
|
|
||||||
logger.info(pretty_results)
|
logger.info(pretty_results)
|
||||||
|
|
||||||
await graphiti.close()
|
await graphiti.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_graph_integration():
|
|
||||||
client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
|
||||||
embedder = client.embedder
|
|
||||||
driver = client.driver
|
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
episode = EpisodicNode(
|
|
||||||
name='test_episode',
|
|
||||||
labels=[],
|
|
||||||
created_at=now,
|
|
||||||
valid_at=now,
|
|
||||||
source='message',
|
|
||||||
source_description='conversation message',
|
|
||||||
content='Alice likes Bob',
|
|
||||||
entity_edges=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
alice_node = EntityNode(
|
|
||||||
name='Alice',
|
|
||||||
labels=[],
|
|
||||||
created_at=now,
|
|
||||||
summary='Alice summary',
|
|
||||||
)
|
|
||||||
|
|
||||||
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
|
|
||||||
|
|
||||||
episodic_edge_1 = EpisodicEdge(
|
|
||||||
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
|
|
||||||
)
|
|
||||||
|
|
||||||
episodic_edge_2 = EpisodicEdge(
|
|
||||||
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
|
|
||||||
)
|
|
||||||
|
|
||||||
entity_edge = EntityEdge(
|
|
||||||
source_node_uuid=alice_node.uuid,
|
|
||||||
target_node_uuid=bob_node.uuid,
|
|
||||||
created_at=now,
|
|
||||||
name='likes',
|
|
||||||
fact='Alice likes Bob',
|
|
||||||
episodes=[],
|
|
||||||
expired_at=now,
|
|
||||||
valid_at=now,
|
|
||||||
invalid_at=now,
|
|
||||||
)
|
|
||||||
|
|
||||||
await entity_edge.generate_embedding(embedder)
|
|
||||||
|
|
||||||
nodes = [episode, alice_node, bob_node]
|
|
||||||
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
|
||||||
|
|
||||||
# test save
|
|
||||||
await semaphore_gather(*[node.save(driver) for node in nodes])
|
|
||||||
await semaphore_gather(*[edge.save(driver) for edge in edges])
|
|
||||||
|
|
||||||
# test get
|
|
||||||
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
|
|
||||||
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
|
|
||||||
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
|
|
||||||
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
|
|
||||||
|
|
||||||
# test delete
|
|
||||||
await semaphore_gather(*[node.delete(driver) for node in nodes])
|
|
||||||
await semaphore_gather(*[edge.delete(driver) for edge in edges])
|
|
||||||
|
|
|
||||||
|
|
@ -1,139 +0,0 @@
|
||||||
"""
|
|
||||||
Copyright 2024, Zep Software, Inc.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from graphiti_core.nodes import (
|
|
||||||
CommunityNode,
|
|
||||||
EntityNode,
|
|
||||||
EpisodeType,
|
|
||||||
EpisodicNode,
|
|
||||||
)
|
|
||||||
|
|
||||||
FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
|
|
||||||
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
|
|
||||||
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
|
|
||||||
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
|
||||||
|
|
||||||
HAS_FALKORDB = True
|
|
||||||
except ImportError:
|
|
||||||
FalkorDriver = None
|
|
||||||
HAS_FALKORDB = False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_entity_node():
|
|
||||||
return EntityNode(
|
|
||||||
uuid=str(uuid4()),
|
|
||||||
name='Test Entity',
|
|
||||||
group_id='test_group',
|
|
||||||
labels=['Entity'],
|
|
||||||
name_embedding=[0.5] * 1024,
|
|
||||||
summary='Entity Summary',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_episodic_node():
|
|
||||||
return EpisodicNode(
|
|
||||||
uuid=str(uuid4()),
|
|
||||||
name='Episode 1',
|
|
||||||
group_id='test_group',
|
|
||||||
source=EpisodeType.text,
|
|
||||||
source_description='Test source',
|
|
||||||
content='Some content here',
|
|
||||||
valid_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_community_node():
|
|
||||||
return CommunityNode(
|
|
||||||
uuid=str(uuid4()),
|
|
||||||
name='Community A',
|
|
||||||
name_embedding=[0.5] * 1024,
|
|
||||||
group_id='test_group',
|
|
||||||
summary='Community summary',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.integration
|
|
||||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
|
||||||
async def test_entity_node_save_get_and_delete(sample_entity_node):
|
|
||||||
falkor_driver = FalkorDriver(
|
|
||||||
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
|
|
||||||
)
|
|
||||||
|
|
||||||
await sample_entity_node.save(falkor_driver)
|
|
||||||
|
|
||||||
retrieved = await EntityNode.get_by_uuid(falkor_driver, sample_entity_node.uuid)
|
|
||||||
assert retrieved.uuid == sample_entity_node.uuid
|
|
||||||
assert retrieved.name == 'Test Entity'
|
|
||||||
assert retrieved.group_id == 'test_group'
|
|
||||||
|
|
||||||
await sample_entity_node.delete(falkor_driver)
|
|
||||||
await falkor_driver.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.integration
|
|
||||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
|
||||||
async def test_community_node_save_get_and_delete(sample_community_node):
|
|
||||||
falkor_driver = FalkorDriver(
|
|
||||||
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
|
|
||||||
)
|
|
||||||
|
|
||||||
await sample_community_node.save(falkor_driver)
|
|
||||||
|
|
||||||
retrieved = await CommunityNode.get_by_uuid(falkor_driver, sample_community_node.uuid)
|
|
||||||
assert retrieved.uuid == sample_community_node.uuid
|
|
||||||
assert retrieved.name == 'Community A'
|
|
||||||
assert retrieved.group_id == 'test_group'
|
|
||||||
assert retrieved.summary == 'Community summary'
|
|
||||||
|
|
||||||
await sample_community_node.delete(falkor_driver)
|
|
||||||
await falkor_driver.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.integration
|
|
||||||
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
|
|
||||||
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
|
|
||||||
falkor_driver = FalkorDriver(
|
|
||||||
host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD
|
|
||||||
)
|
|
||||||
|
|
||||||
await sample_episodic_node.save(falkor_driver)
|
|
||||||
|
|
||||||
retrieved = await EpisodicNode.get_by_uuid(falkor_driver, sample_episodic_node.uuid)
|
|
||||||
assert retrieved.uuid == sample_episodic_node.uuid
|
|
||||||
assert retrieved.name == 'Episode 1'
|
|
||||||
assert retrieved.group_id == 'test_group'
|
|
||||||
assert retrieved.source == EpisodeType.text
|
|
||||||
assert retrieved.source_description == 'Test source'
|
|
||||||
assert retrieved.content == 'Some content here'
|
|
||||||
|
|
||||||
await sample_episodic_node.delete(falkor_driver)
|
|
||||||
await falkor_driver.close()
|
|
||||||
|
|
@ -14,23 +14,22 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
from datetime import datetime
|
||||||
from datetime import datetime, timezone
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from neo4j import AsyncGraphDatabase
|
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.nodes import (
|
from graphiti_core.nodes import (
|
||||||
CommunityNode,
|
CommunityNode,
|
||||||
EntityNode,
|
EntityNode,
|
||||||
EpisodeType,
|
EpisodeType,
|
||||||
EpisodicNode,
|
EpisodicNode,
|
||||||
)
|
)
|
||||||
|
from tests.helpers_test import drivers, get_driver
|
||||||
|
|
||||||
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
|
group_id = f'test_group_{str(uuid4())}'
|
||||||
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
|
|
||||||
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -38,8 +37,8 @@ def sample_entity_node():
|
||||||
return EntityNode(
|
return EntityNode(
|
||||||
uuid=str(uuid4()),
|
uuid=str(uuid4()),
|
||||||
name='Test Entity',
|
name='Test Entity',
|
||||||
group_id='test_group',
|
group_id=group_id,
|
||||||
labels=['Entity'],
|
labels=[],
|
||||||
name_embedding=[0.5] * 1024,
|
name_embedding=[0.5] * 1024,
|
||||||
summary='Entity Summary',
|
summary='Entity Summary',
|
||||||
)
|
)
|
||||||
|
|
@ -50,11 +49,11 @@ def sample_episodic_node():
|
||||||
return EpisodicNode(
|
return EpisodicNode(
|
||||||
uuid=str(uuid4()),
|
uuid=str(uuid4()),
|
||||||
name='Episode 1',
|
name='Episode 1',
|
||||||
group_id='test_group',
|
group_id=group_id,
|
||||||
source=EpisodeType.text,
|
source=EpisodeType.text,
|
||||||
source_description='Test source',
|
source_description='Test source',
|
||||||
content='Some content here',
|
content='Some content here',
|
||||||
valid_at=datetime.now(timezone.utc),
|
valid_at=datetime.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -64,59 +63,181 @@ def sample_community_node():
|
||||||
uuid=str(uuid4()),
|
uuid=str(uuid4()),
|
||||||
name='Community A',
|
name='Community A',
|
||||||
name_embedding=[0.5] * 1024,
|
name_embedding=[0.5] * 1024,
|
||||||
group_id='test_group',
|
group_id=group_id,
|
||||||
summary='Community summary',
|
summary='Community summary',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.integration
|
@pytest.mark.parametrize(
|
||||||
async def test_entity_node_save_get_and_delete(sample_entity_node):
|
'driver',
|
||||||
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
drivers,
|
||||||
await sample_entity_node.save(neo4j_driver)
|
ids=drivers,
|
||||||
retrieved = await EntityNode.get_by_uuid(neo4j_driver, sample_entity_node.uuid)
|
)
|
||||||
|
async def test_entity_node(sample_entity_node, driver):
|
||||||
|
driver = get_driver(driver)
|
||||||
|
uuid = sample_entity_node.uuid
|
||||||
|
|
||||||
|
# Create node
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
await sample_entity_node.save(driver)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
|
||||||
|
retrieved = await EntityNode.get_by_uuid(driver, sample_entity_node.uuid)
|
||||||
assert retrieved.uuid == sample_entity_node.uuid
|
assert retrieved.uuid == sample_entity_node.uuid
|
||||||
assert retrieved.name == 'Test Entity'
|
assert retrieved.name == 'Test Entity'
|
||||||
assert retrieved.group_id == 'test_group'
|
assert retrieved.group_id == group_id
|
||||||
|
|
||||||
await sample_entity_node.delete(neo4j_driver)
|
retrieved = await EntityNode.get_by_uuids(driver, [sample_entity_node.uuid])
|
||||||
|
assert retrieved[0].uuid == sample_entity_node.uuid
|
||||||
|
assert retrieved[0].name == 'Test Entity'
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
|
||||||
await neo4j_driver.close()
|
retrieved = await EntityNode.get_by_group_ids(driver, [group_id], limit=2)
|
||||||
|
assert len(retrieved) == 1
|
||||||
|
assert retrieved[0].uuid == sample_entity_node.uuid
|
||||||
|
assert retrieved[0].name == 'Test Entity'
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
|
||||||
|
await sample_entity_node.load_name_embedding(driver)
|
||||||
|
assert np.allclose(sample_entity_node.name_embedding, [0.5] * 1024)
|
||||||
|
|
||||||
|
# Delete node by uuid
|
||||||
|
await sample_entity_node.delete(driver)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
# Delete node by group id
|
||||||
|
await sample_entity_node.save(driver)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
await sample_entity_node.delete_by_group_id(driver, group_id)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
await driver.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.integration
|
@pytest.mark.parametrize(
|
||||||
async def test_community_node_save_get_and_delete(sample_community_node):
|
'driver',
|
||||||
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
drivers,
|
||||||
|
ids=drivers,
|
||||||
|
)
|
||||||
|
async def test_community_node(sample_community_node, driver):
|
||||||
|
driver = get_driver(driver)
|
||||||
|
uuid = sample_community_node.uuid
|
||||||
|
|
||||||
await sample_community_node.save(neo4j_driver)
|
# Create node
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
await sample_community_node.save(driver)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
|
||||||
retrieved = await CommunityNode.get_by_uuid(neo4j_driver, sample_community_node.uuid)
|
retrieved = await CommunityNode.get_by_uuid(driver, sample_community_node.uuid)
|
||||||
assert retrieved.uuid == sample_community_node.uuid
|
assert retrieved.uuid == sample_community_node.uuid
|
||||||
assert retrieved.name == 'Community A'
|
assert retrieved.name == 'Community A'
|
||||||
assert retrieved.group_id == 'test_group'
|
assert retrieved.group_id == group_id
|
||||||
assert retrieved.summary == 'Community summary'
|
assert retrieved.summary == 'Community summary'
|
||||||
|
|
||||||
await sample_community_node.delete(neo4j_driver)
|
retrieved = await CommunityNode.get_by_uuids(driver, [sample_community_node.uuid])
|
||||||
|
assert retrieved[0].uuid == sample_community_node.uuid
|
||||||
|
assert retrieved[0].name == 'Community A'
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
assert retrieved[0].summary == 'Community summary'
|
||||||
|
|
||||||
await neo4j_driver.close()
|
retrieved = await CommunityNode.get_by_group_ids(driver, [group_id], limit=2)
|
||||||
|
assert len(retrieved) == 1
|
||||||
|
assert retrieved[0].uuid == sample_community_node.uuid
|
||||||
|
assert retrieved[0].name == 'Community A'
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
|
||||||
|
# Delete node by uuid
|
||||||
|
await sample_community_node.delete(driver)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
# Delete node by group id
|
||||||
|
await sample_community_node.save(driver)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
await sample_community_node.delete_by_group_id(driver, group_id)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
await driver.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.integration
|
@pytest.mark.parametrize(
|
||||||
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
|
'driver',
|
||||||
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
drivers,
|
||||||
|
ids=drivers,
|
||||||
|
)
|
||||||
|
async def test_episodic_node(sample_episodic_node, driver):
|
||||||
|
driver = get_driver(driver)
|
||||||
|
uuid = sample_episodic_node.uuid
|
||||||
|
|
||||||
await sample_episodic_node.save(neo4j_driver)
|
# Create node
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
await sample_episodic_node.save(driver)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
|
||||||
retrieved = await EpisodicNode.get_by_uuid(neo4j_driver, sample_episodic_node.uuid)
|
retrieved = await EpisodicNode.get_by_uuid(driver, sample_episodic_node.uuid)
|
||||||
assert retrieved.uuid == sample_episodic_node.uuid
|
assert retrieved.uuid == sample_episodic_node.uuid
|
||||||
assert retrieved.name == 'Episode 1'
|
assert retrieved.name == 'Episode 1'
|
||||||
assert retrieved.group_id == 'test_group'
|
assert retrieved.group_id == group_id
|
||||||
assert retrieved.source == EpisodeType.text
|
assert retrieved.source == EpisodeType.text
|
||||||
assert retrieved.source_description == 'Test source'
|
assert retrieved.source_description == 'Test source'
|
||||||
assert retrieved.content == 'Some content here'
|
assert retrieved.content == 'Some content here'
|
||||||
|
assert retrieved.valid_at == sample_episodic_node.valid_at
|
||||||
|
|
||||||
await sample_episodic_node.delete(neo4j_driver)
|
retrieved = await EpisodicNode.get_by_uuids(driver, [sample_episodic_node.uuid])
|
||||||
|
assert retrieved[0].uuid == sample_episodic_node.uuid
|
||||||
|
assert retrieved[0].name == 'Episode 1'
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
assert retrieved[0].source == EpisodeType.text
|
||||||
|
assert retrieved[0].source_description == 'Test source'
|
||||||
|
assert retrieved[0].content == 'Some content here'
|
||||||
|
assert retrieved[0].valid_at == sample_episodic_node.valid_at
|
||||||
|
|
||||||
await neo4j_driver.close()
|
retrieved = await EpisodicNode.get_by_group_ids(driver, [group_id], limit=2)
|
||||||
|
assert len(retrieved) == 1
|
||||||
|
assert retrieved[0].uuid == sample_episodic_node.uuid
|
||||||
|
assert retrieved[0].name == 'Episode 1'
|
||||||
|
assert retrieved[0].group_id == group_id
|
||||||
|
assert retrieved[0].source == EpisodeType.text
|
||||||
|
assert retrieved[0].source_description == 'Test source'
|
||||||
|
assert retrieved[0].content == 'Some content here'
|
||||||
|
assert retrieved[0].valid_at == sample_episodic_node.valid_at
|
||||||
|
|
||||||
|
# Delete node by uuid
|
||||||
|
await sample_episodic_node.delete(driver)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
# Delete node by group id
|
||||||
|
await sample_episodic_node.save(driver)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 1
|
||||||
|
await sample_episodic_node.delete_by_group_id(driver, group_id)
|
||||||
|
node_count = await get_node_count(driver, uuid)
|
||||||
|
assert node_count == 0
|
||||||
|
|
||||||
|
await driver.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_node_count(driver: GraphDriver, uuid: str):
|
||||||
|
result, _, _ = await driver.execute_query(
|
||||||
|
"""
|
||||||
|
MATCH (n {uuid: $uuid})
|
||||||
|
RETURN COUNT(n) as count
|
||||||
|
""",
|
||||||
|
uuid=uuid,
|
||||||
|
)
|
||||||
|
return int(result[0]['count'])
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue