diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 9f41d595..3096524a 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -20,6 +20,15 @@ jobs: ports: - 6379:6379 options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5 + neo4j: + image: neo4j:5.26-community + ports: + - 7687:7687 + - 7474:7474 + env: + NEO4J_AUTH: neo4j/testpass + NEO4J_PLUGINS: '["apoc"]' + options: --health-cmd "cypher-shell -u neo4j -p testpass 'RETURN 1'" --health-interval 10s --health-timeout 5s --health-retries 10 steps: - uses: actions/checkout@v4 - name: Set up Python @@ -37,15 +46,33 @@ jobs: - name: Run non-integration tests env: PYTHONPATH: ${{ github.workspace }} + NEO4J_URI: bolt://localhost:7687 + NEO4J_USER: neo4j + NEO4J_PASSWORD: testpass run: | uv run pytest -m "not integration" - name: Wait for FalkorDB run: | timeout 60 bash -c 'until redis-cli -h localhost -p 6379 ping; do sleep 1; done' + - name: Wait for Neo4j + run: | + timeout 60 bash -c 'until wget -O /dev/null http://localhost:7474 >/dev/null 2>&1; do sleep 1; done' - name: Run FalkorDB integration tests env: PYTHONPATH: ${{ github.workspace }} FALKORDB_HOST: localhost FALKORDB_PORT: 6379 + DISABLE_NEO4J: 1 run: | uv run pytest tests/driver/test_falkordb_driver.py + - name: Run Neo4j integration tests + env: + PYTHONPATH: ${{ github.workspace }} + NEO4J_URI: bolt://localhost:7687 + NEO4J_USER: neo4j + NEO4J_PASSWORD: testpass + FALKORDB_HOST: localhost + FALKORDB_PORT: 6379 + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + uv run pytest tests/test_*_int.py -k "neo4j" diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 663b7859..959eb7e1 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -18,11 +18,17 @@ import copy import logging from abc import ABC, abstractmethod from collections.abc import Coroutine +from enum import Enum from typing import Any logger = logging.getLogger(__name__) +class GraphProvider(Enum): + NEO4J = 'neo4j' + FALKORDB = 'falkordb' + + class GraphDriverSession(ABC): async def __aenter__(self): return self @@ -46,7 +52,7 @@ class GraphDriverSession(ABC): class GraphDriver(ABC): - provider: str + provider: GraphProvider fulltext_syntax: str = ( '' # Neo4j (default) syntax does not require a prefix for fulltext queries ) diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index 46e14c7e..b88fd583 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -32,7 +32,7 @@ else: 'Install it with: pip install graphiti-core[falkordb]' ) from None -from graphiti_core.driver.driver import GraphDriver, GraphDriverSession +from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider logger = logging.getLogger(__name__) @@ -71,7 +71,7 @@ class FalkorDriverSession(GraphDriverSession): class FalkorDriver(GraphDriver): - provider: str = 'falkordb' + provider = GraphProvider.FALKORDB def __init__( self, @@ -119,7 +119,7 @@ class FalkorDriver(GraphDriver): # check if index already exists logger.info(f'Index already exists: {e}') return None - logger.error(f'Error executing FalkorDB query: {e}') + logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}') raise # Convert the result header to a list of strings diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 95c7f9ce..0bcc9527 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -21,13 +21,13 @@ from typing import Any from neo4j import AsyncGraphDatabase, EagerResult from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver, GraphDriverSession +from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider logger = logging.getLogger(__name__) class Neo4jDriver(GraphDriver): - provider: str = 'neo4j' + provider = GraphProvider.NEO4J def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'): super().__init__() @@ -45,7 +45,11 @@ class Neo4jDriver(GraphDriver): params = {} params.setdefault('database_', self._database) - result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs) + try: + result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs) + except Exception as e: + logger.error(f'Error executing Neo4j query: {e}\n{cypher_query_}\n{params}') + raise return result diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index cee5b870..d1c4f1d2 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -29,29 +29,17 @@ from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.helpers import parse_db_date from graphiti_core.models.edges.edge_db_queries import ( - COMMUNITY_EDGE_SAVE, - ENTITY_EDGE_SAVE, + COMMUNITY_EDGE_RETURN, + ENTITY_EDGE_RETURN, + EPISODIC_EDGE_RETURN, EPISODIC_EDGE_SAVE, + get_community_edge_save_query, + get_entity_edge_save_query, ) from graphiti_core.nodes import Node logger = logging.getLogger(__name__) -ENTITY_EDGE_RETURN: LiteralString = """ - RETURN - e.uuid AS uuid, - startNode(e).uuid AS source_node_uuid, - endNode(e).uuid AS target_node_uuid, - e.created_at AS created_at, - e.name AS name, - e.group_id AS group_id, - e.fact AS fact, - e.episodes AS episodes, - e.expired_at AS expired_at, - e.valid_at AS valid_at, - e.invalid_at AS invalid_at, - properties(e) AS attributes""" - class Edge(BaseModel, ABC): uuid: str = Field(default_factory=lambda: str(uuid4())) @@ -66,9 +54,9 @@ class Edge(BaseModel, ABC): async def delete(self, driver: GraphDriver): result = await driver.execute_query( """ - MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m) - DELETE e - """, + MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m) + DELETE e + """, uuid=self.uuid, ) @@ -107,14 +95,10 @@ class EpisodicEdge(Edge): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - """, + MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity) + RETURN + """ + + EPISODIC_EDGE_RETURN, uuid=uuid, routing_='r', ) @@ -129,15 +113,11 @@ class EpisodicEdge(Edge): async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ - MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) - WHERE e.uuid IN $uuids - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - """, + MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) + WHERE e.uuid IN $uuids + RETURN + """ + + EPISODIC_EDGE_RETURN, uuids=uuids, routing_='r', ) @@ -161,19 +141,17 @@ class EpisodicEdge(Edge): records, _, _ = await driver.execute_query( """ - MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) - WHERE e.group_id IN $group_ids - """ + MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) + WHERE e.group_id IN $group_ids + """ + cursor_query + """ - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - ORDER BY e.uuid DESC - """ + RETURN + """ + + EPISODIC_EDGE_RETURN + + """ + ORDER BY e.uuid DESC + """ + limit_query, group_ids=group_ids, uuid=uuid_cursor, @@ -221,11 +199,14 @@ class EntityEdge(Edge): return self.fact_embedding async def load_fact_embedding(self, driver: GraphDriver): - query: LiteralString = """ + records, _, _ = await driver.execute_query( + """ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) RETURN e.fact_embedding AS fact_embedding - """ - records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r') + """, + uuid=self.uuid, + routing_='r', + ) if len(records) == 0: raise EdgeNotFoundError(self.uuid) @@ -251,7 +232,7 @@ class EntityEdge(Edge): edge_data.update(self.attributes or {}) result = await driver.execute_query( - ENTITY_EDGE_SAVE, + get_entity_edge_save_query(driver.provider), edge_data=edge_data, ) @@ -263,8 +244,9 @@ class EntityEdge(Edge): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity) + RETURN + """ + ENTITY_EDGE_RETURN, uuid=uuid, routing_='r', @@ -283,9 +265,10 @@ class EntityEdge(Edge): records, _, _ = await driver.execute_query( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.uuid IN $uuids - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.uuid IN $uuids + RETURN + """ + ENTITY_EDGE_RETURN, uuids=uuids, routing_='r', @@ -314,22 +297,21 @@ class EntityEdge(Edge): else '' ) - query: LiteralString = ( + records, _, _ = await driver.execute_query( """ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) WHERE e.group_id IN $group_ids """ + cursor_query + + """ + RETURN + """ + ENTITY_EDGE_RETURN + with_embeddings_query + """ - ORDER BY e.uuid DESC - """ - + limit_query - ) - - records, _, _ = await driver.execute_query( - query, + ORDER BY e.uuid DESC + """ + + limit_query, group_ids=group_ids, uuid=uuid_cursor, limit=limit, @@ -344,13 +326,15 @@ class EntityEdge(Edge): @classmethod async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str): - query: LiteralString = ( + records, _, _ = await driver.execute_query( """ - MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) - """ - + ENTITY_EDGE_RETURN + MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) + RETURN + """ + + ENTITY_EDGE_RETURN, + node_uuid=node_uuid, + routing_='r', ) - records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r') edges = [get_entity_edge_from_record(record) for record in records] @@ -360,7 +344,7 @@ class EntityEdge(Edge): class CommunityEdge(Edge): async def save(self, driver: GraphDriver): result = await driver.execute_query( - COMMUNITY_EDGE_SAVE, + get_community_edge_save_query(driver.provider), community_uuid=self.source_node_uuid, entity_uuid=self.target_node_uuid, uuid=self.uuid, @@ -376,14 +360,10 @@ class CommunityEdge(Edge): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community) - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - """, + MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m) + RETURN + """ + + COMMUNITY_EDGE_RETURN, uuid=uuid, routing_='r', ) @@ -396,15 +376,11 @@ class CommunityEdge(Edge): async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ - MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community) - WHERE e.uuid IN $uuids - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - """, + MATCH (n:Community)-[e:HAS_MEMBER]->(m) + WHERE e.uuid IN $uuids + RETURN + """ + + COMMUNITY_EDGE_RETURN, uuids=uuids, routing_='r', ) @@ -426,19 +402,17 @@ class CommunityEdge(Edge): records, _, _ = await driver.execute_query( """ - MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community) - WHERE e.group_id IN $group_ids - """ + MATCH (n:Community)-[e:HAS_MEMBER]->(m) + WHERE e.group_id IN $group_ids + """ + cursor_query + """ - RETURN - e.uuid As uuid, - e.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - e.created_at AS created_at - ORDER BY e.uuid DESC - """ + RETURN + """ + + COMMUNITY_EDGE_RETURN + + """ + ORDER BY e.uuid DESC + """ + limit_query, group_ids=group_ids, uuid=uuid_cursor, diff --git a/graphiti_core/graph_queries.py b/graphiti_core/graph_queries.py index 10e396c0..06f4e8ab 100644 --- a/graphiti_core/graph_queries.py +++ b/graphiti_core/graph_queries.py @@ -5,16 +5,9 @@ This module provides database-agnostic query generation for Neo4j and FalkorDB, supporting index creation, fulltext search, and bulk operations. """ -from typing import Any - from typing_extensions import LiteralString -from graphiti_core.models.edges.edge_db_queries import ( - ENTITY_EDGE_SAVE_BULK, -) -from graphiti_core.models.nodes.node_db_queries import ( - ENTITY_NODE_SAVE_BULK, -) +from graphiti_core.driver.driver import GraphProvider # Mapping from Neo4j fulltext index names to FalkorDB node labels NEO4J_TO_FALKORDB_MAPPING = { @@ -25,8 +18,8 @@ NEO4J_TO_FALKORDB_MAPPING = { } -def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]: - if db_type == 'falkordb': +def get_range_indices(provider: GraphProvider) -> list[LiteralString]: + if provider == GraphProvider.FALKORDB: return [ # Entity node 'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)', @@ -41,109 +34,70 @@ def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]: # HAS_MEMBER edge 'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', ] - else: - return [ - 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', - 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', - 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)', - 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', - 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', - 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', - 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)', - 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)', - 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)', - 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)', - 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', - 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', - 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', - 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)', - 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)', - 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)', - 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)', - 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)', - 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)', - ] + + return [ + 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', + 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', + 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)', + 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', + 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', + 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', + 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)', + 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)', + 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)', + 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)', + 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', + 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', + 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', + 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)', + 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)', + 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)', + 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)', + 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)', + 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)', + ] -def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]: - if db_type == 'falkordb': +def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]: + if provider == GraphProvider.FALKORDB: return [ """CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""", """CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""", """CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""", """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""", ] - else: - return [ - """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS - FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""", - """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS - FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""", - """CREATE FULLTEXT INDEX community_name IF NOT EXISTS - FOR (n:Community) ON EACH [n.name, n.group_id]""", - """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS - FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""", - ] + + return [ + """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS + FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""", + """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS + FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""", + """CREATE FULLTEXT INDEX community_name IF NOT EXISTS + FOR (n:Community) ON EACH [n.name, n.group_id]""", + """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS + FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""", + ] -def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str: - if db_type == 'falkordb': +def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str: + if provider == GraphProvider.FALKORDB: label = NEO4J_TO_FALKORDB_MAPPING[name] return f"CALL db.idx.fulltext.queryNodes('{label}', {query})" - else: - return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' + + return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' -def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str: - if db_type == 'falkordb': +def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2' - else: - return f'vector.similarity.cosine({vec1}, {vec2})' + + return f'vector.similarity.cosine({vec1}, {vec2})' -def get_relationships_query(name: str, db_type: str = 'neo4j') -> str: - if db_type == 'falkordb': +def get_relationships_query(name: str, provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: label = NEO4J_TO_FALKORDB_MAPPING[name] return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)" - else: - return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' - -def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any: - if db_type == 'falkordb': - queries = [] - for node in nodes: - for label in node['labels']: - queries.append( - ( - f""" - UNWIND $nodes AS node - MERGE (n:Entity {{uuid: node.uuid}}) - SET n:{label} - SET n = node - WITH n, node - SET n.name_embedding = vecf32(node.name_embedding) - RETURN n.uuid AS uuid - """, - {'nodes': [node]}, - ) - ) - return queries - else: - return ENTITY_NODE_SAVE_BULK - - -def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str: - if db_type == 'falkordb': - return """ - UNWIND $entity_edges AS edge - MATCH (source:Entity {uuid: edge.source_node_uuid}) - MATCH (target:Entity {uuid: edge.target_node_uuid}) - MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target) - SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes, - created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)} - WITH r, edge - RETURN edge.uuid AS uuid""" - else: - return ENTITY_EDGE_SAVE_BULK + return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})' diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 3459f8ae..03b1e402 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -26,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.neo4j_driver import Neo4jDriver -from graphiti_core.edges import EntityEdge, EpisodicEdge +from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import ( @@ -93,8 +93,11 @@ load_dotenv() class AddEpisodeResults(BaseModel): episode: EpisodicNode + episodic_edges: list[EpisodicEdge] nodes: list[EntityNode] edges: list[EntityEdge] + communities: list[CommunityNode] + community_edges: list[CommunityEdge] class Graphiti: @@ -356,10 +359,10 @@ class Graphiti: group_id: str | None = None, uuid: str | None = None, update_communities: bool = False, - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, excluded_entity_types: list[str] | None = None, previous_episode_uuids: list[str] | None = None, - edge_types: dict[str, BaseModel] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, edge_type_map: dict[tuple[str, str], list[str]] | None = None, ) -> AddEpisodeResults: """ @@ -520,9 +523,12 @@ class Graphiti: self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder ) + communities = [] + community_edges = [] + # Update any communities if update_communities: - await semaphore_gather( + communities, community_edges = await semaphore_gather( *[ update_community(self.driver, self.llm_client, self.embedder, node) for node in nodes @@ -532,7 +538,14 @@ class Graphiti: end = time() logger.info(f'Completed add_episode in {(end - start) * 1000} ms') - return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges) + return AddEpisodeResults( + episode=episode, + episodic_edges=episodic_edges, + nodes=hydrated_nodes, + edges=entity_edges, + communities=communities, + community_edges=community_edges, + ) except Exception as e: raise e @@ -542,9 +555,9 @@ class Graphiti: self, bulk_episodes: list[RawEpisode], group_id: str | None = None, - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, excluded_entity_types: list[str] | None = None, - edge_types: dict[str, BaseModel] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, edge_type_map: dict[tuple[str, str], list[str]] | None = None, ): """ @@ -817,7 +830,9 @@ class Graphiti: except Exception as e: raise e - async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]: + async def build_communities( + self, group_ids: list[str] | None = None + ) -> tuple[list[CommunityNode], list[CommunityEdge]]: """ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising the content of these communities. @@ -846,7 +861,7 @@ class Graphiti: max_coroutines=self.max_coroutines, ) - return community_nodes + return community_nodes, community_edges async def search( self, diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 64ec9eaa..a1bbb267 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -28,6 +28,7 @@ from numpy._typing import NDArray from pydantic import BaseModel from typing_extensions import LiteralString +from graphiti_core.driver.driver import GraphProvider from graphiti_core.errors import GroupIdValidationError load_dotenv() @@ -52,12 +53,12 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None ) -def get_default_group_id(db_type: str) -> str: +def get_default_group_id(provider: GraphProvider) -> str: """ This function differentiates the default group id based on the database type. For most databases, the default group id is an empty string, while there are database types that require a specific default group id. """ - if db_type == 'falkordb': + if provider == GraphProvider.FALKORDB: return '_' else: return '' @@ -109,7 +110,7 @@ def validate_group_id(group_id: str) -> bool: def validate_excluded_entity_types( - excluded_entity_types: list[str] | None, entity_types: dict[str, BaseModel] | None = None + excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None ) -> bool: """ Validate that excluded entity types are valid type names. diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index 47b6f04f..ed1a19e5 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -14,43 +14,117 @@ See the License for the specific language governing permissions and limitations under the License. """ +from graphiti_core.driver.driver import GraphProvider + EPISODIC_EDGE_SAVE = """ - MATCH (episode:Episodic {uuid: $episode_uuid}) - MATCH (node:Entity {uuid: $entity_uuid}) - MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node) - SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at} - RETURN r.uuid AS uuid""" + MATCH (episode:Episodic {uuid: $episode_uuid}) + MATCH (node:Entity {uuid: $entity_uuid}) + MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node) + SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at} + RETURN e.uuid AS uuid +""" EPISODIC_EDGE_SAVE_BULK = """ UNWIND $episodic_edges AS edge - MATCH (episode:Episodic {uuid: edge.source_node_uuid}) - MATCH (node:Entity {uuid: edge.target_node_uuid}) - MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node) - SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at} - RETURN r.uuid AS uuid + MATCH (episode:Episodic {uuid: edge.source_node_uuid}) + MATCH (node:Entity {uuid: edge.target_node_uuid}) + MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node) + SET e = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at} + RETURN e.uuid AS uuid """ -ENTITY_EDGE_SAVE = """ - MATCH (source:Entity {uuid: $edge_data.source_uuid}) - MATCH (target:Entity {uuid: $edge_data.target_uuid}) - MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target) - SET r = $edge_data - WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding) - RETURN r.uuid AS uuid""" - -ENTITY_EDGE_SAVE_BULK = """ - UNWIND $entity_edges AS edge - MATCH (source:Entity {uuid: edge.source_node_uuid}) - MATCH (target:Entity {uuid: edge.target_node_uuid}) - MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target) - SET r = edge - WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding) - RETURN edge.uuid AS uuid +EPISODIC_EDGE_RETURN = """ + e.uuid AS uuid, + e.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.created_at AS created_at """ -COMMUNITY_EDGE_SAVE = """ - MATCH (community:Community {uuid: $community_uuid}) - MATCH (node:Entity | Community {uuid: $entity_uuid}) - MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node) - SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at} - RETURN r.uuid AS uuid""" + +def get_entity_edge_save_query(provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: + return """ + MATCH (source:Entity {uuid: $edge_data.source_uuid}) + MATCH (target:Entity {uuid: $edge_data.target_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) + SET e = $edge_data + RETURN e.uuid AS uuid + """ + + return """ + MATCH (source:Entity {uuid: $edge_data.source_uuid}) + MATCH (target:Entity {uuid: $edge_data.target_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target) + SET e = $edge_data + WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding) + RETURN e.uuid AS uuid + """ + + +def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: + return """ + UNWIND $entity_edges AS edge + MATCH (source:Entity {uuid: edge.source_node_uuid}) + MATCH (target:Entity {uuid: edge.target_node_uuid}) + MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target) + SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes, + created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)} + WITH r, edge + RETURN edge.uuid AS uuid + """ + + return """ + UNWIND $entity_edges AS edge + MATCH (source:Entity {uuid: edge.source_node_uuid}) + MATCH (target:Entity {uuid: edge.target_node_uuid}) + MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target) + SET e = edge + WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding) + RETURN edge.uuid AS uuid + """ + + +ENTITY_EDGE_RETURN = """ + e.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.group_id AS group_id, + e.name AS name, + e.fact AS fact, + e.episodes AS episodes, + e.created_at AS created_at, + e.expired_at AS expired_at, + e.valid_at AS valid_at, + e.invalid_at AS invalid_at, + properties(e) AS attributes +""" + + +def get_community_edge_save_query(provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: + return """ + MATCH (community:Community {uuid: $community_uuid}) + MATCH (node {uuid: $entity_uuid}) + MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node) + SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at} + RETURN e.uuid AS uuid + """ + + return """ + MATCH (community:Community {uuid: $community_uuid}) + MATCH (node:Entity | Community {uuid: $entity_uuid}) + MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node) + SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at} + RETURN e.uuid AS uuid + """ + + +COMMUNITY_EDGE_RETURN = """ + e.uuid AS uuid, + e.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.created_at AS created_at +""" diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 5f361e07..fdcf48f1 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -14,39 +14,120 @@ See the License for the specific language governing permissions and limitations under the License. """ +from typing import Any + +from graphiti_core.driver.driver import GraphProvider + EPISODIC_NODE_SAVE = """ - MERGE (n:Episodic {uuid: $uuid}) - SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, - entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} - RETURN n.uuid AS uuid""" + MERGE (n:Episodic {uuid: $uuid}) + SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content, + entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} + RETURN n.uuid AS uuid +""" EPISODIC_NODE_SAVE_BULK = """ UNWIND $episodes AS episode MERGE (n:Episodic {uuid: episode.uuid}) - SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, - source: episode.source, content: episode.content, + SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, + source: episode.source, content: episode.content, entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at} RETURN n.uuid AS uuid """ -ENTITY_NODE_SAVE = """ - MERGE (n:Entity {uuid: $entity_data.uuid}) - SET n:$($labels) - SET n = $entity_data - WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding) - RETURN n.uuid AS uuid""" - -ENTITY_NODE_SAVE_BULK = """ - UNWIND $nodes AS node - MERGE (n:Entity {uuid: node.uuid}) - SET n:$(node.labels) - SET n = node - WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding) - RETURN n.uuid AS uuid +EPISODIC_NODE_RETURN = """ + e.content AS content, + e.created_at AS created_at, + e.valid_at AS valid_at, + e.uuid AS uuid, + e.name AS name, + e.group_id AS group_id, + e.source_description AS source_description, + e.source AS source, + e.entity_edges AS entity_edges """ -COMMUNITY_NODE_SAVE = """ + +def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str: + if provider == GraphProvider.FALKORDB: + return f""" + MERGE (n:Entity {{uuid: $entity_data.uuid}}) + SET n:{labels} + SET n = $entity_data + RETURN n.uuid AS uuid + """ + + return f""" + MERGE (n:Entity {{uuid: $entity_data.uuid}}) + SET n:{labels} + SET n = $entity_data + WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding) + RETURN n.uuid AS uuid + """ + + +def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any: + if provider == GraphProvider.FALKORDB: + queries = [] + for node in nodes: + for label in node['labels']: + queries.append( + ( + f""" + UNWIND $nodes AS node + MERGE (n:Entity {{uuid: node.uuid}}) + SET n:{label} + SET n = node + WITH n, node + SET n.name_embedding = vecf32(node.name_embedding) + RETURN n.uuid AS uuid + """, + {'nodes': [node]}, + ) + ) + return queries + + return """ + UNWIND $nodes AS node + MERGE (n:Entity {uuid: node.uuid}) + SET n:$(node.labels) + SET n = node + WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding) + RETURN n.uuid AS uuid + """ + + +ENTITY_NODE_RETURN = """ + n.uuid AS uuid, + n.name AS name, + n.group_id AS group_id, + n.created_at AS created_at, + n.summary AS summary, + labels(n) AS labels, + properties(n) AS attributes +""" + + +def get_community_node_save_query(provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: + return """ + MERGE (n:Community {uuid: $uuid}) + SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: $name_embedding} + RETURN n.uuid AS uuid + """ + + return """ MERGE (n:Community {uuid: $uuid}) SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at} WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding) - RETURN n.uuid AS uuid""" + RETURN n.uuid AS uuid + """ + + +COMMUNITY_NODE_RETURN = """ + n.uuid AS uuid, + n.name AS name, + n.name_embedding AS name_embedding, + n.group_id AS group_id, + n.summary AS summary, + n.created_at AS created_at +""" diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index d2fb79ca..bad8fefa 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -25,29 +25,22 @@ from uuid import uuid4 from pydantic import BaseModel, Field from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver +from graphiti_core.driver.driver import GraphDriver, GraphProvider from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import NodeNotFoundError from graphiti_core.helpers import parse_db_date from graphiti_core.models.nodes.node_db_queries import ( - COMMUNITY_NODE_SAVE, - ENTITY_NODE_SAVE, + COMMUNITY_NODE_RETURN, + ENTITY_NODE_RETURN, + EPISODIC_NODE_RETURN, EPISODIC_NODE_SAVE, + get_community_node_save_query, + get_entity_node_save_query, ) from graphiti_core.utils.datetime_utils import utc_now logger = logging.getLogger(__name__) -ENTITY_NODE_RETURN: LiteralString = """ - RETURN - n.uuid As uuid, - n.name AS name, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes""" - class EpisodeType(Enum): """ @@ -96,18 +89,26 @@ class Node(BaseModel, ABC): async def save(self, driver: GraphDriver): ... async def delete(self, driver: GraphDriver): - result = await driver.execute_query( - """ - MATCH (n:Entity|Episodic|Community {uuid: $uuid}) - DETACH DELETE n - """, - uuid=self.uuid, - ) + if driver.provider == GraphProvider.FALKORDB: + for label in ['Entity', 'Episodic', 'Community']: + await driver.execute_query( + f""" + MATCH (n:{label} {{uuid: $uuid}}) + DETACH DELETE n + """, + uuid=self.uuid, + ) + else: + await driver.execute_query( + """ + MATCH (n:Entity|Episodic|Community {uuid: $uuid}) + DETACH DELETE n + """, + uuid=self.uuid, + ) logger.debug(f'Deleted Node: {self.uuid}') - return result - def __hash__(self): return hash(self.uuid) @@ -118,15 +119,23 @@ class Node(BaseModel, ABC): @classmethod async def delete_by_group_id(cls, driver: GraphDriver, group_id: str): - await driver.execute_query( - """ - MATCH (n:Entity|Episodic|Community {group_id: $group_id}) - DETACH DELETE n - """, - group_id=group_id, - ) - - return 'SUCCESS' + if driver.provider == GraphProvider.FALKORDB: + for label in ['Entity', 'Episodic', 'Community']: + await driver.execute_query( + f""" + MATCH (n:{label} {{group_id: $group_id}}) + DETACH DELETE n + """, + group_id=group_id, + ) + else: + await driver.execute_query( + """ + MATCH (n:Entity|Episodic|Community {group_id: $group_id}) + DETACH DELETE n + """, + group_id=group_id, + ) @classmethod async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ... @@ -169,17 +178,10 @@ class EpisodicNode(Node): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (e:Episodic {uuid: $uuid}) - RETURN e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.name AS name, - e.group_id AS group_id, - e.source_description AS source_description, - e.source AS source, - e.entity_edges AS entity_edges - """, + MATCH (e:Episodic {uuid: $uuid}) + RETURN + """ + + EPISODIC_NODE_RETURN, uuid=uuid, routing_='r', ) @@ -195,18 +197,11 @@ class EpisodicNode(Node): async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ - MATCH (e:Episodic) WHERE e.uuid IN $uuids + MATCH (e:Episodic) + WHERE e.uuid IN $uuids RETURN DISTINCT - e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.name AS name, - e.group_id AS group_id, - e.source_description AS source_description, - e.source AS source, - e.entity_edges AS entity_edges - """, + """ + + EPISODIC_NODE_RETURN, uuids=uuids, routing_='r', ) @@ -228,22 +223,17 @@ class EpisodicNode(Node): records, _, _ = await driver.execute_query( """ - MATCH (e:Episodic) WHERE e.group_id IN $group_ids - """ + MATCH (e:Episodic) + WHERE e.group_id IN $group_ids + """ + cursor_query + """ RETURN DISTINCT - e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.name AS name, - e.group_id AS group_id, - e.source_description AS source_description, - e.source AS source, - e.entity_edges AS entity_edges - ORDER BY e.uuid DESC - """ + """ + + EPISODIC_NODE_RETURN + + """ + ORDER BY uuid DESC + """ + limit_query, group_ids=group_ids, uuid=uuid_cursor, @@ -259,18 +249,10 @@ class EpisodicNode(Node): async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid}) + MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid}) RETURN DISTINCT - e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.name AS name, - e.group_id AS group_id, - e.source_description AS source_description, - e.source AS source, - e.entity_edges AS entity_edges - """, + """ + + EPISODIC_NODE_RETURN, entity_node_uuid=entity_node_uuid, routing_='r', ) @@ -297,11 +279,14 @@ class EntityNode(Node): return self.name_embedding async def load_name_embedding(self, driver: GraphDriver): - query: LiteralString = """ + records, _, _ = await driver.execute_query( + """ MATCH (n:Entity {uuid: $uuid}) RETURN n.name_embedding AS name_embedding - """ - records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r') + """, + uuid=self.uuid, + routing_='r', + ) if len(records) == 0: raise NodeNotFoundError(self.uuid) @@ -317,12 +302,12 @@ class EntityNode(Node): 'summary': self.summary, 'created_at': self.created_at, } - entity_data.update(self.attributes or {}) + labels = ':'.join(self.labels + ['Entity']) + result = await driver.execute_query( - ENTITY_NODE_SAVE, - labels=self.labels + ['Entity'], + get_entity_node_save_query(driver.provider, labels), entity_data=entity_data, ) @@ -332,14 +317,12 @@ class EntityNode(Node): @classmethod async def get_by_uuid(cls, driver: GraphDriver, uuid: str): - query = ( - """ - MATCH (n:Entity {uuid: $uuid}) - """ - + ENTITY_NODE_RETURN - ) records, _, _ = await driver.execute_query( - query, + """ + MATCH (n:Entity {uuid: $uuid}) + RETURN + """ + + ENTITY_NODE_RETURN, uuid=uuid, routing_='r', ) @@ -355,8 +338,10 @@ class EntityNode(Node): async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ - MATCH (n:Entity) WHERE n.uuid IN $uuids - """ + MATCH (n:Entity) + WHERE n.uuid IN $uuids + RETURN + """ + ENTITY_NODE_RETURN, uuids=uuids, routing_='r', @@ -379,22 +364,26 @@ class EntityNode(Node): limit_query: LiteralString = 'LIMIT $limit' if limit is not None else '' with_embeddings_query: LiteralString = ( """, - n.name_embedding AS name_embedding - """ + n.name_embedding AS name_embedding + """ if with_embeddings else '' ) records, _, _ = await driver.execute_query( """ - MATCH (n:Entity) WHERE n.group_id IN $group_ids - """ + MATCH (n:Entity) + WHERE n.group_id IN $group_ids + """ + cursor_query + + """ + RETURN + """ + ENTITY_NODE_RETURN + with_embeddings_query + """ - ORDER BY n.uuid DESC - """ + ORDER BY n.uuid DESC + """ + limit_query, group_ids=group_ids, uuid=uuid_cursor, @@ -413,7 +402,7 @@ class CommunityNode(Node): async def save(self, driver: GraphDriver): result = await driver.execute_query( - COMMUNITY_NODE_SAVE, + get_community_node_save_query(driver.provider), uuid=self.uuid, name=self.name, group_id=self.group_id, @@ -436,11 +425,14 @@ class CommunityNode(Node): return self.name_embedding async def load_name_embedding(self, driver: GraphDriver): - query: LiteralString = """ + records, _, _ = await driver.execute_query( + """ MATCH (c:Community {uuid: $uuid}) RETURN c.name_embedding AS name_embedding - """ - records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r') + """, + uuid=self.uuid, + routing_='r', + ) if len(records) == 0: raise NodeNotFoundError(self.uuid) @@ -451,14 +443,10 @@ class CommunityNode(Node): async def get_by_uuid(cls, driver: GraphDriver, uuid: str): records, _, _ = await driver.execute_query( """ - MATCH (n:Community {uuid: $uuid}) - RETURN - n.uuid As uuid, - n.name AS name, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary - """, + MATCH (n:Community {uuid: $uuid}) + RETURN + """ + + COMMUNITY_NODE_RETURN, uuid=uuid, routing_='r', ) @@ -474,14 +462,11 @@ class CommunityNode(Node): async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ - MATCH (n:Community) WHERE n.uuid IN $uuids - RETURN - n.uuid As uuid, - n.name AS name, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary - """, + MATCH (n:Community) + WHERE n.uuid IN $uuids + RETURN + """ + + COMMUNITY_NODE_RETURN, uuids=uuids, routing_='r', ) @@ -503,18 +488,17 @@ class CommunityNode(Node): records, _, _ = await driver.execute_query( """ - MATCH (n:Community) WHERE n.group_id IN $group_ids - """ + MATCH (n:Community) + WHERE n.group_id IN $group_ids + """ + cursor_query + """ - RETURN - n.uuid As uuid, - n.name AS name, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary - ORDER BY n.uuid DESC - """ + RETURN + """ + + COMMUNITY_NODE_RETURN + + """ + ORDER BY n.uuid DESC + """ + limit_query, group_ids=group_ids, uuid=uuid_cursor, @@ -586,6 +570,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode: async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]): if not nodes: # Handle empty list case return + name_embeddings = await embedder.create_batch([node.name for node in nodes]) for node, name_embedding in zip(nodes, name_embeddings, strict=True): node.name_embedding = name_embedding diff --git a/graphiti_core/prompts/dedupe_nodes.py b/graphiti_core/prompts/dedupe_nodes.py index 53223d2d..836b8934 100644 --- a/graphiti_core/prompts/dedupe_nodes.py +++ b/graphiti_core/prompts/dedupe_nodes.py @@ -34,7 +34,7 @@ class NodeDuplicate(BaseModel): ) duplicates: list[int] = Field( ..., - description='idx of all duplicate entities.', + description='idx of all entities that are a duplicate of the entity with the above id.', ) diff --git a/graphiti_core/prompts/extract_edges.py b/graphiti_core/prompts/extract_edges.py index 50e039a4..0d002fec 100644 --- a/graphiti_core/prompts/extract_edges.py +++ b/graphiti_core/prompts/extract_edges.py @@ -68,6 +68,10 @@ def edge(context: dict[str, Any]) -> list[Message]: Message( role='user', content=f""" + +{context['edge_types']} + + {json.dumps([ep for ep in context['previous_episodes']], indent=2)} @@ -84,10 +88,6 @@ def edge(context: dict[str, Any]) -> list[Message]: {context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions - -{context['edge_types']} - - # TASK Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE. Only extract facts that: diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index 6848bb7b..57a2a577 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -52,6 +52,13 @@ class EntityClassification(BaseModel): ) +class EntitySummary(BaseModel): + summary: str = Field( + ..., + description='Summary containing the important information about the entity. Under 250 words', + ) + + class Prompt(Protocol): extract_message: PromptVersion extract_json: PromptVersion @@ -59,6 +66,7 @@ class Prompt(Protocol): reflexion: PromptVersion classify_nodes: PromptVersion extract_attributes: PromptVersion + extract_summary: PromptVersion class Versions(TypedDict): @@ -68,6 +76,7 @@ class Versions(TypedDict): reflexion: PromptFunction classify_nodes: PromptFunction extract_attributes: PromptFunction + extract_summary: PromptFunction def extract_message(context: dict[str, Any]) -> list[Message]: @@ -75,6 +84,10 @@ def extract_message(context: dict[str, Any]) -> list[Message]: Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation.""" user_prompt = f""" + +{context['entity_types']} + + {json.dumps([ep for ep in context['previous_episodes']], indent=2)} @@ -83,10 +96,6 @@ def extract_message(context: dict[str, Any]) -> list[Message]: {context['episode_content']} - -{context['entity_types']} - - Instructions: You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE. @@ -124,15 +133,16 @@ def extract_json(context: dict[str, Any]) -> list[Message]: Your primary task is to extract and classify relevant entities from JSON files""" user_prompt = f""" + +{context['entity_types']} + + : {context['source_description']} {context['episode_content']} - -{context['entity_types']} - {context['custom_prompt']} @@ -155,13 +165,14 @@ def extract_text(context: dict[str, Any]) -> list[Message]: Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text.""" user_prompt = f""" - -{context['episode_content']} - {context['entity_types']} + +{context['episode_content']} + + Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned. For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions. Indicate the classified entity type by providing its entity_type_id. @@ -257,9 +268,39 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]: Guidelines: 1. Do not hallucinate entity property values if they cannot be found in the current context. 2. Only use the provided MESSAGES and ENTITY to set attribute values. + + + {context['node']} + + """, + ), + ] + + +def extract_summary(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are a helpful assistant that extracts entity summaries from the provided text.', + ), + Message( + role='user', + content=f""" + + + {json.dumps(context['previous_episodes'], indent=2)} + {json.dumps(context['episode_content'], indent=2)} + + + Given the above MESSAGES and the following ENTITY, update the summary that combines relevant information about the entity + from the messages and relevant information from the existing summary. + + Guidelines: + 1. Do not hallucinate entity summary information if they cannot be found in the current context. + 2. Only use the provided MESSAGES and ENTITY to set attribute values. 3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES. Summaries must be no longer than 250 words. - + {context['node']} @@ -273,6 +314,7 @@ versions: Versions = { 'extract_json': extract_json, 'extract_text': extract_text, 'reflexion': reflexion, + 'extract_summary': extract_summary, 'classify_nodes': classify_nodes, 'extract_attributes': extract_attributes, } diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index 40c12576..90abb766 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -72,7 +72,7 @@ def edge_search_filter_query_constructor( if filters.edge_types is not None: edge_types = filters.edge_types - edge_types_filter = '\nAND r.name in $edge_types' + edge_types_filter = '\nAND e.name in $edge_types' filter_query += edge_types_filter filter_params['edge_types'] = edge_types @@ -88,7 +88,7 @@ def edge_search_filter_query_constructor( filter_params['valid_at_' + str(j)] = date_filter.date and_filters = [ - '(r.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})' + '(e.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})' for j, date_filter in enumerate(or_list) ] and_filter_query = '' @@ -113,7 +113,7 @@ def edge_search_filter_query_constructor( filter_params['invalid_at_' + str(j)] = date_filter.date and_filters = [ - '(r.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})' + '(e.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})' for j, date_filter in enumerate(or_list) ] and_filter_query = '' @@ -138,7 +138,7 @@ def edge_search_filter_query_constructor( filter_params['created_at_' + str(j)] = date_filter.date and_filters = [ - '(r.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})' + '(e.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})' for j, date_filter in enumerate(or_list) ] and_filter_query = '' @@ -163,7 +163,7 @@ def edge_search_filter_query_constructor( filter_params['expired_at_' + str(j)] = date_filter.date and_filters = [ - '(r.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})' + '(e.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})' for j, date_filter in enumerate(or_list) ] and_filter_query = '' diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 1ac67396..e358ba0e 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -23,7 +23,7 @@ import numpy as np from numpy._typing import NDArray from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver +from graphiti_core.driver.driver import GraphDriver, GraphProvider from graphiti_core.edges import EntityEdge, get_entity_edge_from_record from graphiti_core.graph_queries import ( get_nodes_query, @@ -35,6 +35,8 @@ from graphiti_core.helpers import ( normalize_l2, semaphore_gather, ) +from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN +from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN from graphiti_core.nodes import ( ENTITY_NODE_RETURN, CommunityNode, @@ -100,20 +102,13 @@ async def get_mentioned_nodes( ) -> list[EntityNode]: episode_uuids = [episode.uuid for episode in episodes] - query = """ - MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids - RETURN DISTINCT - n.uuid As uuid, - n.group_id AS group_id, - n.name AS name, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes - """ - records, _, _ = await driver.execute_query( - query, + """ + MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) + WHERE episode.uuid IN $uuids + RETURN DISTINCT + """ + + ENTITY_NODE_RETURN, uuids=episode_uuids, routing_='r', ) @@ -128,18 +123,13 @@ async def get_communities_by_nodes( ) -> list[CommunityNode]: node_uuids = [node.uuid for node in nodes] - query = """ - MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids - RETURN DISTINCT - c.uuid As uuid, - c.group_id AS group_id, - c.name AS name, - c.created_at AS created_at, - c.summary AS summary - """ - records, _, _ = await driver.execute_query( - query, + """ + MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity) + WHERE m.uuid IN $uuids + RETURN DISTINCT + """ + + COMMUNITY_NODE_RETURN, uuids=node_uuids, routing_='r', ) @@ -164,38 +154,30 @@ async def edge_fulltext_search( filter_query, filter_params = edge_search_filter_query_constructor(search_filter) query = ( - get_relationships_query('edge_name_and_fact', db_type=driver.provider) + get_relationships_query('edge_name_and_fact', provider=driver.provider) + """ YIELD relationship AS rel, score - MATCH (n:Entity)-[r:RELATES_TO {uuid: rel.uuid}]->(m:Entity) - WHERE r.group_id IN $group_ids """ + MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity) + WHERE e.group_id IN $group_ids """ + filter_query + """ - WITH r, score, startNode(r) AS n, endNode(r) AS m + WITH e, score, n, m RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at, - properties(r) AS attributes - ORDER BY score DESC LIMIT $limit + """ + + ENTITY_EDGE_RETURN + + """ + ORDER BY score DESC + LIMIT $limit """ ) records, _, _ = await driver.execute_query( query, - params=filter_params, query=fuzzy_query, group_ids=group_ids, limit=limit, routing_='r', + **filter_params, ) edges = [get_entity_edge_from_record(record) for record in records] @@ -219,58 +201,47 @@ async def edge_similarity_search( filter_query, filter_params = edge_search_filter_query_constructor(search_filter) query_params.update(filter_params) - group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL' + group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL' if group_ids is not None: - group_filter_query += '\nAND r.group_id IN $group_ids' + group_filter_query += '\nAND e.group_id IN $group_ids' query_params['group_ids'] = group_ids - query_params['source_node_uuid'] = source_node_uuid - query_params['target_node_uuid'] = target_node_uuid if source_node_uuid is not None: - group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])' + query_params['source_uuid'] = source_node_uuid + group_filter_query += '\nAND (n.uuid = $source_uuid)' if target_node_uuid is not None: - group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])' + query_params['target_uuid'] = target_node_uuid + group_filter_query += '\nAND (m.uuid = $target_uuid)' query = ( RUNTIME_QUERY + """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) """ + group_filter_query + filter_query + """ - WITH DISTINCT r, """ - + get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider) + WITH DISTINCT e, n, m, """ + + get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider) + """ AS score WHERE score > $min_score RETURN - r.uuid AS uuid, - r.group_id AS group_id, - startNode(r).uuid AS source_node_uuid, - endNode(r).uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at, - properties(r) AS attributes + """ + + ENTITY_EDGE_RETURN + + """ ORDER BY score DESC LIMIT $limit """ ) - records, header, _ = await driver.execute_query( + + records, _, _ = await driver.execute_query( query, - params=query_params, search_vector=search_vector, - source_uuid=source_node_uuid, - target_uuid=target_node_uuid, - group_ids=group_ids, limit=limit, min_score=min_score, routing_='r', + **query_params, ) edges = [get_entity_edge_from_record(record) for record in records] @@ -293,41 +264,31 @@ async def edge_bfs_search( filter_query, filter_params = edge_search_filter_query_constructor(search_filter) query = ( + f""" + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity) + UNWIND relationships(path) AS rel + MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity) + WHERE e.uuid = rel.uuid + AND e.group_id IN $group_ids """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - UNWIND relationships(path) AS rel - MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) - WHERE r.uuid = rel.uuid - AND r.group_id IN $group_ids - """ + filter_query - + """ - RETURN DISTINCT - r.uuid AS uuid, - r.group_id AS group_id, - startNode(r).uuid AS source_node_uuid, - endNode(r).uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at, - properties(r) AS attributes - LIMIT $limit + + """ + RETURN DISTINCT + """ + + ENTITY_EDGE_RETURN + + """ + LIMIT $limit """ ) records, _, _ = await driver.execute_query( query, - params=filter_params, bfs_origin_node_uuids=bfs_origin_node_uuids, - depth=bfs_max_depth, group_ids=group_ids, limit=limit, routing_='r', + **filter_params, ) edges = [get_entity_edge_from_record(record) for record in records] @@ -352,23 +313,25 @@ async def node_fulltext_search( get_nodes_query(driver.provider, 'node_name_and_summary', '$query') + """ YIELD node AS n, score - WITH n, score - LIMIT $limit - WHERE n:Entity AND n.group_id IN $group_ids + WHERE n:Entity AND n.group_id IN $group_ids """ + filter_query - + ENTITY_NODE_RETURN + """ + WITH n, score ORDER BY score DESC + LIMIT $limit + RETURN """ + + ENTITY_NODE_RETURN ) - records, header, _ = await driver.execute_query( + + records, _, _ = await driver.execute_query( query, - params=filter_params, query=fuzzy_query, group_ids=group_ids, limit=limit, routing_='r', + **filter_params, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -406,22 +369,23 @@ async def node_similarity_search( WITH n, """ + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider) + """ AS score - WHERE score > $min_score""" + WHERE score > $min_score + RETURN + """ + ENTITY_NODE_RETURN + """ ORDER BY score DESC LIMIT $limit - """ + """ ) - records, header, _ = await driver.execute_query( + records, _, _ = await driver.execute_query( query, - params=query_params, search_vector=search_vector, - group_ids=group_ids, limit=limit, min_score=min_score, routing_='r', + **query_params, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -444,26 +408,29 @@ async def node_bfs_search( filter_query, filter_params = node_search_filter_query_constructor(search_filter) query = ( + f""" + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity) + WHERE n.group_id = origin.group_id + AND origin.group_id IN $group_ids """ - UNWIND $bfs_origin_node_uuids AS origin_uuid - MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - WHERE n.group_id = origin.group_id - AND origin.group_id IN $group_ids - """ + filter_query + + """ + RETURN + """ + ENTITY_NODE_RETURN + """ LIMIT $limit """ ) + records, _, _ = await driver.execute_query( query, - params=filter_params, bfs_origin_node_uuids=bfs_origin_node_uuids, - depth=bfs_max_depth, group_ids=group_ids, limit=limit, routing_='r', + **filter_params, ) nodes = [get_entity_node_from_record(record) for record in records] @@ -489,16 +456,10 @@ async def episode_fulltext_search( MATCH (e:Episodic) WHERE e.uuid = episode.uuid AND e.group_id IN $group_ids - RETURN - e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.name AS name, - e.group_id AS group_id, - e.source_description AS source_description, - e.source AS source, - e.entity_edges AS entity_edges + RETURN + """ + + EPISODIC_NODE_RETURN + + """ ORDER BY score DESC LIMIT $limit """ @@ -530,15 +491,12 @@ async def community_fulltext_search( query = ( get_nodes_query(driver.provider, 'community_name', '$query') + """ - YIELD node AS comm, score - WHERE comm.group_id IN $group_ids + YIELD node AS n, score + WHERE n.group_id IN $group_ids RETURN - comm.uuid AS uuid, - comm.group_id AS group_id, - comm.name AS name, - comm.created_at AS created_at, - comm.summary AS summary, - comm.name_embedding AS name_embedding + """ + + COMMUNITY_NODE_RETURN + + """ ORDER BY score DESC LIMIT $limit """ @@ -568,39 +526,37 @@ async def community_similarity_search( group_filter_query: LiteralString = '' if group_ids is not None: - group_filter_query += 'WHERE comm.group_id IN $group_ids' + group_filter_query += 'WHERE n.group_id IN $group_ids' query_params['group_ids'] = group_ids query = ( RUNTIME_QUERY + """ - MATCH (comm:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ - WITH comm, """ - + get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider) + WITH n, + """ + + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider) + """ AS score - WHERE score > $min_score - RETURN - comm.uuid As uuid, - comm.group_id AS group_id, - comm.name AS name, - comm.created_at AS created_at, - comm.summary AS summary, - comm.name_embedding AS name_embedding - ORDER BY score DESC - LIMIT $limit + WHERE score > $min_score + RETURN + """ + + COMMUNITY_NODE_RETURN + + """ + ORDER BY score DESC + LIMIT $limit """ ) records, _, _ = await driver.execute_query( query, search_vector=search_vector, - group_ids=group_ids, limit=limit, min_score=min_score, routing_='r', + **query_params, ) communities = [get_community_node_from_record(record) for record in records] @@ -719,8 +675,8 @@ async def get_relevant_nodes( WHERE m.group_id = $group_id WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes - WITH node, - top_vector_nodes, + WITH node, + top_vector_nodes, [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes @@ -728,10 +684,10 @@ async def get_relevant_nodes( UNWIND combined_nodes AS combined_node WITH node, collect(DISTINCT combined_node) AS deduped_nodes - RETURN + RETURN node.uuid AS search_node_uuid, [x IN deduped_nodes | { - uuid: x.uuid, + uuid: x.uuid, name: x.name, name_embedding: x.name_embedding, group_id: x.group_id, @@ -755,12 +711,12 @@ async def get_relevant_nodes( results, _, _ = await driver.execute_query( query, - params=query_params, nodes=query_nodes, group_id=group_id, limit=limit, min_score=min_score, routing_='r', + **query_params, ) relevant_nodes_dict: dict[str, list[EntityNode]] = { @@ -825,11 +781,11 @@ async def get_relevant_edges( results, _, _ = await driver.execute_query( query, - params=query_params, edges=[edge.model_dump() for edge in edges], limit=limit, min_score=min_score, routing_='r', + **query_params, ) relevant_edges_dict: dict[str, list[EntityEdge]] = { @@ -895,11 +851,11 @@ async def get_edge_invalidation_candidates( results, _, _ = await driver.execute_query( query, - params=query_params, edges=[edge.model_dump() for edge in edges], limit=limit, min_score=min_score, routing_='r', + **query_params, ) invalidation_edges_dict: dict[str, list[EntityEdge]] = { result['search_edge_uuid']: [ @@ -943,18 +899,17 @@ async def node_distance_reranker( scores: dict[str, float] = {center_node_uuid: 0.0} # Find the shortest path to center node - query = """ + results, header, _ = await driver.execute_query( + """ UNWIND $node_uuids AS node_uuid MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid}) RETURN 1 AS score, node_uuid AS uuid - """ - results, header, _ = await driver.execute_query( - query, + """, node_uuids=filtered_uuids, center_uuid=center_node_uuid, routing_='r', ) - if driver.provider == 'falkordb': + if driver.provider == GraphProvider.FALKORDB: results = [dict(zip(header, row, strict=True)) for row in results] for result in results: @@ -987,13 +942,12 @@ async def episode_mentions_reranker( scores: dict[str, float] = {} # Find the shortest path to center node - query = """ - UNWIND $node_uuids AS node_uuid + results, _, _ = await driver.execute_query( + """ + UNWIND $node_uuids AS node_uuid MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid}) RETURN count(*) AS score, n.uuid AS uuid - """ - results, _, _ = await driver.execute_query( - query, + """, node_uuids=sorted_uuids, routing_='r', ) @@ -1053,15 +1007,16 @@ def maximal_marginal_relevance( async def get_embeddings_for_nodes( driver: GraphDriver, nodes: list[EntityNode] ) -> dict[str, list[float]]: - query: LiteralString = """MATCH (n:Entity) - WHERE n.uuid IN $node_uuids - RETURN DISTINCT - n.uuid AS uuid, - n.name_embedding AS name_embedding - """ - results, _, _ = await driver.execute_query( - query, node_uuids=[node.uuid for node in nodes], routing_='r' + """ + MATCH (n:Entity) + WHERE n.uuid IN $node_uuids + RETURN DISTINCT + n.uuid AS uuid, + n.name_embedding AS name_embedding + """, + node_uuids=[node.uuid for node in nodes], + routing_='r', ) embeddings_dict: dict[str, list[float]] = {} @@ -1077,15 +1032,14 @@ async def get_embeddings_for_nodes( async def get_embeddings_for_communities( driver: GraphDriver, communities: list[CommunityNode] ) -> dict[str, list[float]]: - query: LiteralString = """MATCH (c:Community) - WHERE c.uuid IN $community_uuids - RETURN DISTINCT - c.uuid AS uuid, - c.name_embedding AS name_embedding - """ - results, _, _ = await driver.execute_query( - query, + """ + MATCH (c:Community) + WHERE c.uuid IN $community_uuids + RETURN DISTINCT + c.uuid AS uuid, + c.name_embedding AS name_embedding + """, community_uuids=[community.uuid for community in communities], routing_='r', ) @@ -1103,15 +1057,14 @@ async def get_embeddings_for_communities( async def get_embeddings_for_edges( driver: GraphDriver, edges: list[EntityEdge] ) -> dict[str, list[float]]: - query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity) - WHERE e.uuid IN $edge_uuids - RETURN DISTINCT - e.uuid AS uuid, - e.fact_embedding AS fact_embedding - """ - results, _, _ = await driver.execute_query( - query, + """ + MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity) + WHERE e.uuid IN $edge_uuids + RETURN DISTINCT + e.uuid AS uuid, + e.fact_embedding AS fact_embedding + """, edge_uuids=[edge.uuid for edge in edges], routing_='r', ) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 8d5e8a68..b80c4f3b 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -25,17 +25,15 @@ from typing_extensions import Any from graphiti_core.driver.driver import GraphDriver, GraphDriverSession from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings from graphiti_core.embedder import EmbedderClient -from graphiti_core.graph_queries import ( - get_entity_edge_save_bulk_query, - get_entity_node_save_bulk_query, -) from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import normalize_l2, semaphore_gather from graphiti_core.models.edges.edge_db_queries import ( EPISODIC_EDGE_SAVE_BULK, + get_entity_edge_save_bulk_query, ) from graphiti_core.models.nodes.node_db_queries import ( EPISODIC_NODE_SAVE_BULK, + get_entity_node_save_bulk_query, ) from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.utils.maintenance.edge_operations import ( @@ -158,7 +156,7 @@ async def add_nodes_and_edges_bulk_tx( edges.append(edge_data) await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes) - entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider) + entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes) await tx.run(entity_node_save_bulk, nodes=nodes) await tx.run( EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges] @@ -171,9 +169,9 @@ async def extract_nodes_and_edges_bulk( clients: GraphitiClients, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], edge_type_map: dict[tuple[str, str], list[str]], - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, excluded_entity_types: list[str] | None = None, - edge_types: dict[str, BaseModel] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, ) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]: extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather( *[ @@ -204,7 +202,7 @@ async def dedupe_nodes_bulk( clients: GraphitiClients, extracted_nodes: list[list[EntityNode]], episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]: embedder = clients.embedder min_score = 0.8 @@ -292,7 +290,7 @@ async def dedupe_edges_bulk( extracted_edges: list[list[EntityEdge]], episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], _entities: list[EntityNode], - edge_types: dict[str, BaseModel], + edge_types: dict[str, type[BaseModel]], _edge_type_map: dict[tuple[str, str], list[str]], ) -> dict[str, list[EntityEdge]]: embedder = clients.embedder diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index de74d56b..243a5079 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -34,7 +34,7 @@ async def get_community_clusters( group_id_values, _, _ = await driver.execute_query( """ MATCH (n:Entity WHERE n.group_id IS NOT NULL) - RETURN + RETURN collect(DISTINCT n.group_id) AS group_ids """, ) @@ -233,10 +233,10 @@ async def determine_entity_community( """ MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid}) RETURN - c.uuid As uuid, + c.uuid AS uuid, c.name AS name, c.group_id AS group_id, - c.created_at AS created_at, + c.created_at AS created_at, c.summary AS summary """, entity_uuid=entity.uuid, @@ -250,10 +250,10 @@ async def determine_entity_community( """ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid}) RETURN - c.uuid As uuid, + c.uuid AS uuid, c.name AS name, c.group_id AS group_id, - c.created_at AS created_at, + c.created_at AS created_at, c.summary AS summary """, entity_uuid=entity.uuid, @@ -286,11 +286,11 @@ async def determine_entity_community( async def update_community( driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode -): +) -> tuple[list[CommunityNode], list[CommunityEdge]]: community, is_new = await determine_entity_community(driver, entity) if community is None: - return + return [], [] new_summary = await summarize_pair(llm_client, (entity.summary, community.summary)) new_name = await generate_summary_description(llm_client, new_summary) @@ -298,10 +298,14 @@ async def update_community( community.summary = new_summary community.name = new_name + community_edges = [] if is_new: community_edge = (build_community_edges([entity], community, utc_now()))[0] await community_edge.save(driver) + community_edges.append(community_edge) await community.generate_name_embedding(embedder) await community.save(driver) + + return [community], community_edges diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index ad9267bf..ef78db43 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -34,7 +34,7 @@ from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client.config import ModelSize from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode from graphiti_core.prompts import prompt_library -from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts +from graphiti_core.prompts.dedupe_edges import EdgeDuplicate from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges @@ -114,7 +114,7 @@ async def extract_edges( previous_episodes: list[EpisodicNode], edge_type_map: dict[tuple[str, str], list[str]], group_id: str = '', - edge_types: dict[str, BaseModel] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, ) -> list[EntityEdge]: start = time() @@ -161,9 +161,9 @@ async def extract_edges( response_model=ExtractedEdges, max_tokens=extract_edges_max_tokens, ) - edges_data = llm_response.get('edges', []) + edges_data = ExtractedEdges(**llm_response).edges - context['extracted_facts'] = [edge_data.get('fact', '') for edge_data in edges_data] + context['extracted_facts'] = [edge_data.fact for edge_data in edges_data] reflexion_iterations += 1 if reflexion_iterations < MAX_REFLEXION_ITERATIONS: @@ -193,20 +193,20 @@ async def extract_edges( edges = [] for edge_data in edges_data: # Validate Edge Date information - valid_at = edge_data.get('valid_at', None) - invalid_at = edge_data.get('invalid_at', None) + valid_at = edge_data.valid_at + invalid_at = edge_data.invalid_at valid_at_datetime = None invalid_at_datetime = None - source_node_idx = edge_data.get('source_entity_id', -1) - target_node_idx = edge_data.get('target_entity_id', -1) + source_node_idx = edge_data.source_entity_id + target_node_idx = edge_data.target_entity_id if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)): logger.warning( - f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} ' + f'WARNING: source or target node not filled {edge_data.relation_type}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} ' ) continue source_node_uuid = nodes[source_node_idx].uuid - target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid + target_node_uuid = nodes[edge_data.target_entity_id].uuid if valid_at: try: @@ -226,9 +226,9 @@ async def extract_edges( edge = EntityEdge( source_node_uuid=source_node_uuid, target_node_uuid=target_node_uuid, - name=edge_data.get('relation_type', ''), + name=edge_data.relation_type, group_id=group_id, - fact=edge_data.get('fact', ''), + fact=edge_data.fact, episodes=[episode.uuid], created_at=utc_now(), valid_at=valid_at_datetime, @@ -249,7 +249,7 @@ async def resolve_extracted_edges( extracted_edges: list[EntityEdge], episode: EpisodicNode, entities: list[EntityNode], - edge_types: dict[str, BaseModel], + edge_types: dict[str, type[BaseModel]], edge_type_map: dict[tuple[str, str], list[str]], ) -> tuple[list[EntityEdge], list[EntityEdge]]: driver = clients.driver @@ -272,7 +272,7 @@ async def resolve_extracted_edges( uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities} # Determine which edge types are relevant for each edge - edge_types_lst: list[dict[str, BaseModel]] = [] + edge_types_lst: list[dict[str, type[BaseModel]]] = [] for extracted_edge in extracted_edges: source_node = uuid_entity_map.get(extracted_edge.source_node_uuid) target_node = uuid_entity_map.get(extracted_edge.target_node_uuid) @@ -381,7 +381,7 @@ async def resolve_extracted_edge( related_edges: list[EntityEdge], existing_edges: list[EntityEdge], episode: EpisodicNode, - edge_types: dict[str, BaseModel] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]: if len(related_edges) == 0 and len(existing_edges) == 0: return extracted_edge, [], [] @@ -422,10 +422,10 @@ async def resolve_extracted_edge( response_model=EdgeDuplicate, model_size=ModelSize.small, ) + response_object = EdgeDuplicate(**llm_response) + duplicate_facts = response_object.duplicate_facts - duplicate_fact_ids: list[int] = list( - filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', [])) - ) + duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)] resolved_edge = extracted_edge for duplicate_fact_id in duplicate_fact_ids: @@ -435,11 +435,13 @@ async def resolve_extracted_edge( if duplicate_fact_ids and episode is not None: resolved_edge.episodes.append(episode.uuid) - contradicted_facts: list[int] = llm_response.get('contradicted_facts', []) + contradicted_facts: list[int] = response_object.contradicted_facts - invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts] + invalidation_candidates: list[EntityEdge] = [ + existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges) + ] - fact_type: str = str(llm_response.get('fact_type')) + fact_type: str = response_object.fact_type if fact_type.upper() != 'DEFAULT' and edge_types is not None: resolved_edge.name = fact_type @@ -494,39 +496,6 @@ async def resolve_extracted_edge( return resolved_edge, invalidated_edges, duplicate_edges -async def dedupe_edge_list( - llm_client: LLMClient, - edges: list[EntityEdge], -) -> list[EntityEdge]: - start = time() - - # Create edge map - edge_map = {} - for edge in edges: - edge_map[edge.uuid] = edge - - # Prepare context for LLM - context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]} - - llm_response = await llm_client.generate_response( - prompt_library.dedupe_edges.edge_list(context), response_model=UniqueFacts - ) - unique_edges_data = llm_response.get('unique_facts', []) - - end = time() - logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ') - - # Get full edge data - unique_edges = [] - for edge_data in unique_edges_data: - uuid = edge_data['uuid'] - edge = edge_map[uuid] - edge.fact = edge_data['fact'] - unique_edges.append(edge) - - return unique_edges - - async def filter_existing_duplicate_of_edges( driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]] ) -> list[tuple[EntityNode, EntityNode]]: diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index f6975053..b866607f 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -15,14 +15,15 @@ limitations under the License. """ import logging -from datetime import datetime, timezone +from datetime import datetime from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices -from graphiti_core.helpers import parse_db_date, semaphore_gather -from graphiti_core.nodes import EpisodeType, EpisodicNode +from graphiti_core.helpers import semaphore_gather +from graphiti_core.models.nodes.node_db_queries import EPISODIC_NODE_RETURN +from graphiti_core.nodes import EpisodeType, EpisodicNode, get_episodic_node_from_record EPISODE_WINDOW_LEN = 3 @@ -33,8 +34,8 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo if delete_existing: records, _, _ = await driver.execute_query( """ - SHOW INDEXES YIELD name - """, + SHOW INDEXES YIELD name + """, ) index_names = [record['name'] for record in records] await semaphore_gather( @@ -108,19 +109,16 @@ async def retrieve_episodes( query: LiteralString = ( """ - MATCH (e:Episodic) WHERE e.valid_at <= $reference_time - """ + MATCH (e:Episodic) + WHERE e.valid_at <= $reference_time + """ + group_id_filter + source_filter + """ - RETURN e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.group_id AS group_id, - e.name AS name, - e.source_description AS source_description, - e.source AS source + RETURN + """ + + EPISODIC_NODE_RETURN + + """ ORDER BY e.valid_at DESC LIMIT $num_episodes """ @@ -133,18 +131,5 @@ async def retrieve_episodes( group_ids=group_ids, ) - episodes = [ - EpisodicNode( - content=record['content'], - created_at=parse_db_date(record['created_at']) - or datetime.min.replace(tzinfo=timezone.utc), - valid_at=parse_db_date(record['valid_at']) or datetime.min.replace(tzinfo=timezone.utc), - uuid=record['uuid'], - group_id=record['group_id'], - source=EpisodeType.from_str(record['source']), - name=record['name'], - source_description=record['source_description'], - ) - for record in result - ] + episodes = [get_episodic_node_from_record(record) for record in result] return list(reversed(episodes)) # Return in chronological order diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 1588d042..51a7be4b 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -15,13 +15,10 @@ limitations under the License. """ import logging -from contextlib import suppress from time import time from typing import Any -from uuid import uuid4 -import pydantic -from pydantic import BaseModel, Field +from pydantic import BaseModel from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather @@ -29,8 +26,9 @@ from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client.config import ModelSize from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.prompts import prompt_library -from graphiti_core.prompts.dedupe_nodes import NodeResolutions +from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions from graphiti_core.prompts.extract_nodes import ( + EntitySummary, ExtractedEntities, ExtractedEntity, MissedEntities, @@ -70,7 +68,7 @@ async def extract_nodes( clients: GraphitiClients, episode: EpisodicNode, previous_episodes: list[EpisodicNode], - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, excluded_entity_types: list[str] | None = None, ) -> list[EntityNode]: start = time() @@ -125,10 +123,9 @@ async def extract_nodes( prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities ) - extracted_entities: list[ExtractedEntity] = [ - ExtractedEntity(**entity_types_context) - for entity_types_context in llm_response.get('extracted_entities', []) - ] + response_object = ExtractedEntities(**llm_response) + + extracted_entities: list[ExtractedEntity] = response_object.extracted_entities reflexion_iterations += 1 if reflexion_iterations < MAX_REFLEXION_ITERATIONS: @@ -181,7 +178,7 @@ async def resolve_extracted_nodes( extracted_nodes: list[EntityNode], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, existing_nodes_override: list[EntityNode] | None = None, ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]: llm_client = clients.llm_client @@ -224,7 +221,7 @@ async def resolve_extracted_nodes( ], ) - entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {} + entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {} # Prepare context for LLM extracted_nodes_context = [ @@ -254,14 +251,14 @@ async def resolve_extracted_nodes( response_model=NodeResolutions, ) - node_resolutions: list = llm_response.get('entity_resolutions', []) + node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions resolved_nodes: list[EntityNode] = [] uuid_map: dict[str, str] = {} node_duplicates: list[tuple[EntityNode, EntityNode]] = [] for resolution in node_resolutions: - resolution_id: int = resolution.get('id', -1) - duplicate_idx: int = resolution.get('duplicate_idx', -1) + resolution_id: int = resolution.id + duplicate_idx: int = resolution.duplicate_idx extracted_node = extracted_nodes[resolution_id] @@ -276,7 +273,7 @@ async def resolve_extracted_nodes( resolved_nodes.append(resolved_node) uuid_map[extracted_node.uuid] = resolved_node.uuid - duplicates: list[int] = resolution.get('duplicates', []) + duplicates: list[int] = resolution.duplicates if duplicate_idx not in duplicates and duplicate_idx > -1: duplicates.append(duplicate_idx) for idx in duplicates: @@ -298,7 +295,7 @@ async def extract_attributes_from_nodes( nodes: list[EntityNode], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, ) -> list[EntityNode]: llm_client = clients.llm_client embedder = clients.embedder @@ -327,7 +324,7 @@ async def extract_attributes_from_node( node: EntityNode, episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_type: BaseModel | None = None, + entity_type: type[BaseModel] | None = None, ) -> EntityNode: node_context: dict[str, Any] = { 'name': node.name, @@ -336,25 +333,14 @@ async def extract_attributes_from_node( 'attributes': node.attributes, } - attributes_definitions: dict[str, Any] = { - 'summary': ( - str, - Field( - description='Summary containing the important information about the entity. Under 250 words', - ), - ) + attributes_context: dict[str, Any] = { + 'node': node_context, + 'episode_content': episode.content if episode is not None else '', + 'previous_episodes': [ep.content for ep in previous_episodes] + if previous_episodes is not None + else [], } - if entity_type is not None: - for field_name, field_info in entity_type.model_fields.items(): - attributes_definitions[field_name] = ( - field_info.annotation, - Field(description=field_info.description), - ) - - unique_model_name = f'EntityAttributes_{uuid4().hex}' - entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions) - summary_context: dict[str, Any] = { 'node': node_context, 'episode_content': episode.content if episode is not None else '', @@ -363,63 +349,30 @@ async def extract_attributes_from_node( else [], } - llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.extract_attributes(summary_context), - response_model=entity_attributes_model, + llm_response = ( + ( + await llm_client.generate_response( + prompt_library.extract_nodes.extract_attributes(attributes_context), + response_model=entity_type, + model_size=ModelSize.small, + ) + ) + if entity_type is not None + else {} + ) + + summary_response = await llm_client.generate_response( + prompt_library.extract_nodes.extract_summary(summary_context), + response_model=EntitySummary, model_size=ModelSize.small, ) - node.summary = llm_response.get('summary', node.summary) - node_attributes = {key: value for key, value in llm_response.items()} + if entity_type is not None: + entity_type(**llm_response) - with suppress(KeyError): - del node_attributes['summary'] + node.summary = summary_response.get('summary', '') + node_attributes = {key: value for key, value in llm_response.items()} node.attributes.update(node_attributes) return node - - -async def dedupe_node_list( - llm_client: LLMClient, - nodes: list[EntityNode], -) -> tuple[list[EntityNode], dict[str, str]]: - start = time() - - # build node map - node_map = {} - for node in nodes: - node_map[node.uuid] = node - - # Prepare context for LLM - nodes_context = [{'uuid': node.uuid, 'name': node.name, **node.attributes} for node in nodes] - - context = { - 'nodes': nodes_context, - } - - llm_response = await llm_client.generate_response( - prompt_library.dedupe_nodes.node_list(context) - ) - - nodes_data = llm_response.get('nodes', []) - - end = time() - logger.debug(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms') - - # Get full node data - unique_nodes = [] - uuid_map: dict[str, str] = {} - for node_data in nodes_data: - node_instance: EntityNode | None = node_map.get(node_data['uuids'][0]) - if node_instance is None: - logger.warning(f'Node {node_data["uuids"][0]} not found in node map') - continue - node_instance.summary = node_data['summary'] - unique_nodes.append(node_instance) - - for uuid in node_data['uuids'][1:]: - uuid_value = node_map[node_data['uuids'][0]].uuid - uuid_map[uuid] = uuid_value - - return unique_nodes, uuid_map diff --git a/graphiti_core/utils/ontology_utils/entity_types_utils.py b/graphiti_core/utils/ontology_utils/entity_types_utils.py index f6cb08fb..bbc07af7 100644 --- a/graphiti_core/utils/ontology_utils/entity_types_utils.py +++ b/graphiti_core/utils/ontology_utils/entity_types_utils.py @@ -21,7 +21,7 @@ from graphiti_core.nodes import EntityNode def validate_entity_types( - entity_types: dict[str, BaseModel] | None, + entity_types: dict[str, type[BaseModel]] | None, ) -> bool: if entity_types is None: return True diff --git a/pyproject.toml b/pyproject.toml index d2e49acc..41639b64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,11 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.18.0" +version = "0.18.2" authors = [ - { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, - { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" }, - { "name" = "Daniel Chalef", "email" = "daniel@getzep.com" }, + { name = "Paul Paliychuk", email = "paul@getzep.com" }, + { name = "Preston Rasmussen", email = "preston@getzep.com" }, + { name = "Daniel Chalef", email = "daniel@getzep.com" }, ] readme = "README.md" license = "Apache-2.0" diff --git a/signatures/version1/cla.json b/signatures/version1/cla.json index f682fab3..6cc093f1 100644 --- a/signatures/version1/cla.json +++ b/signatures/version1/cla.json @@ -247,6 +247,22 @@ "created_at": "2025-07-24T15:39:36Z", "repoId": 840056306, "pullRequestNo": 764 + }, + { + "name": "gifflet", + "id": 33522742, + "comment_id": 3133869379, + "created_at": "2025-07-29T20:00:27Z", + "repoId": 840056306, + "pullRequestNo": 782 + }, + { + "name": "bechbd", + "id": 6898505, + "comment_id": 3140501814, + "created_at": "2025-07-31T15:58:08Z", + "repoId": 840056306, + "pullRequestNo": 793 } ] } \ No newline at end of file diff --git a/tests/driver/test_falkordb_driver.py b/tests/driver/test_falkordb_driver.py index 260e24d2..6cca9e74 100644 --- a/tests/driver/test_falkordb_driver.py +++ b/tests/driver/test_falkordb_driver.py @@ -21,6 +21,8 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from graphiti_core.driver.driver import GraphProvider + try: from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession @@ -48,7 +50,7 @@ class TestFalkorDriver: driver = FalkorDriver( host='test-host', port='1234', username='test-user', password='test-pass' ) - assert driver.provider == 'falkordb' + assert driver.provider == GraphProvider.FALKORDB mock_falkor_db.assert_called_once_with( host='test-host', port='1234', username='test-user', password='test-pass' ) @@ -59,14 +61,14 @@ class TestFalkorDriver: with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class: mock_falkor_db = MagicMock() driver = FalkorDriver(falkor_db=mock_falkor_db) - assert driver.provider == 'falkordb' + assert driver.provider == GraphProvider.FALKORDB assert driver.client is mock_falkor_db mock_falkor_db_class.assert_not_called() @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') def test_provider(self): """Test driver provider identification.""" - assert self.driver.provider == 'falkordb' + assert self.driver.provider == GraphProvider.FALKORDB @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') def test_get_graph_with_name(self): diff --git a/tests/helpers_test.py b/tests/helpers_test.py index e0f8b99c..9aa9e1ef 100644 --- a/tests/helpers_test.py +++ b/tests/helpers_test.py @@ -14,16 +14,76 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os + import pytest +from dotenv import load_dotenv + +from graphiti_core.driver.driver import GraphDriver + +load_dotenv() + +HAS_NEO4J = False +HAS_FALKORDB = False +if os.getenv('DISABLE_NEO4J') is None: + try: + from graphiti_core.driver.neo4j_driver import Neo4jDriver + + HAS_NEO4J = True + except ImportError: + pass + +if os.getenv('DISABLE_FALKORDB') is None: + try: + from graphiti_core.driver.falkordb_driver import FalkorDriver + + HAS_FALKORDB = True + except ImportError: + pass + +NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687') +NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j') +NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test') + +FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost') +FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379') +FALKORDB_USER = os.getenv('FALKORDB_USER', None) +FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None) + + +def get_driver(driver_name: str) -> GraphDriver: + if driver_name == 'neo4j': + return Neo4jDriver( + uri=NEO4J_URI, + user=NEO4J_USER, + password=NEO4J_PASSWORD, + ) + elif driver_name == 'falkordb': + return FalkorDriver( + host=FALKORDB_HOST, + port=int(FALKORDB_PORT), + username=FALKORDB_USER, + password=FALKORDB_PASSWORD, + ) + else: + raise ValueError(f'Driver {driver_name} not available') + + +drivers: list[str] = [] +if HAS_NEO4J: + drivers.append('neo4j') +if HAS_FALKORDB: + drivers.append('falkordb') def test_neo4j_sanitize(): """Test Neo4j sanitization by importing the driver and using its sanitize method.""" - from graphiti_core.driver.neo4j_driver import Neo4jDriver + if not HAS_NEO4J: + pytest.skip("Neo4j not available - skipping sanitize test") # Create a driver instance - if it fails, we'll handle it gracefully try: - driver = Neo4jDriver(uri="bolt://localhost:7687", user="test", password="test") + driver = Neo4jDriver(uri=NEO4J_URI, user=NEO4J_USER, password=NEO4J_PASSWORD) except Exception: # If we can't create a real driver, skip this test pytest.skip("Neo4j driver connection failed - skipping sanitize test") @@ -44,11 +104,16 @@ def test_neo4j_sanitize(): def test_falkordb_sanitize(): """Test FalkorDB sanitization by importing the driver and using its sanitize method.""" + if not HAS_FALKORDB: + pytest.skip("FalkorDB not available - skipping sanitize test") + try: - from graphiti_core.driver.falkordb_driver import FalkorDriver - driver = FalkorDriver() - except ImportError: - pytest.skip("FalkorDB not installed - skipping sanitize test") + driver = FalkorDriver( + host=FALKORDB_HOST, + port=int(FALKORDB_PORT), + username=FALKORDB_USER, + password=FALKORDB_PASSWORD, + ) except Exception: # If we can't create a real driver, skip this test pytest.skip("FalkorDB driver connection failed - skipping sanitize test") @@ -69,20 +134,5 @@ def test_falkordb_sanitize(): assert expected_result == result -# Keep the drivers and get_driver functions for other tests that import from this file -drivers = ['neo4j', 'falkordb'] - -def get_driver(driver_name: str): - """Helper function to get a driver instance for testing""" - if driver_name == 'neo4j': - from graphiti_core.driver.neo4j_driver import Neo4jDriver - return Neo4jDriver(uri="bolt://localhost:7687", user="neo4j", password="test") - elif driver_name == 'falkordb': - from graphiti_core.driver.falkordb_driver import FalkorDriver - return FalkorDriver() - else: - raise ValueError(f"Unknown driver: {driver_name}") - - if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/test_edge_int.py b/tests/test_edge_int.py new file mode 100644 index 00000000..6eb769a4 --- /dev/null +++ b/tests/test_edge_int.py @@ -0,0 +1,384 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +import sys +from datetime import datetime +from uuid import uuid4 + +import numpy as np +import pytest + +from graphiti_core.driver.driver import GraphDriver +from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge +from graphiti_core.embedder.openai import OpenAIEmbedder +from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode +from tests.helpers_test import drivers, get_driver + +pytestmark = pytest.mark.integration + +pytest_plugins = ('pytest_asyncio',) + +group_id = f'test_group_{str(uuid4())}' + + +def setup_logging(): + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO + + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # Add formatter to console handler + console_handler.setFormatter(formatter) + + # Add console handler to logger + logger.addHandler(console_handler) + + return logger + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_episodic_edge(driver): + graph_driver = get_driver(driver) + embedder = OpenAIEmbedder() + + now = datetime.now() + + episode_node = EpisodicNode( + name='test_episode', + labels=[], + created_at=now, + valid_at=now, + source=EpisodeType.message, + source_description='conversation message', + content='Alice likes Bob', + entity_edges=[], + group_id=group_id, + ) + + node_count = await get_node_count(graph_driver, episode_node.uuid) + assert node_count == 0 + await episode_node.save(graph_driver) + node_count = await get_node_count(graph_driver, episode_node.uuid) + assert node_count == 1 + + alice_node = EntityNode( + name='Alice', + labels=[], + created_at=now, + summary='Alice summary', + group_id=group_id, + ) + await alice_node.generate_name_embedding(embedder) + + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + await alice_node.save(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 1 + + episodic_edge = EpisodicEdge( + source_node_uuid=episode_node.uuid, + target_node_uuid=alice_node.uuid, + created_at=now, + group_id=group_id, + ) + + edge_count = await get_edge_count(graph_driver, episodic_edge.uuid) + assert edge_count == 0 + await episodic_edge.save(graph_driver) + edge_count = await get_edge_count(graph_driver, episodic_edge.uuid) + assert edge_count == 1 + + retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid) + assert retrieved.uuid == episodic_edge.uuid + assert retrieved.source_node_uuid == episode_node.uuid + assert retrieved.target_node_uuid == alice_node.uuid + assert retrieved.created_at == now + assert retrieved.group_id == group_id + + retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid]) + assert len(retrieved) == 1 + assert retrieved[0].uuid == episodic_edge.uuid + assert retrieved[0].source_node_uuid == episode_node.uuid + assert retrieved[0].target_node_uuid == alice_node.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2) + assert len(retrieved) == 1 + assert retrieved[0].uuid == episodic_edge.uuid + assert retrieved[0].source_node_uuid == episode_node.uuid + assert retrieved[0].target_node_uuid == alice_node.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + await episodic_edge.delete(graph_driver) + edge_count = await get_edge_count(graph_driver, episodic_edge.uuid) + assert edge_count == 0 + + await episode_node.delete(graph_driver) + node_count = await get_node_count(graph_driver, episode_node.uuid) + assert node_count == 0 + + await alice_node.delete(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + + await graph_driver.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_entity_edge(driver): + graph_driver = get_driver(driver) + embedder = OpenAIEmbedder() + + now = datetime.now() + + alice_node = EntityNode( + name='Alice', + labels=[], + created_at=now, + summary='Alice summary', + group_id=group_id, + ) + await alice_node.generate_name_embedding(embedder) + + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + await alice_node.save(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 1 + + bob_node = EntityNode( + name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id + ) + await bob_node.generate_name_embedding(embedder) + + node_count = await get_node_count(graph_driver, bob_node.uuid) + assert node_count == 0 + await bob_node.save(graph_driver) + node_count = await get_node_count(graph_driver, bob_node.uuid) + assert node_count == 1 + + entity_edge = EntityEdge( + source_node_uuid=alice_node.uuid, + target_node_uuid=bob_node.uuid, + created_at=now, + name='likes', + fact='Alice likes Bob', + episodes=[], + expired_at=now, + valid_at=now, + invalid_at=now, + group_id=group_id, + ) + edge_embedding = await entity_edge.generate_embedding(embedder) + + edge_count = await get_edge_count(graph_driver, entity_edge.uuid) + assert edge_count == 0 + await entity_edge.save(graph_driver) + edge_count = await get_edge_count(graph_driver, entity_edge.uuid) + assert edge_count == 1 + + retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid) + assert retrieved.uuid == entity_edge.uuid + assert retrieved.source_node_uuid == alice_node.uuid + assert retrieved.target_node_uuid == bob_node.uuid + assert retrieved.created_at == now + assert retrieved.group_id == group_id + + retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid]) + assert len(retrieved) == 1 + assert retrieved[0].uuid == entity_edge.uuid + assert retrieved[0].source_node_uuid == alice_node.uuid + assert retrieved[0].target_node_uuid == bob_node.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2) + assert len(retrieved) == 1 + assert retrieved[0].uuid == entity_edge.uuid + assert retrieved[0].source_node_uuid == alice_node.uuid + assert retrieved[0].target_node_uuid == bob_node.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid) + assert len(retrieved) == 1 + assert retrieved[0].uuid == entity_edge.uuid + assert retrieved[0].source_node_uuid == alice_node.uuid + assert retrieved[0].target_node_uuid == bob_node.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + await entity_edge.load_fact_embedding(graph_driver) + assert np.allclose(entity_edge.fact_embedding, edge_embedding) + + await entity_edge.delete(graph_driver) + edge_count = await get_edge_count(graph_driver, entity_edge.uuid) + assert edge_count == 0 + + await alice_node.delete(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + + await bob_node.delete(graph_driver) + node_count = await get_node_count(graph_driver, bob_node.uuid) + assert node_count == 0 + + await graph_driver.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_community_edge(driver): + graph_driver = get_driver(driver) + embedder = OpenAIEmbedder() + + now = datetime.now() + + community_node_1 = CommunityNode( + name='Community A', + group_id=group_id, + summary='Community A summary', + ) + await community_node_1.generate_name_embedding(embedder) + node_count = await get_node_count(graph_driver, community_node_1.uuid) + assert node_count == 0 + await community_node_1.save(graph_driver) + node_count = await get_node_count(graph_driver, community_node_1.uuid) + assert node_count == 1 + + community_node_2 = CommunityNode( + name='Community B', + group_id=group_id, + summary='Community B summary', + ) + await community_node_2.generate_name_embedding(embedder) + node_count = await get_node_count(graph_driver, community_node_2.uuid) + assert node_count == 0 + await community_node_2.save(graph_driver) + node_count = await get_node_count(graph_driver, community_node_2.uuid) + assert node_count == 1 + + alice_node = EntityNode( + name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id + ) + await alice_node.generate_name_embedding(embedder) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + await alice_node.save(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 1 + + community_edge = CommunityEdge( + source_node_uuid=community_node_1.uuid, + target_node_uuid=community_node_2.uuid, + created_at=now, + group_id=group_id, + ) + edge_count = await get_edge_count(graph_driver, community_edge.uuid) + assert edge_count == 0 + await community_edge.save(graph_driver) + edge_count = await get_edge_count(graph_driver, community_edge.uuid) + assert edge_count == 1 + + retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid) + assert retrieved.uuid == community_edge.uuid + assert retrieved.source_node_uuid == community_node_1.uuid + assert retrieved.target_node_uuid == community_node_2.uuid + assert retrieved.created_at == now + assert retrieved.group_id == group_id + + retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid]) + assert len(retrieved) == 1 + assert retrieved[0].uuid == community_edge.uuid + assert retrieved[0].source_node_uuid == community_node_1.uuid + assert retrieved[0].target_node_uuid == community_node_2.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1) + assert len(retrieved) == 1 + assert retrieved[0].uuid == community_edge.uuid + assert retrieved[0].source_node_uuid == community_node_1.uuid + assert retrieved[0].target_node_uuid == community_node_2.uuid + assert retrieved[0].created_at == now + assert retrieved[0].group_id == group_id + + await community_edge.delete(graph_driver) + edge_count = await get_edge_count(graph_driver, community_edge.uuid) + assert edge_count == 0 + + await alice_node.delete(graph_driver) + node_count = await get_node_count(graph_driver, alice_node.uuid) + assert node_count == 0 + + await community_node_1.delete(graph_driver) + node_count = await get_node_count(graph_driver, community_node_1.uuid) + assert node_count == 0 + + await community_node_2.delete(graph_driver) + node_count = await get_node_count(graph_driver, community_node_2.uuid) + assert node_count == 0 + + await graph_driver.close() + + +async def get_node_count(driver: GraphDriver, uuid: str): + results, _, _ = await driver.execute_query( + """ + MATCH (n {uuid: $uuid}) + RETURN COUNT(n) as count + """, + uuid=uuid, + ) + return int(results[0]['count']) + + +async def get_edge_count(driver: GraphDriver, uuid: str): + results, _, _ = await driver.execute_query( + """ + MATCH (n)-[e {uuid: $uuid}]->(m) + RETURN COUNT(e) as count + UNION ALL + MATCH (n)-[e:RELATES_TO]->(m {uuid: $uuid})-[e2:RELATES_TO]->(m2) + RETURN COUNT(m) as count + """, + uuid=uuid, + ) + return sum(int(result['count']) for result in results) diff --git a/tests/test_entity_exclusion_int.py b/tests/test_entity_exclusion_int.py index 715e5ada..473177b1 100644 --- a/tests/test_entity_exclusion_int.py +++ b/tests/test_entity_exclusion_int.py @@ -14,26 +14,18 @@ See the License for the specific language governing permissions and limitations under the License. """ -import os from datetime import datetime, timezone import pytest -from dotenv import load_dotenv from pydantic import BaseModel, Field from graphiti_core.graphiti import Graphiti from graphiti_core.helpers import validate_excluded_entity_types +from tests.helpers_test import drivers, get_driver pytestmark = pytest.mark.integration - pytest_plugins = ('pytest_asyncio',) -load_dotenv() - -NEO4J_URI = os.getenv('NEO4J_URI') -NEO4J_USER = os.getenv('NEO4J_USER') -NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD') - # Test entity type definitions class Person(BaseModel): @@ -65,9 +57,14 @@ class Location(BaseModel): @pytest.mark.asyncio -async def test_exclude_default_entity_type(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_exclude_default_entity_type(driver): """Test excluding the default 'Entity' type while keeping custom types.""" - graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + graphiti = Graphiti(graph_driver=get_driver(driver)) try: await graphiti.build_indices_and_constraints() @@ -118,9 +115,14 @@ async def test_exclude_default_entity_type(): @pytest.mark.asyncio -async def test_exclude_specific_custom_types(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_exclude_specific_custom_types(driver): """Test excluding specific custom entity types while keeping others.""" - graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + graphiti = Graphiti(graph_driver=get_driver(driver)) try: await graphiti.build_indices_and_constraints() @@ -177,9 +179,14 @@ async def test_exclude_specific_custom_types(): @pytest.mark.asyncio -async def test_exclude_all_types(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_exclude_all_types(driver): """Test excluding all entity types (edge case).""" - graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + graphiti = Graphiti(graph_driver=get_driver(driver)) try: await graphiti.build_indices_and_constraints() @@ -221,9 +228,14 @@ async def test_exclude_all_types(): @pytest.mark.asyncio -async def test_exclude_no_types(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_exclude_no_types(driver): """Test normal behavior when no types are excluded (baseline test).""" - graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + graphiti = Graphiti(graph_driver=get_driver(driver)) try: await graphiti.build_indices_and_constraints() @@ -299,9 +311,14 @@ def test_validation_invalid_excluded_types(): @pytest.mark.asyncio -async def test_excluded_types_parameter_validation_in_add_episode(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_excluded_types_parameter_validation_in_add_episode(driver): """Test that add_episode validates excluded_entity_types parameter.""" - graphiti = Graphiti(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD) + graphiti = Graphiti(graph_driver=get_driver(driver)) try: entity_types = { diff --git a/tests/test_graphiti_falkordb_int.py b/tests/test_graphiti_falkordb_int.py deleted file mode 100644 index c41ef47e..00000000 --- a/tests/test_graphiti_falkordb_int.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Copyright 2024, Zep Software, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import logging -import os -import sys -import unittest -from datetime import datetime, timezone - -import pytest -from dotenv import load_dotenv - -from graphiti_core.edges import EntityEdge, EpisodicEdge -from graphiti_core.graphiti import Graphiti -from graphiti_core.helpers import semaphore_gather -from graphiti_core.nodes import EntityNode, EpisodicNode -from graphiti_core.search.search_helpers import search_results_to_context_string - -try: - from graphiti_core.driver.falkordb_driver import FalkorDriver - - HAS_FALKORDB = True -except ImportError: - FalkorDriver = None - HAS_FALKORDB = False - -pytestmark = pytest.mark.integration - -pytest_plugins = ('pytest_asyncio',) - -load_dotenv() - -FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost') -FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379') -FALKORDB_USER = os.getenv('FALKORDB_USER', None) -FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None) - - -def setup_logging(): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) # Set the logging level to INFO - - # Create console handler and set level to INFO - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) - - # Create formatter - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - # Add formatter to console handler - console_handler.setFormatter(formatter) - - # Add console handler to logger - logger.addHandler(console_handler) - - return logger - - -@pytest.mark.asyncio -@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') -async def test_graphiti_falkordb_init(): - logger = setup_logging() - - falkor_driver = FalkorDriver( - host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD - ) - - graphiti = Graphiti(graph_driver=falkor_driver) - - results = await graphiti.search_(query='Who is the user?') - - pretty_results = search_results_to_context_string(results) - - logger.info(pretty_results) - - await graphiti.close() - - -@pytest.mark.asyncio -@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') -async def test_graph_falkordb_integration(): - falkor_driver = FalkorDriver( - host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD - ) - - client = Graphiti(graph_driver=falkor_driver) - embedder = client.embedder - driver = client.driver - - now = datetime.now(timezone.utc) - episode = EpisodicNode( - name='test_episode', - labels=[], - created_at=now, - valid_at=now, - source='message', - source_description='conversation message', - content='Alice likes Bob', - entity_edges=[], - ) - - alice_node = EntityNode( - name='Alice', - labels=[], - created_at=now, - summary='Alice summary', - ) - - bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary') - - episodic_edge_1 = EpisodicEdge( - source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now - ) - - episodic_edge_2 = EpisodicEdge( - source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now - ) - - entity_edge = EntityEdge( - source_node_uuid=alice_node.uuid, - target_node_uuid=bob_node.uuid, - created_at=now, - name='likes', - fact='Alice likes Bob', - episodes=[], - expired_at=now, - valid_at=now, - invalid_at=now, - ) - - await entity_edge.generate_embedding(embedder) - - nodes = [episode, alice_node, bob_node] - edges = [episodic_edge_1, episodic_edge_2, entity_edge] - - # test save - await semaphore_gather(*[node.save(driver) for node in nodes]) - await semaphore_gather(*[edge.save(driver) for edge in edges]) - - # test get - assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None - assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None - assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None - assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None - - # test delete - await semaphore_gather(*[node.delete(driver) for node in nodes]) - await semaphore_gather(*[edge.delete(driver) for edge in edges]) - - await client.close() diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index acec74cb..237e2bfc 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -15,31 +15,19 @@ limitations under the License. """ import logging -import os import sys -from datetime import datetime, timezone import pytest -from dotenv import load_dotenv -from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.graphiti import Graphiti -from graphiti_core.helpers import semaphore_gather -from graphiti_core.nodes import EntityNode, EpisodicNode from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters from graphiti_core.search.search_helpers import search_results_to_context_string from graphiti_core.utils.datetime_utils import utc_now +from tests.helpers_test import drivers, get_driver pytestmark = pytest.mark.integration - pytest_plugins = ('pytest_asyncio',) -load_dotenv() - -NEO4J_URI = os.getenv('NEO4J_URI') -NEO4j_USER = os.getenv('NEO4J_USER') -NEO4j_PASSWORD = os.getenv('NEO4J_PASSWORD') - def setup_logging(): # Create a logger @@ -63,11 +51,21 @@ def setup_logging(): @pytest.mark.asyncio -async def test_graphiti_init(): +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_graphiti_init(driver): logger = setup_logging() - graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) + driver = get_driver(driver) + graphiti = Graphiti(graph_driver=driver) + + await graphiti.build_indices_and_constraints() + search_filter = SearchFilters( - created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]] + node_labels=['Person'], + created_at=[[DateFilter(date=utc_now(), comparison_operator=ComparisonOperator.less_than)]], ) results = await graphiti.search_( @@ -76,74 +74,6 @@ async def test_graphiti_init(): ) pretty_results = search_results_to_context_string(results) - logger.info(pretty_results) await graphiti.close() - - -@pytest.mark.asyncio -async def test_graph_integration(): - client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) - embedder = client.embedder - driver = client.driver - - now = datetime.now(timezone.utc) - episode = EpisodicNode( - name='test_episode', - labels=[], - created_at=now, - valid_at=now, - source='message', - source_description='conversation message', - content='Alice likes Bob', - entity_edges=[], - ) - - alice_node = EntityNode( - name='Alice', - labels=[], - created_at=now, - summary='Alice summary', - ) - - bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary') - - episodic_edge_1 = EpisodicEdge( - source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now - ) - - episodic_edge_2 = EpisodicEdge( - source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now - ) - - entity_edge = EntityEdge( - source_node_uuid=alice_node.uuid, - target_node_uuid=bob_node.uuid, - created_at=now, - name='likes', - fact='Alice likes Bob', - episodes=[], - expired_at=now, - valid_at=now, - invalid_at=now, - ) - - await entity_edge.generate_embedding(embedder) - - nodes = [episode, alice_node, bob_node] - edges = [episodic_edge_1, episodic_edge_2, entity_edge] - - # test save - await semaphore_gather(*[node.save(driver) for node in nodes]) - await semaphore_gather(*[edge.save(driver) for edge in edges]) - - # test get - assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None - assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None - assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None - assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None - - # test delete - await semaphore_gather(*[node.delete(driver) for node in nodes]) - await semaphore_gather(*[edge.delete(driver) for edge in edges]) diff --git a/tests/test_node_falkordb_int.py b/tests/test_node_falkordb_int.py deleted file mode 100644 index 35eaa578..00000000 --- a/tests/test_node_falkordb_int.py +++ /dev/null @@ -1,139 +0,0 @@ -""" -Copyright 2024, Zep Software, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import os -import unittest -from datetime import datetime, timezone -from uuid import uuid4 - -import pytest - -from graphiti_core.nodes import ( - CommunityNode, - EntityNode, - EpisodeType, - EpisodicNode, -) - -FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost') -FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379') -FALKORDB_USER = os.getenv('FALKORDB_USER', None) -FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None) - -try: - from graphiti_core.driver.falkordb_driver import FalkorDriver - - HAS_FALKORDB = True -except ImportError: - FalkorDriver = None - HAS_FALKORDB = False - - -@pytest.fixture -def sample_entity_node(): - return EntityNode( - uuid=str(uuid4()), - name='Test Entity', - group_id='test_group', - labels=['Entity'], - name_embedding=[0.5] * 1024, - summary='Entity Summary', - ) - - -@pytest.fixture -def sample_episodic_node(): - return EpisodicNode( - uuid=str(uuid4()), - name='Episode 1', - group_id='test_group', - source=EpisodeType.text, - source_description='Test source', - content='Some content here', - valid_at=datetime.now(timezone.utc), - ) - - -@pytest.fixture -def sample_community_node(): - return CommunityNode( - uuid=str(uuid4()), - name='Community A', - name_embedding=[0.5] * 1024, - group_id='test_group', - summary='Community summary', - ) - - -@pytest.mark.asyncio -@pytest.mark.integration -@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') -async def test_entity_node_save_get_and_delete(sample_entity_node): - falkor_driver = FalkorDriver( - host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD - ) - - await sample_entity_node.save(falkor_driver) - - retrieved = await EntityNode.get_by_uuid(falkor_driver, sample_entity_node.uuid) - assert retrieved.uuid == sample_entity_node.uuid - assert retrieved.name == 'Test Entity' - assert retrieved.group_id == 'test_group' - - await sample_entity_node.delete(falkor_driver) - await falkor_driver.close() - - -@pytest.mark.asyncio -@pytest.mark.integration -@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') -async def test_community_node_save_get_and_delete(sample_community_node): - falkor_driver = FalkorDriver( - host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD - ) - - await sample_community_node.save(falkor_driver) - - retrieved = await CommunityNode.get_by_uuid(falkor_driver, sample_community_node.uuid) - assert retrieved.uuid == sample_community_node.uuid - assert retrieved.name == 'Community A' - assert retrieved.group_id == 'test_group' - assert retrieved.summary == 'Community summary' - - await sample_community_node.delete(falkor_driver) - await falkor_driver.close() - - -@pytest.mark.asyncio -@pytest.mark.integration -@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed') -async def test_episodic_node_save_get_and_delete(sample_episodic_node): - falkor_driver = FalkorDriver( - host=FALKORDB_HOST, port=FALKORDB_PORT, username=FALKORDB_USER, password=FALKORDB_PASSWORD - ) - - await sample_episodic_node.save(falkor_driver) - - retrieved = await EpisodicNode.get_by_uuid(falkor_driver, sample_episodic_node.uuid) - assert retrieved.uuid == sample_episodic_node.uuid - assert retrieved.name == 'Episode 1' - assert retrieved.group_id == 'test_group' - assert retrieved.source == EpisodeType.text - assert retrieved.source_description == 'Test source' - assert retrieved.content == 'Some content here' - - await sample_episodic_node.delete(falkor_driver) - await falkor_driver.close() diff --git a/tests/test_node_int.py b/tests/test_node_int.py index 9f50f18a..b4aa0709 100644 --- a/tests/test_node_int.py +++ b/tests/test_node_int.py @@ -14,23 +14,22 @@ See the License for the specific language governing permissions and limitations under the License. """ -import os -from datetime import datetime, timezone +from datetime import datetime from uuid import uuid4 +import numpy as np import pytest -from neo4j import AsyncGraphDatabase +from graphiti_core.driver.driver import GraphDriver from graphiti_core.nodes import ( CommunityNode, EntityNode, EpisodeType, EpisodicNode, ) +from tests.helpers_test import drivers, get_driver -NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687') -NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j') -NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test') +group_id = f'test_group_{str(uuid4())}' @pytest.fixture @@ -38,8 +37,8 @@ def sample_entity_node(): return EntityNode( uuid=str(uuid4()), name='Test Entity', - group_id='test_group', - labels=['Entity'], + group_id=group_id, + labels=[], name_embedding=[0.5] * 1024, summary='Entity Summary', ) @@ -50,11 +49,11 @@ def sample_episodic_node(): return EpisodicNode( uuid=str(uuid4()), name='Episode 1', - group_id='test_group', + group_id=group_id, source=EpisodeType.text, source_description='Test source', content='Some content here', - valid_at=datetime.now(timezone.utc), + valid_at=datetime.now(), ) @@ -64,59 +63,181 @@ def sample_community_node(): uuid=str(uuid4()), name='Community A', name_embedding=[0.5] * 1024, - group_id='test_group', + group_id=group_id, summary='Community summary', ) @pytest.mark.asyncio -@pytest.mark.integration -async def test_entity_node_save_get_and_delete(sample_entity_node): - neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) - await sample_entity_node.save(neo4j_driver) - retrieved = await EntityNode.get_by_uuid(neo4j_driver, sample_entity_node.uuid) +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_entity_node(sample_entity_node, driver): + driver = get_driver(driver) + uuid = sample_entity_node.uuid + + # Create node + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + await sample_entity_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 + + retrieved = await EntityNode.get_by_uuid(driver, sample_entity_node.uuid) assert retrieved.uuid == sample_entity_node.uuid assert retrieved.name == 'Test Entity' - assert retrieved.group_id == 'test_group' + assert retrieved.group_id == group_id - await sample_entity_node.delete(neo4j_driver) + retrieved = await EntityNode.get_by_uuids(driver, [sample_entity_node.uuid]) + assert retrieved[0].uuid == sample_entity_node.uuid + assert retrieved[0].name == 'Test Entity' + assert retrieved[0].group_id == group_id - await neo4j_driver.close() + retrieved = await EntityNode.get_by_group_ids(driver, [group_id], limit=2) + assert len(retrieved) == 1 + assert retrieved[0].uuid == sample_entity_node.uuid + assert retrieved[0].name == 'Test Entity' + assert retrieved[0].group_id == group_id + + await sample_entity_node.load_name_embedding(driver) + assert np.allclose(sample_entity_node.name_embedding, [0.5] * 1024) + + # Delete node by uuid + await sample_entity_node.delete(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + # Delete node by group id + await sample_entity_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 + await sample_entity_node.delete_by_group_id(driver, group_id) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + await driver.close() @pytest.mark.asyncio -@pytest.mark.integration -async def test_community_node_save_get_and_delete(sample_community_node): - neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_community_node(sample_community_node, driver): + driver = get_driver(driver) + uuid = sample_community_node.uuid - await sample_community_node.save(neo4j_driver) + # Create node + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + await sample_community_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 - retrieved = await CommunityNode.get_by_uuid(neo4j_driver, sample_community_node.uuid) + retrieved = await CommunityNode.get_by_uuid(driver, sample_community_node.uuid) assert retrieved.uuid == sample_community_node.uuid assert retrieved.name == 'Community A' - assert retrieved.group_id == 'test_group' + assert retrieved.group_id == group_id assert retrieved.summary == 'Community summary' - await sample_community_node.delete(neo4j_driver) + retrieved = await CommunityNode.get_by_uuids(driver, [sample_community_node.uuid]) + assert retrieved[0].uuid == sample_community_node.uuid + assert retrieved[0].name == 'Community A' + assert retrieved[0].group_id == group_id + assert retrieved[0].summary == 'Community summary' - await neo4j_driver.close() + retrieved = await CommunityNode.get_by_group_ids(driver, [group_id], limit=2) + assert len(retrieved) == 1 + assert retrieved[0].uuid == sample_community_node.uuid + assert retrieved[0].name == 'Community A' + assert retrieved[0].group_id == group_id + + # Delete node by uuid + await sample_community_node.delete(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + # Delete node by group id + await sample_community_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 + await sample_community_node.delete_by_group_id(driver, group_id) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + await driver.close() @pytest.mark.asyncio -@pytest.mark.integration -async def test_episodic_node_save_get_and_delete(sample_episodic_node): - neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) +@pytest.mark.parametrize( + 'driver', + drivers, + ids=drivers, +) +async def test_episodic_node(sample_episodic_node, driver): + driver = get_driver(driver) + uuid = sample_episodic_node.uuid - await sample_episodic_node.save(neo4j_driver) + # Create node + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + await sample_episodic_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 - retrieved = await EpisodicNode.get_by_uuid(neo4j_driver, sample_episodic_node.uuid) + retrieved = await EpisodicNode.get_by_uuid(driver, sample_episodic_node.uuid) assert retrieved.uuid == sample_episodic_node.uuid assert retrieved.name == 'Episode 1' - assert retrieved.group_id == 'test_group' + assert retrieved.group_id == group_id assert retrieved.source == EpisodeType.text assert retrieved.source_description == 'Test source' assert retrieved.content == 'Some content here' + assert retrieved.valid_at == sample_episodic_node.valid_at - await sample_episodic_node.delete(neo4j_driver) + retrieved = await EpisodicNode.get_by_uuids(driver, [sample_episodic_node.uuid]) + assert retrieved[0].uuid == sample_episodic_node.uuid + assert retrieved[0].name == 'Episode 1' + assert retrieved[0].group_id == group_id + assert retrieved[0].source == EpisodeType.text + assert retrieved[0].source_description == 'Test source' + assert retrieved[0].content == 'Some content here' + assert retrieved[0].valid_at == sample_episodic_node.valid_at - await neo4j_driver.close() + retrieved = await EpisodicNode.get_by_group_ids(driver, [group_id], limit=2) + assert len(retrieved) == 1 + assert retrieved[0].uuid == sample_episodic_node.uuid + assert retrieved[0].name == 'Episode 1' + assert retrieved[0].group_id == group_id + assert retrieved[0].source == EpisodeType.text + assert retrieved[0].source_description == 'Test source' + assert retrieved[0].content == 'Some content here' + assert retrieved[0].valid_at == sample_episodic_node.valid_at + + # Delete node by uuid + await sample_episodic_node.delete(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + # Delete node by group id + await sample_episodic_node.save(driver) + node_count = await get_node_count(driver, uuid) + assert node_count == 1 + await sample_episodic_node.delete_by_group_id(driver, group_id) + node_count = await get_node_count(driver, uuid) + assert node_count == 0 + + await driver.close() + + +async def get_node_count(driver: GraphDriver, uuid: str): + result, _, _ = await driver.execute_query( + """ + MATCH (n {uuid: $uuid}) + RETURN COUNT(n) as count + """, + uuid=uuid, + ) + return int(result[0]['count']) diff --git a/uv.lock b/uv.lock index 832bec3b..cc771368 100644 --- a/uv.lock +++ b/uv.lock @@ -746,7 +746,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.18.0" +version = "0.18.2" source = { editable = "." } dependencies = [ { name = "diskcache" },