diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index f100926e..1de7eeed 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -86,4 +86,4 @@ async def main(use_bulk: bool = True): await client.add_episode_bulk(episodes) -asyncio.run(main(False)) +asyncio.run(main(True)) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 2a23d79c..3f4dcfcb 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -180,9 +180,9 @@ class Graphiti: await build_indices_and_constraints(self.driver) async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -210,14 +210,14 @@ class Graphiti: return await retrieve_episodes(self.driver, reference_time, last_n) async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - source: EpisodeType = EpisodeType.message, - success_callback: Callable | None = None, - error_callback: Callable | None = None, + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source: EpisodeType = EpisodeType.message, + success_callback: Callable | None = None, + error_callback: Callable | None = None, ): """ Process an episode and update the graph. @@ -321,11 +321,11 @@ class Graphiti: await asyncio.gather( *[ get_relevant_edges( - [edge], self.driver, - RELEVANT_SCHEMA_LIMIT, + [edge], edge.source_node_uuid, edge.target_node_uuid, + RELEVANT_SCHEMA_LIMIT, ) for edge in extracted_edges ] @@ -422,8 +422,8 @@ class Graphiti: raise e async def add_episode_bulk( - self, - bulk_episodes: list[RawEpisode], + self, + bulk_episodes: list[RawEpisode], ): """ Process multiple episodes in bulk and update the graph. @@ -587,18 +587,18 @@ class Graphiti: return edges async def _search( - self, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + self, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ): return await hybrid_search( self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid ) async def get_nodes_by_query( - self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT + self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: """ Retrieve nodes from the graph database based on a text query. diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index f27249c4..172cd123 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -83,7 +83,7 @@ async def hybrid_search( nodes.extend(await get_mentioned_nodes(driver, episodes)) if SearchMethod.bm25 in config.search_methods: - text_search = await edge_fulltext_search(driver, query, 2 * config.num_edges) + text_search = await edge_fulltext_search(driver, query, None, None, 2 * config.num_edges) search_results.append(text_search) if SearchMethod.cosine_similarity in config.search_methods: @@ -95,7 +95,7 @@ async def hybrid_search( ) similarity_search = await edge_similarity_search( - driver, search_vector, 2 * config.num_edges + driver, search_vector, None, None, 2 * config.num_edges ) search_results.append(similarity_search) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 0cfdb601..22235cbc 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -1,11 +1,11 @@ import asyncio import logging import re -import typing from collections import defaultdict from time import time +from typing import Any -from neo4j import AsyncDriver +from neo4j import AsyncDriver, Query from graphiti_core.edges import EntityEdge from graphiti_core.helpers import parse_db_date @@ -66,12 +66,12 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): r.expired_at AS expired_at, r.valid_at AS valid_at, r.invalid_at AS invalid_at - + """, node_ids=node_ids, ) - context: dict[str, typing.Any] = {} + context: dict[str, Any] = {} for record in records: n_uuid = record['source_node_uuid'] @@ -96,15 +96,14 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): async def edge_similarity_search( - driver: AsyncDriver, - search_vector: list[float], - limit: int = RELEVANT_SCHEMA_LIMIT, - source_node_uuid: str = '*', - target_node_uuid: str = '*', + driver: AsyncDriver, + search_vector: list[float], + source_node_uuid: str | None, + target_node_uuid: str | None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: # vector similarity search over embedded facts - records, _, _ = await driver.execute_query( - """ + query = Query(""" CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) @@ -121,7 +120,68 @@ async def edge_similarity_search( r.valid_at AS valid_at, r.invalid_at AS invalid_at ORDER BY score DESC - """, + """) + + if source_node_uuid is None and target_node_uuid is None: + query = Query(""" + CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) + YIELD relationship AS rel, score + MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) + RETURN + r.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC + """) + elif source_node_uuid is None: + query = Query(""" + CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) + YIELD relationship AS rel, score + MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + RETURN + r.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC + """) + elif target_node_uuid is None: + query = Query(""" + CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) + YIELD relationship AS rel, score + MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) + RETURN + r.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC + """) + + records, _, _ = await driver.execute_query( + query, search_vector=search_vector, source_uuid=source_node_uuid, target_uuid=target_node_uuid, @@ -151,7 +211,7 @@ async def edge_similarity_search( async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # vector similarity search over entity names records, _, _ = await driver.execute_query( @@ -161,6 +221,7 @@ async def entity_similarity_search( RETURN n.uuid As uuid, n.name AS name, + n.name_embeddings AS name_embedding, n.created_at AS created_at, n.summary AS summary ORDER BY score DESC @@ -175,6 +236,7 @@ async def entity_similarity_search( EntityNode( uuid=record['uuid'], name=record['name'], + name_embedding=record['name_embedding'], labels=['Entity'], created_at=record['created_at'].to_native(), summary=record['summary'], @@ -185,7 +247,7 @@ async def entity_similarity_search( async def entity_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # BM25 search to get top nodes fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' @@ -193,8 +255,9 @@ async def entity_fulltext_search( """ CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score RETURN - node.uuid As uuid, + node.uuid AS uuid, node.name AS name, + node.name_embeddings AS name_embedding, node.created_at AS created_at, node.summary AS summary ORDER BY score DESC @@ -210,6 +273,7 @@ async def entity_fulltext_search( EntityNode( uuid=record['uuid'], name=record['name'], + name_embedding=record['name_embedding'], labels=['Entity'], created_at=record['created_at'].to_native(), summary=record['summary'], @@ -220,21 +284,18 @@ async def entity_fulltext_search( async def edge_fulltext_search( - driver: AsyncDriver, - query: str, - limit=RELEVANT_SCHEMA_LIMIT, - source_node_uuid: str = '*', - target_node_uuid: str = '*', + driver: AsyncDriver, + query: str, + source_node_uuid: str | None, + target_node_uuid: str | None, + limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: # fulltext search over facts - fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' - - records, _, _ = await driver.execute_query( - """ - CALL db.index.fulltext.queryRelationships("name_and_fact", $query) - YIELD relationship AS rel, score - MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) - RETURN + cypher_query = Query(""" + CALL db.index.fulltext.queryRelationships("name_and_fact", $query) + YIELD relationship AS rel, score + MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + RETURN r.uuid AS uuid, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, @@ -247,7 +308,70 @@ async def edge_fulltext_search( r.valid_at AS valid_at, r.invalid_at AS invalid_at ORDER BY score DESC LIMIT $limit - """, + """) + + if source_node_uuid is None and target_node_uuid is None: + cypher_query = Query(""" + CALL db.index.fulltext.queryRelationships("name_and_fact", $query) + YIELD relationship AS rel, score + MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) + RETURN + r.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC LIMIT $limit + """) + elif source_node_uuid is None: + cypher_query = Query(""" + CALL db.index.fulltext.queryRelationships("name_and_fact", $query) + YIELD relationship AS rel, score + MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) + RETURN + r.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC LIMIT $limit + """) + elif target_node_uuid is None: + cypher_query = Query(""" + CALL db.index.fulltext.queryRelationships("name_and_fact", $query) + YIELD relationship AS rel, score + MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) + RETURN + r.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC LIMIT $limit + """) + + fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' + + records, _, _ = await driver.execute_query( + cypher_query, query=fuzzy_query, source_uuid=source_node_uuid, target_uuid=target_node_uuid, @@ -277,16 +401,16 @@ async def edge_fulltext_search( async def hybrid_node_search( - queries: list[str], - embeddings: list[list[float]], - driver: AsyncDriver, - limit: int = RELEVANT_SCHEMA_LIMIT, + queries: list[str], + embeddings: list[list[float]], + driver: AsyncDriver, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ Perform a hybrid search for nodes using both text queries and embeddings. This method combines fulltext search and vector similarity search to find - relevant nodes in the graph database. It uses an rrf reranker. + relevant nodes in the graph database. It uses a rrf reranker. Parameters ---------- @@ -342,8 +466,8 @@ async def hybrid_node_search( async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, + nodes: list[EntityNode], + driver: AsyncDriver, ) -> list[EntityNode]: """ Retrieve relevant nodes based on the provided list of EntityNodes. @@ -379,11 +503,11 @@ async def get_relevant_nodes( async def get_relevant_edges( - edges: list[EntityEdge], - driver: AsyncDriver, - limit: int = RELEVANT_SCHEMA_LIMIT, - source_node_uuid: str = '*', - target_node_uuid: str = '*', + driver: AsyncDriver, + edges: list[EntityEdge], + source_node_uuid: str | None, + target_node_uuid: str | None, + limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: start = time() relevant_edges: list[EntityEdge] = [] @@ -392,13 +516,13 @@ async def get_relevant_edges( results = await asyncio.gather( *[ edge_similarity_search( - driver, edge.fact_embedding, limit, source_node_uuid, target_node_uuid + driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit ) for edge in edges if edge.fact_embedding is not None ], *[ - edge_fulltext_search(driver, edge.fact, limit, source_node_uuid, target_node_uuid) + edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit) for edge in edges ], ) @@ -433,14 +557,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: async def node_distance_reranker( - driver: AsyncDriver, results: list[list[str]], center_node_uuid: str + driver: AsyncDriver, results: list[list[str]], center_node_uuid: str ) -> list[str]: # use rrf as a preliminary ranker sorted_uuids = rrf(results) scores: dict[str, float] = {} for uuid in sorted_uuids: - # Find shortest path to center node + # Find the shortest path to center node records, _, _ = await driver.execute_query( """ MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity) @@ -455,8 +579,8 @@ async def node_distance_reranker( for record in records: if ( - record['source_uuid'] == center_node_uuid - or record['target_uuid'] == center_node_uuid + record['source_uuid'] == center_node_uuid + or record['target_uuid'] == center_node_uuid ): continue distance = record['score'] diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 50c702aa..4c8f12a4 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -158,7 +158,7 @@ async def dedupe_edges_bulk( relevant_edges_chunks: list[list[EntityEdge]] = list( await asyncio.gather( - *[get_relevant_edges(edge_chunk, driver) for edge_chunk in edge_chunks] + *[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks] ) ) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index d1e8f43b..30673eef 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) async def extract_message_nodes( - llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode] + llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode] ) -> list[dict[str, Any]]: # Prepare context for LLM context = { @@ -49,8 +49,8 @@ async def extract_message_nodes( async def extract_json_nodes( - llm_client: LLMClient, - episode: EpisodicNode, + llm_client: LLMClient, + episode: EpisodicNode, ) -> list[dict[str, Any]]: # Prepare context for LLM context = { @@ -67,9 +67,9 @@ async def extract_json_nodes( async def extract_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], ) -> list[EntityNode]: start = time() extracted_node_data: list[dict[str, Any]] = [] @@ -96,9 +96,9 @@ async def extract_nodes( async def dedupe_extracted_nodes( - llm_client: LLMClient, - extracted_nodes: list[EntityNode], - existing_nodes: list[EntityNode], + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: start = time() @@ -146,9 +146,9 @@ async def dedupe_extracted_nodes( async def resolve_extracted_nodes( - llm_client: LLMClient, - extracted_nodes: list[EntityNode], - existing_nodes_lists: list[list[EntityNode]], + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes_lists: list[list[EntityNode]], ) -> tuple[list[EntityNode], dict[str, str]]: uuid_map: dict[str, str] = {} resolved_nodes: list[EntityNode] = [] @@ -169,7 +169,7 @@ async def resolve_extracted_nodes( async def resolve_extracted_node( - llm_client: LLMClient, extracted_node: EntityNode, existing_nodes: list[EntityNode] + llm_client: LLMClient, extracted_node: EntityNode, existing_nodes: list[EntityNode] ) -> tuple[EntityNode, dict[str, str]]: start = time() @@ -214,8 +214,8 @@ async def resolve_extracted_node( async def dedupe_node_list( - llm_client: LLMClient, - nodes: list[EntityNode], + llm_client: LLMClient, + nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: start = time() diff --git a/pyproject.toml b/pyproject.toml index a0cc66a0..81d87afd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.2.0" +version = "0.2.1" description = "A temporal graph building library" authors = [ "Paul Paliychuk ",