diff --git a/graphiti_core/cross_encoder/openai_reranker_client.py b/graphiti_core/cross_encoder/openai_reranker_client.py new file mode 100644 index 00000000..d1e37228 --- /dev/null +++ b/graphiti_core/cross_encoder/openai_reranker_client.py @@ -0,0 +1,113 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import logging +from typing import Any + +import openai +from openai import AsyncOpenAI +from pydantic import BaseModel + +from ..llm_client import LLMConfig, RateLimitError +from ..prompts import Message +from .client import CrossEncoderClient + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = 'gpt-4o-mini' + + +class BooleanClassifier(BaseModel): + isTrue: bool + + +class OpenAIRerankerClient(CrossEncoderClient): + def __init__(self, config: LLMConfig | None = None): + """ + Initialize the OpenAIClient with the provided configuration, cache setting, and client. + + Args: + config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. + cache (bool): Whether to use caching for responses. Defaults to False. + client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. + + """ + if config is None: + config = LLMConfig() + + self.config = config + self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + + async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]: + openai_messages_list: Any = [ + [ + Message( + role='system', + content='You are an expert tasked with determining whether the passage is relevant to the query', + ), + Message( + role='user', + content=f""" + Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise. + + {query} + + {passage} + + + """, + ), + ] + for passage in passages + ] + try: + responses = await asyncio.gather( + *[ + self.client.chat.completions.create( + model=DEFAULT_MODEL, + messages=openai_messages, + temperature=0, + max_tokens=1, + logit_bias={'6432': 1, '7983': 1}, + logprobs=True, + top_logprobs=2, + ) + for openai_messages in openai_messages_list + ] + ) + + responses_top_logprobs = [ + response.choices[0].logprobs.content[0].top_logprobs + if response.choices[0].logprobs is not None + and response.choices[0].logprobs.content is not None + else [] + for response in responses + ] + scores: list[float] = [] + for top_logprobs in responses_top_logprobs: + for logprob in top_logprobs: + if bool(logprob.token): + scores.append(logprob.logprob) + + results = [(passage, score) for passage, score in zip(passages, scores)] + results.sort(reverse=True, key=lambda x: x[1]) + return results + except openai.RateLimitError as e: + raise RateLimitError from e + except Exception as e: + logger.error(f'Error in generating LLM response: {e}') + raise diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 9f9ec988..c7829c77 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -23,8 +23,11 @@ from dotenv import load_dotenv from neo4j import AsyncGraphDatabase from pydantic import BaseModel +from graphiti_core.cross_encoder.client import CrossEncoderClient +from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder +from graphiti_core.helpers import DEFAULT_DATABASE from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode from graphiti_core.search.search import SearchConfig, search @@ -92,6 +95,7 @@ class Graphiti: password: str, llm_client: LLMClient | None = None, embedder: EmbedderClient | None = None, + cross_encoder: CrossEncoderClient | None = None, store_raw_episode_content: bool = True, ): """ @@ -131,7 +135,7 @@ class Graphiti: Graphiti if you're using the default OpenAIClient. """ self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) - self.database = 'neo4j' + self.database = DEFAULT_DATABASE self.store_raw_episode_content = store_raw_episode_content if llm_client: self.llm_client = llm_client @@ -141,6 +145,10 @@ class Graphiti: self.embedder = embedder else: self.embedder = OpenAIEmbedder() + if cross_encoder: + self.cross_encoder = cross_encoder + else: + self.cross_encoder = OpenAIRerankerClient() async def close(self): """ @@ -648,6 +656,7 @@ class Graphiti: await search( self.driver, self.embedder, + self.cross_encoder, query, group_ids, search_config, @@ -663,8 +672,18 @@ class Graphiti: config: SearchConfig, group_ids: list[str] | None = None, center_node_uuid: str | None = None, + bfs_origin_node_uuids: list[str] | None = None, ) -> SearchResults: - return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid) + return await search( + self.driver, + self.embedder, + self.cross_encoder, + query, + group_ids, + config, + center_node_uuid, + bfs_origin_node_uuids, + ) async def get_nodes_by_query( self, @@ -716,7 +735,13 @@ class Graphiti: nodes = ( await search( - self.driver, self.embedder, query, group_ids, search_config, center_node_uuid + self.driver, + self.embedder, + self.cross_encoder, + query, + group_ids, + search_config, + center_node_uuid, ) ).nodes return nodes diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 1fa5eb7d..206ed5fd 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -21,6 +21,7 @@ from time import time from neo4j import AsyncDriver +from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.edges import EntityEdge from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import SearchRerankerError @@ -39,6 +40,7 @@ from graphiti_core.search.search_config import ( from graphiti_core.search.search_utils import ( community_fulltext_search, community_similarity_search, + edge_bfs_search, edge_fulltext_search, edge_similarity_search, episode_mentions_reranker, @@ -55,10 +57,12 @@ logger = logging.getLogger(__name__) async def search( driver: AsyncDriver, embedder: EmbedderClient, + cross_encoder: CrossEncoderClient, query: str, group_ids: list[str] | None, config: SearchConfig, center_node_uuid: str | None = None, + bfs_origin_node_uuids: list[str] | None = None, ) -> SearchResults: start = time() query_vector = await embedder.create(input=[query.replace('\n', ' ')]) @@ -68,28 +72,34 @@ async def search( edges, nodes, communities = await asyncio.gather( edge_search( driver, + cross_encoder, query, query_vector, group_ids, config.edge_config, center_node_uuid, + bfs_origin_node_uuids, config.limit, ), node_search( driver, + cross_encoder, query, query_vector, group_ids, config.node_config, center_node_uuid, + bfs_origin_node_uuids, config.limit, ), community_search( driver, + cross_encoder, query, query_vector, group_ids, config.community_config, + bfs_origin_node_uuids, config.limit, ), ) @@ -109,11 +119,13 @@ async def search( async def edge_search( driver: AsyncDriver, + cross_encoder: CrossEncoderClient, query: str, query_vector: list[float], group_ids: list[str] | None, config: EdgeSearchConfig | None, center_node_uuid: str | None = None, + bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, ) -> list[EntityEdge]: if config is None: @@ -126,6 +138,7 @@ async def edge_search( edge_similarity_search( driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score ), + edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth), ] ) ) @@ -146,6 +159,10 @@ async def edge_search( reranked_uuids = maximal_marginal_relevance( query_vector, search_result_uuids_and_vectors, config.mmr_lambda ) + elif config.reranker == EdgeReranker.cross_encoder: + fact_to_uuid_map = {edge.fact: edge.uuid for result in search_results for edge in result} + reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys())) + reranked_uuids = [fact_to_uuid_map[fact] for fact, _ in reranked_facts] elif config.reranker == EdgeReranker.node_distance: if center_node_uuid is None: raise SearchRerankerError('No center node provided for Node Distance reranker') @@ -176,11 +193,13 @@ async def edge_search( async def node_search( driver: AsyncDriver, + cross_encoder: CrossEncoderClient, query: str, query_vector: list[float], group_ids: list[str] | None, config: NodeSearchConfig | None, center_node_uuid: str | None = None, + bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, ) -> list[EntityNode]: if config is None: @@ -212,6 +231,12 @@ async def node_search( reranked_uuids = maximal_marginal_relevance( query_vector, search_result_uuids_and_vectors, config.mmr_lambda ) + elif config.reranker == NodeReranker.cross_encoder: + summary_to_uuid_map = { + node.summary: node.uuid for result in search_results for node in result + } + reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys())) + reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries] elif config.reranker == NodeReranker.episode_mentions: reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids) elif config.reranker == NodeReranker.node_distance: @@ -228,10 +253,12 @@ async def node_search( async def community_search( driver: AsyncDriver, + cross_encoder: CrossEncoderClient, query: str, query_vector: list[float], group_ids: list[str] | None, config: CommunitySearchConfig | None, + bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, ) -> list[CommunityNode]: if config is None: @@ -268,6 +295,12 @@ async def community_search( reranked_uuids = maximal_marginal_relevance( query_vector, search_result_uuids_and_vectors, config.mmr_lambda ) + elif config.reranker == CommunityReranker.cross_encoder: + summary_to_uuid_map = { + node.summary: node.uuid for result in search_results for node in result + } + reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys())) + reranked_uuids = [summary_to_uuid_map[fact] for fact, _ in reranked_summaries] reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids] diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index badee7c6..9aa23daa 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -20,7 +20,11 @@ from pydantic import BaseModel, Field from graphiti_core.edges import EntityEdge from graphiti_core.nodes import CommunityNode, EntityNode -from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA +from graphiti_core.search.search_utils import ( + DEFAULT_MIN_SCORE, + DEFAULT_MMR_LAMBDA, + MAX_SEARCH_DEPTH, +) DEFAULT_SEARCH_LIMIT = 10 @@ -28,11 +32,13 @@ DEFAULT_SEARCH_LIMIT = 10 class EdgeSearchMethod(Enum): cosine_similarity = 'cosine_similarity' bm25 = 'bm25' + bfs = 'breadth_first_search' class NodeSearchMethod(Enum): cosine_similarity = 'cosine_similarity' bm25 = 'bm25' + bfs = 'breadth_first_search' class CommunitySearchMethod(Enum): @@ -45,6 +51,7 @@ class EdgeReranker(Enum): node_distance = 'node_distance' episode_mentions = 'episode_mentions' mmr = 'mmr' + cross_encoder = 'cross_encoder' class NodeReranker(Enum): @@ -52,11 +59,13 @@ class NodeReranker(Enum): node_distance = 'node_distance' episode_mentions = 'episode_mentions' mmr = 'mmr' + cross_encoder = 'cross_encoder' class CommunityReranker(Enum): rrf = 'reciprocal_rank_fusion' mmr = 'mmr' + cross_encoder = 'cross_encoder' class EdgeSearchConfig(BaseModel): @@ -64,6 +73,7 @@ class EdgeSearchConfig(BaseModel): reranker: EdgeReranker = Field(default=EdgeReranker.rrf) sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) + bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) class NodeSearchConfig(BaseModel): @@ -71,6 +81,7 @@ class NodeSearchConfig(BaseModel): reranker: NodeReranker = Field(default=NodeReranker.rrf) sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) + bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) class CommunitySearchConfig(BaseModel): @@ -78,6 +89,7 @@ class CommunitySearchConfig(BaseModel): reranker: CommunityReranker = Field(default=CommunityReranker.rrf) sim_min_score: float = Field(default=DEFAULT_MIN_SCORE) mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA) + bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH) class SearchConfig(BaseModel): diff --git a/graphiti_core/search/search_config_recipes.py b/graphiti_core/search/search_config_recipes.py index abba4c29..712793db 100644 --- a/graphiti_core/search/search_config_recipes.py +++ b/graphiti_core/search/search_config_recipes.py @@ -48,14 +48,41 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig( edge_config=EdgeSearchConfig( search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], reranker=EdgeReranker.mmr, + mmr_lambda=1, ), node_config=NodeSearchConfig( search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity], reranker=NodeReranker.mmr, + mmr_lambda=1, ), community_config=CommunitySearchConfig( search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity], reranker=CommunityReranker.mmr, + mmr_lambda=1, + ), +) + +# Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities +COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig( + edge_config=EdgeSearchConfig( + search_methods=[ + EdgeSearchMethod.bm25, + EdgeSearchMethod.cosine_similarity, + EdgeSearchMethod.bfs, + ], + reranker=EdgeReranker.cross_encoder, + ), + node_config=NodeSearchConfig( + search_methods=[ + NodeSearchMethod.bm25, + NodeSearchMethod.cosine_similarity, + NodeSearchMethod.bfs, + ], + reranker=NodeReranker.cross_encoder, + ), + community_config=CommunitySearchConfig( + search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity], + reranker=CommunityReranker.cross_encoder, ), ) @@ -81,7 +108,6 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig( search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], reranker=EdgeReranker.node_distance, ), - limit=30, ) # performs a hybrid search over edges with episode mention reranking diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 57d8b58b..dd963e04 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -19,7 +19,6 @@ import logging from collections import defaultdict from time import time -import neo4j import numpy as np from neo4j import AsyncDriver, Query @@ -38,6 +37,7 @@ logger = logging.getLogger(__name__) RELEVANT_SCHEMA_LIMIT = 3 DEFAULT_MIN_SCORE = 0.6 DEFAULT_MMR_LAMBDA = 0.5 +MAX_SEARCH_DEPTH = 3 MAX_QUERY_LENGTH = 128 @@ -80,23 +80,21 @@ async def get_mentioned_nodes( driver: AsyncDriver, episodes: list[EpisodicNode] ) -> list[EntityNode]: episode_uuids = [episode.uuid for episode in episodes] - async with driver.session( - database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS - ) as session: - result = await session.run( - """ - MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids - RETURN DISTINCT - n.uuid As uuid, - n.group_id AS group_id, - n.name AS name, - n.name_embedding AS name_embedding, - n.created_at AS created_at, - n.summary AS summary - """, - {'uuids': episode_uuids}, - ) - records = [record async for record in result] + records, _, _ = await driver.execute_query( + """ + MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids + RETURN DISTINCT + n.uuid As uuid, + n.group_id AS group_id, + n.name AS name, + n.name_embedding AS name_embedding, + n.created_at AS created_at, + n.summary AS summary + """, + uuids=episode_uuids, + database_=DEFAULT_DATABASE, + routing_='r', + ) nodes = [get_entity_node_from_record(record) for record in records] @@ -107,23 +105,21 @@ async def get_communities_by_nodes( driver: AsyncDriver, nodes: list[EntityNode] ) -> list[CommunityNode]: node_uuids = [node.uuid for node in nodes] - async with driver.session( - database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS - ) as session: - result = await session.run( - """ - MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids - RETURN DISTINCT - c.uuid As uuid, - c.group_id AS group_id, - c.name AS name, - c.name_embedding AS name_embedding - c.created_at AS created_at, - c.summary AS summary - """, - {'uuids': node_uuids}, - ) - records = [record async for record in result] + records, _, _ = await driver.execute_query( + """ + MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids + RETURN DISTINCT + c.uuid As uuid, + c.group_id AS group_id, + c.name AS name, + c.name_embedding AS name_embedding + c.created_at AS created_at, + c.summary AS summary + """, + uuids=node_uuids, + database_=DEFAULT_DATABASE, + routing_='r', + ) communities = [get_community_node_from_record(record) for record in records] @@ -149,7 +145,7 @@ async def edge_fulltext_search( MATCH (n:Entity)-[r {uuid: rel.uuid}]->(m:Entity) WHERE ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid]) AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid]) - RETURN + RETURN r.uuid AS uuid, r.group_id AS group_id, n.uuid AS source_node_uuid, @@ -165,20 +161,16 @@ async def edge_fulltext_search( ORDER BY score DESC LIMIT $limit """) - async with driver.session( - database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS - ) as session: - result = await session.run( - cypher_query, - { - 'query': fuzzy_query, - 'source_uuid': source_node_uuid, - 'target_uuid': target_node_uuid, - 'group_ids': group_ids, - 'limit': limit, - }, - ) - records = [record async for record in result] + records, _, _ = await driver.execute_query( + cypher_query, + query=fuzzy_query, + source_uuid=source_node_uuid, + target_uuid=target_node_uuid, + group_ids=group_ids, + limit=limit, + database_=DEFAULT_DATABASE, + routing_='r', + ) edges = [get_entity_edge_from_record(record) for record in records] @@ -201,13 +193,13 @@ async def edge_similarity_search( WHERE ($group_ids IS NULL OR r.group_id IN $group_ids) AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid]) AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid]) - WITH DISTINCT n, r, m, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score + WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score WHERE score > $min_score RETURN r.uuid AS uuid, r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, + startNode(r).uuid AS source_node_uuid, + endNode(r).uuid AS target_node_uuid, r.created_at AS created_at, r.name AS name, r.fact AS fact, @@ -220,21 +212,59 @@ async def edge_similarity_search( LIMIT $limit """) - async with driver.session( - database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS - ) as session: - result = await session.run( - query, - { - 'search_vector': search_vector, - 'source_uuid': source_node_uuid, - 'target_uuid': target_node_uuid, - 'group_ids': group_ids, - 'limit': limit, - 'min_score': min_score, - }, - ) - records = [record async for record in result] + records, _, _ = await driver.execute_query( + query, + search_vector=search_vector, + source_uuid=source_node_uuid, + target_uuid=target_node_uuid, + group_ids=group_ids, + limit=limit, + min_score=min_score, + database_=DEFAULT_DATABASE, + routing_='r', + ) + + edges = [get_entity_edge_from_record(record) for record in records] + + return edges + + +async def edge_bfs_search( + driver: AsyncDriver, + bfs_origin_node_uuids: list[str] | None, + bfs_max_depth: int, +) -> list[EntityEdge]: + # vector similarity search over embedded facts + if bfs_origin_node_uuids is None: + return [] + + query = Query(""" + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) + UNWIND relationships(path) AS rel + MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-() + RETURN DISTINCT + r.uuid AS uuid, + r.group_id AS group_id, + startNode(r).uuid AS source_node_uuid, + endNode(r).uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + """) + + records, _, _ = await driver.execute_query( + query, + bfs_origin_node_uuids=bfs_origin_node_uuids, + depth=bfs_max_depth, + database_=DEFAULT_DATABASE, + routing_='r', + ) edges = [get_entity_edge_from_record(record) for record in records] @@ -252,30 +282,26 @@ async def node_fulltext_search( if fuzzy_query == '': return [] - async with driver.session( - database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS - ) as session: - result = await session.run( - """ - CALL db.index.fulltext.queryNodes("node_name_and_summary", $query) - YIELD node AS n, score - RETURN - n.uuid AS uuid, - n.group_id AS group_id, - n.name AS name, - n.name_embedding AS name_embedding, - n.created_at AS created_at, - n.summary AS summary - ORDER BY score DESC - LIMIT $limit - """, - { - 'query': fuzzy_query, - 'group_ids': group_ids, - 'limit': limit, - }, - ) - records = [record async for record in result] + records, _, _ = await driver.execute_query( + """ + CALL db.index.fulltext.queryNodes("node_name_and_summary", $query) + YIELD node AS n, score + RETURN + n.uuid AS uuid, + n.group_id AS group_id, + n.name AS name, + n.name_embedding AS name_embedding, + n.created_at AS created_at, + n.summary AS summary + ORDER BY score DESC + LIMIT $limit + """, + query=fuzzy_query, + group_ids=group_ids, + limit=limit, + database_=DEFAULT_DATABASE, + routing_='r', + ) nodes = [get_entity_node_from_record(record) for record in records] return nodes @@ -289,34 +315,62 @@ async def node_similarity_search( min_score: float = DEFAULT_MIN_SCORE, ) -> list[EntityNode]: # vector similarity search over entity names - async with driver.session( - database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS - ) as session: - result = await session.run( - """ - CYPHER runtime = parallel parallelRuntimeSupport=all - MATCH (n:Entity) - WHERE $group_ids IS NULL OR n.group_id IN $group_ids - WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score - WHERE score > $min_score - RETURN - n.uuid As uuid, - n.group_id AS group_id, - n.name AS name, - n.name_embedding AS name_embedding, - n.created_at AS created_at, - n.summary AS summary - ORDER BY score DESC - LIMIT $limit - """, - { - 'search_vector': search_vector, - 'group_ids': group_ids, - 'limit': limit, - 'min_score': min_score, - }, - ) - records = [record async for record in result] + records, _, _ = await driver.execute_query( + """ + CYPHER runtime = parallel parallelRuntimeSupport=all + MATCH (n:Entity) + WHERE $group_ids IS NULL OR n.group_id IN $group_ids + WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score + WHERE score > $min_score + RETURN + n.uuid As uuid, + n.group_id AS group_id, + n.name AS name, + n.name_embedding AS name_embedding, + n.created_at AS created_at, + n.summary AS summary + ORDER BY score DESC + LIMIT $limit + """, + search_vector=search_vector, + group_ids=group_ids, + limit=limit, + min_score=min_score, + database_=DEFAULT_DATABASE, + routing_='r', + ) + nodes = [get_entity_node_from_record(record) for record in records] + + return nodes + + +async def node_bfs_search( + driver: AsyncDriver, + bfs_origin_node_uuids: list[str] | None, + bfs_max_depth: int, +) -> list[EntityNode]: + # vector similarity search over entity names + if bfs_origin_node_uuids is None: + return [] + + records, _, _ = await driver.execute_query( + """ + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) + RETURN DISTINCT + n.uuid As uuid, + n.group_id AS group_id, + n.name AS name, + n.name_embedding AS name_embedding, + n.created_at AS created_at, + n.summary AS summary + LIMIT $limit + """, + bfs_origin_node_uuids=bfs_origin_node_uuids, + depth=bfs_max_depth, + database_=DEFAULT_DATABASE, + routing_='r', + ) nodes = [get_entity_node_from_record(record) for record in records] return nodes @@ -333,30 +387,26 @@ async def community_fulltext_search( if fuzzy_query == '': return [] - async with driver.session( - database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS - ) as session: - result = await session.run( - """ - CALL db.index.fulltext.queryNodes("community_name", $query) - YIELD node AS comm, score - RETURN - comm.uuid AS uuid, - comm.group_id AS group_id, - comm.name AS name, - comm.name_embedding AS name_embedding, - comm.created_at AS created_at, - comm.summary AS summary - ORDER BY score DESC - LIMIT $limit - """, - { - 'query': fuzzy_query, - 'group_ids': group_ids, - 'limit': limit, - }, - ) - records = [record async for record in result] + records, _, _ = await driver.execute_query( + """ + CALL db.index.fulltext.queryNodes("community_name", $query) + YIELD node AS comm, score + RETURN + comm.uuid AS uuid, + comm.group_id AS group_id, + comm.name AS name, + comm.name_embedding AS name_embedding, + comm.created_at AS created_at, + comm.summary AS summary + ORDER BY score DESC + LIMIT $limit + """, + query=fuzzy_query, + group_ids=group_ids, + limit=limit, + database_=DEFAULT_DATABASE, + routing_='r', + ) communities = [get_community_node_from_record(record) for record in records] return communities @@ -370,34 +420,30 @@ async def community_similarity_search( min_score=DEFAULT_MIN_SCORE, ) -> list[CommunityNode]: # vector similarity search over entity names - async with driver.session( - database=DEFAULT_DATABASE, default_access_mode=neo4j.READ_ACCESS - ) as session: - result = await session.run( - """ - CYPHER runtime = parallel parallelRuntimeSupport=all - MATCH (comm:Community) - WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids) - WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score - WHERE score > $min_score - RETURN - comm.uuid As uuid, - comm.group_id AS group_id, - comm.name AS name, - comm.name_embedding AS name_embedding, - comm.created_at AS created_at, - comm.summary AS summary - ORDER BY score DESC - LIMIT $limit - """, - { - 'search_vector': search_vector, - 'group_ids': group_ids, - 'limit': limit, - 'min_score': min_score, - }, - ) - records = [record async for record in result] + records, _, _ = await driver.execute_query( + """ + CYPHER runtime = parallel parallelRuntimeSupport=all + MATCH (comm:Community) + WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids) + WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score + WHERE score > $min_score + RETURN + comm.uuid As uuid, + comm.group_id AS group_id, + comm.name AS name, + comm.name_embedding AS name_embedding, + comm.created_at AS created_at, + comm.summary AS summary + ORDER BY score DESC + LIMIT $limit + """, + search_vector=search_vector, + group_ids=group_ids, + limit=limit, + min_score=min_score, + database_=DEFAULT_DATABASE, + routing_='r', + ) communities = [get_community_node_from_record(record) for record in records] return communities diff --git a/pyproject.toml b/pyproject.toml index 45fcea33..d2458b2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.3.16" +version = "0.3.17" description = "A temporal graph building library" authors = [ "Paul Paliychuk ", diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index aa6a95b4..f7296f39 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -26,7 +26,9 @@ from dotenv import load_dotenv from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.graphiti import Graphiti from graphiti_core.nodes import EntityNode, EpisodicNode -from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF +from graphiti_core.search.search_config_recipes import ( + COMBINED_HYBRID_SEARCH_CROSS_ENCODER, +) pytestmark = pytest.mark.integration @@ -60,22 +62,19 @@ def setup_logging(): return logger -def format_context(facts): - formatted_string = '' - formatted_string += 'FACTS:\n' - for fact in facts: - formatted_string += f' - {fact}\n' - formatted_string += '\n' - - return formatted_string.strip() - - @pytest.mark.asyncio async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) + episodes = await graphiti.retrieve_episodes(datetime.now(), group_ids=None) + episode_uuids = [episode.uuid for episode in episodes] - results = await graphiti._search('new house', COMBINED_HYBRID_SEARCH_RRF, group_ids=None) + results = await graphiti._search( + "Emily: I can't log in", + COMBINED_HYBRID_SEARCH_CROSS_ENCODER, + bfs_origin_node_uuids=episode_uuids, + group_ids=None, + ) pretty_results = { 'edges': [edge.fact for edge in results.edges], 'nodes': [node.name for node in results.nodes], @@ -83,6 +82,7 @@ async def test_graphiti_init(): } logger.info(pretty_results) + await graphiti.close()