From fd341a6f16254f896472e3646f509f28351ddb74 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Thu, 26 Sep 2024 16:12:38 -0400 Subject: [PATCH] Add MSC benchmark and improve search performance (#157) * test cases * test * benchmark * eval updates * improve search performance * remove data * formatting * add None type to config * update sanitization * push version * maketrans update * mypy --- .../msc_eval.py | 126 ++++++++++++ .../msc_runner.py | 91 +++++++++ .../parse_msc_messages.py | 85 ++++++++ graphiti_core/graphiti.py | 7 +- graphiti_core/helpers.py | 30 +++ graphiti_core/prompts/eval.py | 90 +++++++++ graphiti_core/prompts/lib.py | 6 + graphiti_core/search/search.py | 52 +++-- graphiti_core/search/search_utils.py | 186 ++++-------------- .../maintenance/graph_data_operations.py | 48 ++--- pyproject.toml | 2 +- tests/test_graphiti_int.py | 3 +- 12 files changed, 517 insertions(+), 209 deletions(-) create mode 100644 examples/multi_session_conversation_memory/msc_eval.py create mode 100644 examples/multi_session_conversation_memory/msc_runner.py create mode 100644 examples/multi_session_conversation_memory/parse_msc_messages.py create mode 100644 graphiti_core/prompts/eval.py diff --git a/examples/multi_session_conversation_memory/msc_eval.py b/examples/multi_session_conversation_memory/msc_eval.py new file mode 100644 index 00000000..0042deee --- /dev/null +++ b/examples/multi_session_conversation_memory/msc_eval.py @@ -0,0 +1,126 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import csv +import logging +import os +import sys +from time import time + +from dotenv import load_dotenv + +from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a +from graphiti_core import Graphiti +from graphiti_core.prompts import prompt_library +from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF + +load_dotenv() + +neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' +neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' +neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' + + +def setup_logging(): + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO + + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # Add formatter to console handler + console_handler.setFormatter(formatter) + + # Add console handler to logger + logger.addHandler(console_handler) + + return logger + + +async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str): + search_start = time() + results = await graphiti._search(query, COMBINED_HYBRID_SEARCH_RRF, group_ids=[str(group_id)]) + search_end = time() + search_duration = search_end - search_start + + facts = [edge.fact for edge in results.edges] + entity_summaries = [node.name + ': ' + node.summary for node in results.nodes] + context = {'facts': facts, 'entity_summaries': entity_summaries, 'query': 'Bob: ' + query} + + llm_response = await graphiti.llm_client.generate_response( + prompt_library.eval.qa_prompt(context) + ) + response = llm_response.get('ANSWER', '') + + eval_context = { + 'query': 'Bob: ' + query, + 'answer': 'Alice: ' + answer, + 'response': 'Alice: ' + response, + } + + eval_llm_response = await graphiti.llm_client.generate_response( + prompt_library.eval.eval_prompt(eval_context) + ) + eval_response = 1 if eval_llm_response.get('is_correct', False) else 0 + + return { + 'Group id': group_id, + 'Question': query, + 'Answer': answer, + 'Response': response, + 'Score': eval_response, + 'Search Duration (ms)': search_duration * 1000, + } + + +async def main(): + setup_logging() + graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password) + + fields = ['Group id', 'Question', 'Answer', 'Response', 'Score', 'Search Duration (ms)'] + with open('../data/msc_eval.csv', 'w', newline='') as file: + writer = csv.DictWriter(file, fieldnames=fields) + writer.writeheader() + + qa = conversation_q_and_a()[0:500] + i = 0 + while i < 500: + qa_chunk = qa[i : i + 20] + group_ids = range(len(qa))[i : i + 20] + results = list( + await asyncio.gather( + *[ + evaluate_qa(graphiti, str(group_id), query, answer) + for group_id, (query, answer) in zip(group_ids, qa_chunk) + ] + ) + ) + + with open('../data/msc_eval.csv', 'a', newline='') as file: + writer = csv.DictWriter(file, fieldnames=fields) + writer.writerows(results) + i += 20 + + await graphiti.close() + + +asyncio.run(main()) diff --git a/examples/multi_session_conversation_memory/msc_runner.py b/examples/multi_session_conversation_memory/msc_runner.py new file mode 100644 index 00000000..3ea412ae --- /dev/null +++ b/examples/multi_session_conversation_memory/msc_runner.py @@ -0,0 +1,91 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import logging +import os +import sys + +from dotenv import load_dotenv + +from examples.multi_session_conversation_memory.parse_msc_messages import ( + ParsedMscMessage, + parse_msc_messages, +) +from graphiti_core import Graphiti + +load_dotenv() + +neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' +neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' +neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' + + +def setup_logging(): + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO + + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # Add formatter to console handler + console_handler.setFormatter(formatter) + + # Add console handler to logger + logger.addHandler(console_handler) + + return logger + + +async def add_conversation(graphiti: Graphiti, group_id: str, messages: list[ParsedMscMessage]): + for i, message in enumerate(messages): + await graphiti.add_episode( + name=f'Message {group_id + "-" + str(i)}', + episode_body=f'{message.speaker_name}: {message.content}', + reference_time=message.actual_timestamp, + source_description='Multi-Session Conversation', + group_id=group_id, + ) + + +async def main(): + setup_logging() + graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password) + msc_messages = parse_msc_messages() + i = 0 + while i <= 490: + msc_message_slice = msc_messages[i : i + 10] + group_ids = range(len(msc_messages))[i : i + 10] + + await asyncio.gather( + *[ + add_conversation(graphiti, str(group_id), messages) + for group_id, messages in zip(group_ids, msc_message_slice) + ] + ) + + i += 10 + + # build communities + # await client.build_communities() + + +asyncio.run(main()) diff --git a/examples/multi_session_conversation_memory/parse_msc_messages.py b/examples/multi_session_conversation_memory/parse_msc_messages.py new file mode 100644 index 00000000..1cb28ef8 --- /dev/null +++ b/examples/multi_session_conversation_memory/parse_msc_messages.py @@ -0,0 +1,85 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +from datetime import datetime + +from pydantic import BaseModel + + +class ParsedMscMessage(BaseModel): + speaker_name: str + actual_timestamp: datetime + content: str + group_id: str + + +def parse_msc_messages() -> list[list[ParsedMscMessage]]: + msc_messages: list[list[ParsedMscMessage]] = [] + speakers = ['Alice', 'Bob'] + + with open('../data/msc.json') as file: + data = json.load(file)['data'] + for i, conversation in enumerate(data): + messages: list[ParsedMscMessage] = [] + for previous_dialog in conversation['previous_dialogs']: + dialog = previous_dialog['dialog'] + speaker_idx = 0 + + for utterance in dialog: + content = utterance['text'] + messages.append( + ParsedMscMessage( + speaker_name=speakers[speaker_idx], + content=content, + actual_timestamp=datetime.now(), + group_id=str(i), + ) + ) + speaker_idx += 1 + speaker_idx %= 2 + + dialog = conversation['dialog'] + speaker_idx = 0 + for utterance in dialog: + content = utterance['text'] + messages.append( + ParsedMscMessage( + speaker_name=speakers[speaker_idx], + content=content, + actual_timestamp=datetime.now(), + group_id=str(i), + ) + ) + speaker_idx += 1 + speaker_idx %= 2 + + msc_messages.append(messages) + + return msc_messages + + +def conversation_q_and_a() -> list[tuple[str, str]]: + with open('../data/msc.json') as file: + data = json.load(file)['data'] + + qa: list[tuple[str, str]] = [] + for conversation in data: + query = conversation['self_instruct']['B'] + answer = conversation['self_instruct']['A'] + + qa.append((query, answer)) + return qa diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 540c9ff1..b2432110 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -161,7 +161,7 @@ class Graphiti: """ await self.driver.close() - async def build_indices_and_constraints(self): + async def build_indices_and_constraints(self, delete_existing: bool = False): """ Build indices and constraints in the Neo4j database. @@ -171,6 +171,9 @@ class Graphiti: Parameters ---------- self + delete_existing : bool, optional + Whether to clear existing indices before creating new ones. + Returns ------- @@ -191,7 +194,7 @@ class Graphiti: Caution: Running this method on a large existing database may take some time and could impact database performance during execution. """ - await build_indices_and_constraints(self.driver) + await build_indices_and_constraints(self.driver, delete_existing) async def retrieve_episodes( self, diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 9471058e..80cd9752 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -21,3 +21,33 @@ from neo4j import time as neo4j_time def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None: return neo_date.to_native() if neo_date else None + + +def lucene_sanitize(query: str) -> str: + # Escape special characters from a query before passing into Lucene + # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ + escape_map = str.maketrans( + { + '+': r'\+', + '-': r'\-', + '&': r'\&', + '|': r'\|', + '!': r'\!', + '(': r'\(', + ')': r'\)', + '{': r'\{', + '}': r'\}', + '[': r'\[', + ']': r'\]', + '^': r'\^', + '"': r'\"', + '~': r'\~', + '*': r'\*', + '?': r'\?', + ':': r'\:', + '\\': r'\\', + } + ) + + sanitized = query.translate(escape_map) + return sanitized diff --git a/graphiti_core/prompts/eval.py b/graphiti_core/prompts/eval.py new file mode 100644 index 00000000..8fa22d3f --- /dev/null +++ b/graphiti_core/prompts/eval.py @@ -0,0 +1,90 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +from typing import Any, Protocol, TypedDict + +from .models import Message, PromptFunction, PromptVersion + + +class Prompt(Protocol): + qa_prompt: PromptVersion + eval_prompt: PromptVersion + + +class Versions(TypedDict): + qa_prompt: PromptFunction + eval_prompt: PromptFunction + + +def qa_prompt(context: dict[str, Any]) -> list[Message]: + sys_prompt = """You are Alice and should respond to all questions from the first person perspective of Alice""" + + user_prompt = f""" + Your task is to briefly answer the question in the way that you think Alice would answer the question. + You are given the following entity summaries and facts to help you determine the answer to your question. + + {json.dumps(context['entity_summaries'])} + + {json.dumps(context['facts'])} + + + {context['query']} + + respond with a JSON object in the following format: + {{ + "ANSWER": "how Alice would answer the question" + }} + """ + return [ + Message(role='system', content=sys_prompt), + Message(role='user', content=user_prompt), + ] + + +def eval_prompt(context: dict[str, Any]) -> list[Message]: + sys_prompt = ( + """You are a judge that determines if answers to questions match a gold standard answer""" + ) + + user_prompt = f""" + Given the QUESTION and the gold standard ANSWER determine if the RESPONSE to the question is correct or incorrect. + Although the RESPONSE may be more verbose, mark it as correct as long as it references the same topic + as the gold standard ANSWER. Also include your reasoning for the grade. + + {context['query']} + + + {context['answer']} + + + {context['response']} + + + respond with a JSON object in the following format: + {{ + "is_correct": "boolean if the answer is correct or incorrect" + "reasoning": "why you determined the response was correct or incorrect" + }} + """ + return [ + Message(role='system', content=sys_prompt), + Message(role='user', content=user_prompt), + ] + + +versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt} diff --git a/graphiti_core/prompts/lib.py b/graphiti_core/prompts/lib.py index 09cc3ec0..cb971d0f 100644 --- a/graphiti_core/prompts/lib.py +++ b/graphiti_core/prompts/lib.py @@ -34,6 +34,9 @@ from .dedupe_nodes import ( from .dedupe_nodes import ( versions as dedupe_nodes_versions, ) +from .eval import Prompt as EvalPrompt +from .eval import Versions as EvalVersions +from .eval import versions as eval_versions from .extract_edge_dates import ( Prompt as ExtractEdgeDatesPrompt, ) @@ -84,6 +87,7 @@ class PromptLibrary(Protocol): invalidate_edges: InvalidateEdgesPrompt extract_edge_dates: ExtractEdgeDatesPrompt summarize_nodes: SummarizeNodesPrompt + eval: EvalPrompt class PromptLibraryImpl(TypedDict): @@ -94,6 +98,7 @@ class PromptLibraryImpl(TypedDict): invalidate_edges: InvalidateEdgesVersions extract_edge_dates: ExtractEdgeDatesVersions summarize_nodes: SummarizeNodesVersions + eval: EvalVersions class VersionWrapper: @@ -124,5 +129,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = { 'invalidate_edges': invalidate_edges_versions, 'extract_edge_dates': extract_edge_dates_versions, 'summarize_nodes': summarize_nodes_versions, + 'eval': eval_versions, } prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment] diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 862ececd..1a3ec50c 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import asyncio import logging from collections import defaultdict from time import time @@ -65,32 +66,20 @@ async def search( query = query.replace('\n', ' ') # if group_ids is empty, set it to None group_ids = group_ids if group_ids else None - edges = ( - await edge_search( + edges, nodes, communities = await asyncio.gather( + edge_search( driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit - ) - if config.edge_config is not None - else [] - ) - nodes = ( - await node_search( + ), + node_search( driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit - ) - if config.node_config is not None - else [] - ) - communities = ( - await community_search( - driver, embedder, query, group_ids, config.community_config, config.limit - ) - if config.community_config is not None - else [] + ), + community_search(driver, embedder, query, group_ids, config.community_config, config.limit), ) results = SearchResults( - edges=edges[: config.limit], - nodes=nodes[: config.limit], - communities=communities[: config.limit], + edges=edges, + nodes=nodes, + communities=communities, ) end = time() @@ -105,10 +94,13 @@ async def edge_search( embedder, query: str, group_ids: list[str] | None, - config: EdgeSearchConfig, + config: EdgeSearchConfig | None, center_node_uuid: str | None = None, limit=DEFAULT_SEARCH_LIMIT, ) -> list[EntityEdge]: + if config is None: + return [] + search_results: list[list[EntityEdge]] = [] if EdgeSearchMethod.bm25 in config.search_methods: @@ -162,7 +154,7 @@ async def edge_search( if config.reranker == EdgeReranker.episode_mentions: reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes)) - return reranked_edges + return reranked_edges[:limit] async def node_search( @@ -170,10 +162,13 @@ async def node_search( embedder, query: str, group_ids: list[str] | None, - config: NodeSearchConfig, + config: NodeSearchConfig | None, center_node_uuid: str | None = None, limit=DEFAULT_SEARCH_LIMIT, ) -> list[EntityNode]: + if config is None: + return [] + search_results: list[list[EntityNode]] = [] if NodeSearchMethod.bm25 in config.search_methods: @@ -212,7 +207,7 @@ async def node_search( reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids] - return reranked_nodes + return reranked_nodes[:limit] async def community_search( @@ -220,9 +215,12 @@ async def community_search( embedder, query: str, group_ids: list[str] | None, - config: CommunitySearchConfig, + config: CommunitySearchConfig | None, limit=DEFAULT_SEARCH_LIMIT, ) -> list[CommunityNode]: + if config is None: + return [] + search_results: list[list[CommunityNode]] = [] if CommunitySearchMethod.bm25 in config.search_methods: @@ -255,4 +253,4 @@ async def community_search( reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids] - return reranked_communities + return reranked_communities[:limit] diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 110c3210..af672bfd 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -16,13 +16,13 @@ limitations under the License. import asyncio import logging -import re from collections import defaultdict from time import time from neo4j import AsyncDriver, Query from graphiti_core.edges import EntityEdge, get_entity_edge_from_record +from graphiti_core.helpers import lucene_sanitize from graphiti_core.nodes import ( CommunityNode, EntityNode, @@ -36,6 +36,22 @@ logger = logging.getLogger(__name__) RELEVANT_SCHEMA_LIMIT = 3 +def fulltext_query(query: str, group_ids: list[str] | None = None): + group_ids_filter_list = ( + [f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else [] + ) + group_ids_filter = '' + for f in group_ids_filter_list: + group_ids_filter += f if not group_ids_filter else f'OR {f}' + + group_ids_filter += ' AND ' if group_ids_filter else '' + + fuzzy_query = lucene_sanitize(query) + '~' + full_query = group_ids_filter + fuzzy_query + + return full_query + + async def get_mentioned_nodes( driver: AsyncDriver, episodes: list[EpisodicNode] ) -> list[EntityNode]: @@ -91,11 +107,15 @@ async def edge_fulltext_search( limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: # fulltext search over facts + fuzzy_query = fulltext_query(query, group_ids) + cypher_query = Query(""" - CALL db.index.fulltext.queryRelationships("name_and_fact", $query) + CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query) YIELD relationship AS rel, score - MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) - WHERE $group_ids IS NULL OR n.group_id IN $group_ids + MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) + WHERE ($source_uuid IS NULL OR n.uuid = $source_uuid) + AND ($target_uuid IS NULL OR m.uuid = $target_uuid) + AND ($group_ids IS NULL OR n.group_id IN $group_ids) RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -112,72 +132,6 @@ async def edge_fulltext_search( ORDER BY score DESC LIMIT $limit """) - if source_node_uuid is None and target_node_uuid is None: - cypher_query = Query(""" - CALL db.index.fulltext.queryRelationships("name_and_fact", $query) - YIELD relationship AS rel, score - MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) - WHERE $group_ids IS NULL OR r.group_id IN $group_ids - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC LIMIT $limit - """) - elif source_node_uuid is None: - cypher_query = Query(""" - CALL db.index.fulltext.queryRelationships("name_and_fact", $query) - YIELD relationship AS rel, score - MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) - WHERE $group_ids IS NULL OR r.group_id IN $group_ids - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC LIMIT $limit - """) - elif target_node_uuid is None: - cypher_query = Query(""" - CALL db.index.fulltext.queryRelationships("name_and_fact", $query) - YIELD relationship AS rel, score - MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) - WHERE $group_ids IS NULL OR r.group_id IN $group_ids - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC LIMIT $limit - """) - - fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' - records, _, _ = await driver.execute_query( cypher_query, query=fuzzy_query, @@ -202,11 +156,12 @@ async def edge_similarity_search( ) -> list[EntityEdge]: # vector similarity search over embedded facts query = Query(""" - CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) - YIELD relationship AS rel, score - MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) - WHERE $group_ids IS NULL OR r.group_id IN $group_ids + MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity) + WHERE ($group_ids IS NULL OR r.group_id IN $group_ids) + AND ($source_uuid IS NULL OR n.uuid = $source_uuid) + AND ($target_uuid IS NULL OR m.uuid = $target_uuid) RETURN + vector.similarity.cosine(r.fact_embedding, $search_vector) AS score, r.uuid AS uuid, r.group_id AS group_id, n.uuid AS source_node_uuid, @@ -220,72 +175,9 @@ async def edge_similarity_search( r.valid_at AS valid_at, r.invalid_at AS invalid_at ORDER BY score DESC + LIMIT $limit """) - if source_node_uuid is None and target_node_uuid is None: - query = Query(""" - CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) - YIELD relationship AS rel, score - MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) - WHERE $group_ids IS NULL OR r.group_id IN $group_ids - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC - """) - elif source_node_uuid is None: - query = Query(""" - CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) - YIELD relationship AS rel, score - MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) - WHERE $group_ids IS NULL OR r.group_id IN $group_ids - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC - """) - elif target_node_uuid is None: - query = Query(""" - CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) - YIELD relationship AS rel, score - MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) - WHERE $group_ids IS NULL OR r.group_id IN $group_ids - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC - """) - records, _, _ = await driver.execute_query( query, search_vector=search_vector, @@ -307,10 +199,11 @@ async def node_fulltext_search( limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: # BM25 search to get top nodes - fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' + fuzzy_query = fulltext_query(query, group_ids) + records, _, _ = await driver.execute_query( """ - CALL db.index.fulltext.queryNodes("name_and_summary", $query) + CALL db.index.fulltext.queryNodes("node_name_and_summary", $query) YIELD node AS n, score WHERE $group_ids IS NULL OR n.group_id IN $group_ids RETURN @@ -341,11 +234,10 @@ async def node_similarity_search( # vector similarity search over entity names records, _, _ = await driver.execute_query( """ - CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) - YIELD node AS n, score MATCH (n:Entity) WHERE $group_ids IS NULL OR n.group_id IN $group_ids RETURN + vector.similarity.cosine(n.name_embedding, $search_vector) AS score, n.uuid As uuid, n.group_id AS group_id, n.name AS name, @@ -353,6 +245,7 @@ async def node_similarity_search( n.created_at AS created_at, n.summary AS summary ORDER BY score DESC + LIMIT $limit """, search_vector=search_vector, group_ids=group_ids, @@ -370,7 +263,8 @@ async def community_fulltext_search( limit=RELEVANT_SCHEMA_LIMIT, ) -> list[CommunityNode]: # BM25 search to get top communities - fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' + fuzzy_query = fulltext_query(query, group_ids) + records, _, _ = await driver.execute_query( """ CALL db.index.fulltext.queryNodes("community_name", $query) @@ -405,11 +299,10 @@ async def community_similarity_search( # vector similarity search over entity names records, _, _ = await driver.execute_query( """ - CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector) - YIELD node AS comm, score MATCH (comm:Community) - WHERE $group_ids IS NULL OR comm.group_id IN $group_ids + WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids) RETURN + vector.similarity.cosine(comm.name_embedding, $search_vector) AS score, comm.uuid As uuid, comm.group_id AS group_id, comm.name AS name, @@ -417,6 +310,7 @@ async def community_similarity_search( comm.created_at AS created_at, comm.summary AS summary ORDER BY score DESC + LIMIT $limit """, search_vector=search_vector, group_ids=group_ids, diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index cdc39b31..6d67a707 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -28,7 +28,16 @@ EPISODE_WINDOW_LEN = 3 logger = logging.getLogger(__name__) -async def build_indices_and_constraints(driver: AsyncDriver): +async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False): + if delete_existing: + records, _, _ = await driver.execute_query(""" + SHOW INDEXES YIELD name + """) + index_names = [record['name'] for record in records] + await asyncio.gather( + *[driver.execute_query("""DROP INDEX $name""", name=name) for name in index_names] + ) + range_indices: list[LiteralString] = [ 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', @@ -52,38 +61,15 @@ async def build_indices_and_constraints(driver: AsyncDriver): ] fulltext_indices: list[LiteralString] = [ - 'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]', - 'CREATE FULLTEXT INDEX community_name IF NOT EXISTS FOR (n:Community) ON EACH [n.name]', - 'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]', + """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS + FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""", + """CREATE FULLTEXT INDEX community_name IF NOT EXISTS + FOR (n:Community) ON EACH [n.name, n.group_id]""", + """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS + FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""", ] - vector_indices: list[LiteralString] = [ - """ - CREATE VECTOR INDEX fact_embedding IF NOT EXISTS - FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding) - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """, - """ - CREATE VECTOR INDEX name_embedding IF NOT EXISTS - FOR (n:Entity) ON (n.name_embedding) - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """, - """ - CREATE VECTOR INDEX community_name_embedding IF NOT EXISTS - FOR (n:Community) ON (n.name_embedding) - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """, - ] - index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices + index_queries: list[LiteralString] = range_indices + fulltext_indices await asyncio.gather(*[driver.execute_query(query) for query in index_queries]) diff --git a/pyproject.toml b/pyproject.toml index c8ccfcaf..6b5d6647 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.3.6" +version = "0.3.7" description = "A temporal graph building library" authors = [ "Paul Paliychuk ", diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 5ab04541..29de263d 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -74,7 +74,6 @@ def format_context(facts): async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) - await graphiti.build_communities() edges = await graphiti.search( 'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None @@ -96,7 +95,7 @@ async def test_graphiti_init(): } logger.info(pretty_results) - graphiti.close() + await graphiti.close() @pytest.mark.asyncio