diff --git a/examples/multi_session_conversation_memory/msc_eval.py b/examples/multi_session_conversation_memory/msc_eval.py
new file mode 100644
index 00000000..0042deee
--- /dev/null
+++ b/examples/multi_session_conversation_memory/msc_eval.py
@@ -0,0 +1,126 @@
+"""
+Copyright 2024, Zep Software, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import asyncio
+import csv
+import logging
+import os
+import sys
+from time import time
+
+from dotenv import load_dotenv
+
+from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a
+from graphiti_core import Graphiti
+from graphiti_core.prompts import prompt_library
+from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
+
+load_dotenv()
+
+neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
+neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
+neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
+
+
+def setup_logging():
+ # Create a logger
+ logger = logging.getLogger()
+ logger.setLevel(logging.INFO) # Set the logging level to INFO
+
+ # Create console handler and set level to INFO
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_handler.setLevel(logging.INFO)
+
+ # Create formatter
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+
+ # Add formatter to console handler
+ console_handler.setFormatter(formatter)
+
+ # Add console handler to logger
+ logger.addHandler(console_handler)
+
+ return logger
+
+
+async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str):
+ search_start = time()
+ results = await graphiti._search(query, COMBINED_HYBRID_SEARCH_RRF, group_ids=[str(group_id)])
+ search_end = time()
+ search_duration = search_end - search_start
+
+ facts = [edge.fact for edge in results.edges]
+ entity_summaries = [node.name + ': ' + node.summary for node in results.nodes]
+ context = {'facts': facts, 'entity_summaries': entity_summaries, 'query': 'Bob: ' + query}
+
+ llm_response = await graphiti.llm_client.generate_response(
+ prompt_library.eval.qa_prompt(context)
+ )
+ response = llm_response.get('ANSWER', '')
+
+ eval_context = {
+ 'query': 'Bob: ' + query,
+ 'answer': 'Alice: ' + answer,
+ 'response': 'Alice: ' + response,
+ }
+
+ eval_llm_response = await graphiti.llm_client.generate_response(
+ prompt_library.eval.eval_prompt(eval_context)
+ )
+ eval_response = 1 if eval_llm_response.get('is_correct', False) else 0
+
+ return {
+ 'Group id': group_id,
+ 'Question': query,
+ 'Answer': answer,
+ 'Response': response,
+ 'Score': eval_response,
+ 'Search Duration (ms)': search_duration * 1000,
+ }
+
+
+async def main():
+ setup_logging()
+ graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
+
+ fields = ['Group id', 'Question', 'Answer', 'Response', 'Score', 'Search Duration (ms)']
+ with open('../data/msc_eval.csv', 'w', newline='') as file:
+ writer = csv.DictWriter(file, fieldnames=fields)
+ writer.writeheader()
+
+ qa = conversation_q_and_a()[0:500]
+ i = 0
+ while i < 500:
+ qa_chunk = qa[i : i + 20]
+ group_ids = range(len(qa))[i : i + 20]
+ results = list(
+ await asyncio.gather(
+ *[
+ evaluate_qa(graphiti, str(group_id), query, answer)
+ for group_id, (query, answer) in zip(group_ids, qa_chunk)
+ ]
+ )
+ )
+
+ with open('../data/msc_eval.csv', 'a', newline='') as file:
+ writer = csv.DictWriter(file, fieldnames=fields)
+ writer.writerows(results)
+ i += 20
+
+ await graphiti.close()
+
+
+asyncio.run(main())
diff --git a/examples/multi_session_conversation_memory/msc_runner.py b/examples/multi_session_conversation_memory/msc_runner.py
new file mode 100644
index 00000000..3ea412ae
--- /dev/null
+++ b/examples/multi_session_conversation_memory/msc_runner.py
@@ -0,0 +1,91 @@
+"""
+Copyright 2024, Zep Software, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import asyncio
+import logging
+import os
+import sys
+
+from dotenv import load_dotenv
+
+from examples.multi_session_conversation_memory.parse_msc_messages import (
+ ParsedMscMessage,
+ parse_msc_messages,
+)
+from graphiti_core import Graphiti
+
+load_dotenv()
+
+neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
+neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
+neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
+
+
+def setup_logging():
+ # Create a logger
+ logger = logging.getLogger()
+ logger.setLevel(logging.INFO) # Set the logging level to INFO
+
+ # Create console handler and set level to INFO
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_handler.setLevel(logging.INFO)
+
+ # Create formatter
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+
+ # Add formatter to console handler
+ console_handler.setFormatter(formatter)
+
+ # Add console handler to logger
+ logger.addHandler(console_handler)
+
+ return logger
+
+
+async def add_conversation(graphiti: Graphiti, group_id: str, messages: list[ParsedMscMessage]):
+ for i, message in enumerate(messages):
+ await graphiti.add_episode(
+ name=f'Message {group_id + "-" + str(i)}',
+ episode_body=f'{message.speaker_name}: {message.content}',
+ reference_time=message.actual_timestamp,
+ source_description='Multi-Session Conversation',
+ group_id=group_id,
+ )
+
+
+async def main():
+ setup_logging()
+ graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
+ msc_messages = parse_msc_messages()
+ i = 0
+ while i <= 490:
+ msc_message_slice = msc_messages[i : i + 10]
+ group_ids = range(len(msc_messages))[i : i + 10]
+
+ await asyncio.gather(
+ *[
+ add_conversation(graphiti, str(group_id), messages)
+ for group_id, messages in zip(group_ids, msc_message_slice)
+ ]
+ )
+
+ i += 10
+
+ # build communities
+ # await client.build_communities()
+
+
+asyncio.run(main())
diff --git a/examples/multi_session_conversation_memory/parse_msc_messages.py b/examples/multi_session_conversation_memory/parse_msc_messages.py
new file mode 100644
index 00000000..1cb28ef8
--- /dev/null
+++ b/examples/multi_session_conversation_memory/parse_msc_messages.py
@@ -0,0 +1,85 @@
+"""
+Copyright 2024, Zep Software, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import json
+from datetime import datetime
+
+from pydantic import BaseModel
+
+
+class ParsedMscMessage(BaseModel):
+ speaker_name: str
+ actual_timestamp: datetime
+ content: str
+ group_id: str
+
+
+def parse_msc_messages() -> list[list[ParsedMscMessage]]:
+ msc_messages: list[list[ParsedMscMessage]] = []
+ speakers = ['Alice', 'Bob']
+
+ with open('../data/msc.json') as file:
+ data = json.load(file)['data']
+ for i, conversation in enumerate(data):
+ messages: list[ParsedMscMessage] = []
+ for previous_dialog in conversation['previous_dialogs']:
+ dialog = previous_dialog['dialog']
+ speaker_idx = 0
+
+ for utterance in dialog:
+ content = utterance['text']
+ messages.append(
+ ParsedMscMessage(
+ speaker_name=speakers[speaker_idx],
+ content=content,
+ actual_timestamp=datetime.now(),
+ group_id=str(i),
+ )
+ )
+ speaker_idx += 1
+ speaker_idx %= 2
+
+ dialog = conversation['dialog']
+ speaker_idx = 0
+ for utterance in dialog:
+ content = utterance['text']
+ messages.append(
+ ParsedMscMessage(
+ speaker_name=speakers[speaker_idx],
+ content=content,
+ actual_timestamp=datetime.now(),
+ group_id=str(i),
+ )
+ )
+ speaker_idx += 1
+ speaker_idx %= 2
+
+ msc_messages.append(messages)
+
+ return msc_messages
+
+
+def conversation_q_and_a() -> list[tuple[str, str]]:
+ with open('../data/msc.json') as file:
+ data = json.load(file)['data']
+
+ qa: list[tuple[str, str]] = []
+ for conversation in data:
+ query = conversation['self_instruct']['B']
+ answer = conversation['self_instruct']['A']
+
+ qa.append((query, answer))
+ return qa
diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py
index 0037c74e..4b104698 100644
--- a/graphiti_core/graphiti.py
+++ b/graphiti_core/graphiti.py
@@ -161,7 +161,7 @@ class Graphiti:
"""
await self.driver.close()
- async def build_indices_and_constraints(self):
+ async def build_indices_and_constraints(self, delete_existing: bool = False):
"""
Build indices and constraints in the Neo4j database.
@@ -171,6 +171,9 @@ class Graphiti:
Parameters
----------
self
+ delete_existing : bool, optional
+ Whether to clear existing indices before creating new ones.
+
Returns
-------
@@ -191,7 +194,7 @@ class Graphiti:
Caution: Running this method on a large existing database may take some time
and could impact database performance during execution.
"""
- await build_indices_and_constraints(self.driver)
+ await build_indices_and_constraints(self.driver, delete_existing)
async def retrieve_episodes(
self,
diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py
index 9471058e..80cd9752 100644
--- a/graphiti_core/helpers.py
+++ b/graphiti_core/helpers.py
@@ -21,3 +21,33 @@ from neo4j import time as neo4j_time
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
return neo_date.to_native() if neo_date else None
+
+
+def lucene_sanitize(query: str) -> str:
+ # Escape special characters from a query before passing into Lucene
+ # + - && || ! ( ) { } [ ] ^ " ~ * ? : \
+ escape_map = str.maketrans(
+ {
+ '+': r'\+',
+ '-': r'\-',
+ '&': r'\&',
+ '|': r'\|',
+ '!': r'\!',
+ '(': r'\(',
+ ')': r'\)',
+ '{': r'\{',
+ '}': r'\}',
+ '[': r'\[',
+ ']': r'\]',
+ '^': r'\^',
+ '"': r'\"',
+ '~': r'\~',
+ '*': r'\*',
+ '?': r'\?',
+ ':': r'\:',
+ '\\': r'\\',
+ }
+ )
+
+ sanitized = query.translate(escape_map)
+ return sanitized
diff --git a/graphiti_core/prompts/eval.py b/graphiti_core/prompts/eval.py
new file mode 100644
index 00000000..8fa22d3f
--- /dev/null
+++ b/graphiti_core/prompts/eval.py
@@ -0,0 +1,90 @@
+"""
+Copyright 2024, Zep Software, Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import json
+from typing import Any, Protocol, TypedDict
+
+from .models import Message, PromptFunction, PromptVersion
+
+
+class Prompt(Protocol):
+ qa_prompt: PromptVersion
+ eval_prompt: PromptVersion
+
+
+class Versions(TypedDict):
+ qa_prompt: PromptFunction
+ eval_prompt: PromptFunction
+
+
+def qa_prompt(context: dict[str, Any]) -> list[Message]:
+ sys_prompt = """You are Alice and should respond to all questions from the first person perspective of Alice"""
+
+ user_prompt = f"""
+ Your task is to briefly answer the question in the way that you think Alice would answer the question.
+ You are given the following entity summaries and facts to help you determine the answer to your question.
+
+ {json.dumps(context['entity_summaries'])}
+
+ {json.dumps(context['facts'])}
+
+
+ {context['query']}
+
+ respond with a JSON object in the following format:
+ {{
+ "ANSWER": "how Alice would answer the question"
+ }}
+ """
+ return [
+ Message(role='system', content=sys_prompt),
+ Message(role='user', content=user_prompt),
+ ]
+
+
+def eval_prompt(context: dict[str, Any]) -> list[Message]:
+ sys_prompt = (
+ """You are a judge that determines if answers to questions match a gold standard answer"""
+ )
+
+ user_prompt = f"""
+ Given the QUESTION and the gold standard ANSWER determine if the RESPONSE to the question is correct or incorrect.
+ Although the RESPONSE may be more verbose, mark it as correct as long as it references the same topic
+ as the gold standard ANSWER. Also include your reasoning for the grade.
+
+ {context['query']}
+
+
+ {context['answer']}
+
+
+ {context['response']}
+
+
+ respond with a JSON object in the following format:
+ {{
+ "is_correct": "boolean if the answer is correct or incorrect"
+ "reasoning": "why you determined the response was correct or incorrect"
+ }}
+ """
+ return [
+ Message(role='system', content=sys_prompt),
+ Message(role='user', content=user_prompt),
+ ]
+
+
+versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}
diff --git a/graphiti_core/prompts/lib.py b/graphiti_core/prompts/lib.py
index 09cc3ec0..cb971d0f 100644
--- a/graphiti_core/prompts/lib.py
+++ b/graphiti_core/prompts/lib.py
@@ -34,6 +34,9 @@ from .dedupe_nodes import (
from .dedupe_nodes import (
versions as dedupe_nodes_versions,
)
+from .eval import Prompt as EvalPrompt
+from .eval import Versions as EvalVersions
+from .eval import versions as eval_versions
from .extract_edge_dates import (
Prompt as ExtractEdgeDatesPrompt,
)
@@ -84,6 +87,7 @@ class PromptLibrary(Protocol):
invalidate_edges: InvalidateEdgesPrompt
extract_edge_dates: ExtractEdgeDatesPrompt
summarize_nodes: SummarizeNodesPrompt
+ eval: EvalPrompt
class PromptLibraryImpl(TypedDict):
@@ -94,6 +98,7 @@ class PromptLibraryImpl(TypedDict):
invalidate_edges: InvalidateEdgesVersions
extract_edge_dates: ExtractEdgeDatesVersions
summarize_nodes: SummarizeNodesVersions
+ eval: EvalVersions
class VersionWrapper:
@@ -124,5 +129,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
'invalidate_edges': invalidate_edges_versions,
'extract_edge_dates': extract_edge_dates_versions,
'summarize_nodes': summarize_nodes_versions,
+ 'eval': eval_versions,
}
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py
index 9a5fa311..0fda25c3 100644
--- a/graphiti_core/search/search.py
+++ b/graphiti_core/search/search.py
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
+import asyncio
import logging
from collections import defaultdict
from time import time
@@ -65,32 +66,20 @@ async def search(
query = query.replace('\n', ' ')
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None
- edges = (
- await edge_search(
- driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit, config.embedding_model
- )
- if config.edge_config is not None
- else []
- )
- nodes = (
- await node_search(
- driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit, config.embedding_model
- )
- if config.node_config is not None
- else []
- )
- communities = (
- await community_search(
- driver, embedder, query, group_ids, config.community_config, config.limit, config.embedding_model
- )
- if config.community_config is not None
- else []
+ edges, nodes, communities = await asyncio.gather(
+ edge_search(
+ driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit
+ ),
+ node_search(
+ driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit
+ ),
+ community_search(driver, embedder, query, group_ids, config.community_config, config.limit),
)
results = SearchResults(
- edges=edges[: config.limit],
- nodes=nodes[: config.limit],
- communities=communities[: config.limit],
+ edges=edges,
+ nodes=nodes,
+ communities=communities,
)
end = time()
@@ -105,11 +94,14 @@ async def edge_search(
embedder,
query: str,
group_ids: list[str] | None,
- config: EdgeSearchConfig,
+ config: EdgeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[EntityEdge]:
+ if config is None:
+ return []
+
search_results: list[list[EntityEdge]] = []
if EdgeSearchMethod.bm25 in config.search_methods:
@@ -163,7 +155,7 @@ async def edge_search(
if config.reranker == EdgeReranker.episode_mentions:
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
- return reranked_edges
+ return reranked_edges[:limit]
async def node_search(
@@ -171,11 +163,14 @@ async def node_search(
embedder,
query: str,
group_ids: list[str] | None,
- config: NodeSearchConfig,
+ config: NodeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[EntityNode]:
+ if config is None:
+ return []
+
search_results: list[list[EntityNode]] = []
if NodeSearchMethod.bm25 in config.search_methods:
@@ -214,7 +209,7 @@ async def node_search(
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
- return reranked_nodes
+ return reranked_nodes[:limit]
async def community_search(
@@ -222,10 +217,13 @@ async def community_search(
embedder,
query: str,
group_ids: list[str] | None,
- config: CommunitySearchConfig,
+ config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[CommunityNode]:
+ if config is None:
+ return []
+
search_results: list[list[CommunityNode]] = []
if CommunitySearchMethod.bm25 in config.search_methods:
@@ -258,4 +256,4 @@ async def community_search(
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
- return reranked_communities
+ return reranked_communities[:limit]
diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py
index 110c3210..af672bfd 100644
--- a/graphiti_core/search/search_utils.py
+++ b/graphiti_core/search/search_utils.py
@@ -16,13 +16,13 @@ limitations under the License.
import asyncio
import logging
-import re
from collections import defaultdict
from time import time
from neo4j import AsyncDriver, Query
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
+from graphiti_core.helpers import lucene_sanitize
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
@@ -36,6 +36,22 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
+def fulltext_query(query: str, group_ids: list[str] | None = None):
+ group_ids_filter_list = (
+ [f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
+ )
+ group_ids_filter = ''
+ for f in group_ids_filter_list:
+ group_ids_filter += f if not group_ids_filter else f'OR {f}'
+
+ group_ids_filter += ' AND ' if group_ids_filter else ''
+
+ fuzzy_query = lucene_sanitize(query) + '~'
+ full_query = group_ids_filter + fuzzy_query
+
+ return full_query
+
+
async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
@@ -91,11 +107,15 @@ async def edge_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# fulltext search over facts
+ fuzzy_query = fulltext_query(query, group_ids)
+
cypher_query = Query("""
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
+ CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query)
YIELD relationship AS rel, score
- MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
- WHERE $group_ids IS NULL OR n.group_id IN $group_ids
+ MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
+ WHERE ($source_uuid IS NULL OR n.uuid = $source_uuid)
+ AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
+ AND ($group_ids IS NULL OR n.group_id IN $group_ids)
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@@ -112,72 +132,6 @@ async def edge_fulltext_search(
ORDER BY score DESC LIMIT $limit
""")
- if source_node_uuid is None and target_node_uuid is None:
- cypher_query = Query("""
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
- YIELD relationship AS rel, score
- MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
- WHERE $group_ids IS NULL OR r.group_id IN $group_ids
- RETURN
- r.uuid AS uuid,
- r.group_id AS group_id,
- n.uuid AS source_node_uuid,
- m.uuid AS target_node_uuid,
- r.created_at AS created_at,
- r.name AS name,
- r.fact AS fact,
- r.fact_embedding AS fact_embedding,
- r.episodes AS episodes,
- r.expired_at AS expired_at,
- r.valid_at AS valid_at,
- r.invalid_at AS invalid_at
- ORDER BY score DESC LIMIT $limit
- """)
- elif source_node_uuid is None:
- cypher_query = Query("""
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
- YIELD relationship AS rel, score
- MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
- WHERE $group_ids IS NULL OR r.group_id IN $group_ids
- RETURN
- r.uuid AS uuid,
- r.group_id AS group_id,
- n.uuid AS source_node_uuid,
- m.uuid AS target_node_uuid,
- r.created_at AS created_at,
- r.name AS name,
- r.fact AS fact,
- r.fact_embedding AS fact_embedding,
- r.episodes AS episodes,
- r.expired_at AS expired_at,
- r.valid_at AS valid_at,
- r.invalid_at AS invalid_at
- ORDER BY score DESC LIMIT $limit
- """)
- elif target_node_uuid is None:
- cypher_query = Query("""
- CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
- YIELD relationship AS rel, score
- MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
- WHERE $group_ids IS NULL OR r.group_id IN $group_ids
- RETURN
- r.uuid AS uuid,
- r.group_id AS group_id,
- n.uuid AS source_node_uuid,
- m.uuid AS target_node_uuid,
- r.created_at AS created_at,
- r.name AS name,
- r.fact AS fact,
- r.fact_embedding AS fact_embedding,
- r.episodes AS episodes,
- r.expired_at AS expired_at,
- r.valid_at AS valid_at,
- r.invalid_at AS invalid_at
- ORDER BY score DESC LIMIT $limit
- """)
-
- fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
-
records, _, _ = await driver.execute_query(
cypher_query,
query=fuzzy_query,
@@ -202,11 +156,12 @@ async def edge_similarity_search(
) -> list[EntityEdge]:
# vector similarity search over embedded facts
query = Query("""
- CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
- YIELD relationship AS rel, score
- MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
- WHERE $group_ids IS NULL OR r.group_id IN $group_ids
+ MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
+ WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
+ AND ($source_uuid IS NULL OR n.uuid = $source_uuid)
+ AND ($target_uuid IS NULL OR m.uuid = $target_uuid)
RETURN
+ vector.similarity.cosine(r.fact_embedding, $search_vector) AS score,
r.uuid AS uuid,
r.group_id AS group_id,
n.uuid AS source_node_uuid,
@@ -220,72 +175,9 @@ async def edge_similarity_search(
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC
+ LIMIT $limit
""")
- if source_node_uuid is None and target_node_uuid is None:
- query = Query("""
- CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
- YIELD relationship AS rel, score
- MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
- WHERE $group_ids IS NULL OR r.group_id IN $group_ids
- RETURN
- r.uuid AS uuid,
- r.group_id AS group_id,
- n.uuid AS source_node_uuid,
- m.uuid AS target_node_uuid,
- r.created_at AS created_at,
- r.name AS name,
- r.fact AS fact,
- r.fact_embedding AS fact_embedding,
- r.episodes AS episodes,
- r.expired_at AS expired_at,
- r.valid_at AS valid_at,
- r.invalid_at AS invalid_at
- ORDER BY score DESC
- """)
- elif source_node_uuid is None:
- query = Query("""
- CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
- YIELD relationship AS rel, score
- MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
- WHERE $group_ids IS NULL OR r.group_id IN $group_ids
- RETURN
- r.uuid AS uuid,
- r.group_id AS group_id,
- n.uuid AS source_node_uuid,
- m.uuid AS target_node_uuid,
- r.created_at AS created_at,
- r.name AS name,
- r.fact AS fact,
- r.fact_embedding AS fact_embedding,
- r.episodes AS episodes,
- r.expired_at AS expired_at,
- r.valid_at AS valid_at,
- r.invalid_at AS invalid_at
- ORDER BY score DESC
- """)
- elif target_node_uuid is None:
- query = Query("""
- CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
- YIELD relationship AS rel, score
- MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
- WHERE $group_ids IS NULL OR r.group_id IN $group_ids
- RETURN
- r.uuid AS uuid,
- r.group_id AS group_id,
- n.uuid AS source_node_uuid,
- m.uuid AS target_node_uuid,
- r.created_at AS created_at,
- r.name AS name,
- r.fact AS fact,
- r.fact_embedding AS fact_embedding,
- r.episodes AS episodes,
- r.expired_at AS expired_at,
- r.valid_at AS valid_at,
- r.invalid_at AS invalid_at
- ORDER BY score DESC
- """)
-
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
@@ -307,10 +199,11 @@ async def node_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
# BM25 search to get top nodes
- fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
+ fuzzy_query = fulltext_query(query, group_ids)
+
records, _, _ = await driver.execute_query(
"""
- CALL db.index.fulltext.queryNodes("name_and_summary", $query)
+ CALL db.index.fulltext.queryNodes("node_name_and_summary", $query)
YIELD node AS n, score
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
RETURN
@@ -341,11 +234,10 @@ async def node_similarity_search(
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
- CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
- YIELD node AS n, score
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
RETURN
+ vector.similarity.cosine(n.name_embedding, $search_vector) AS score,
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
@@ -353,6 +245,7 @@ async def node_similarity_search(
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
+ LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,
@@ -370,7 +263,8 @@ async def community_fulltext_search(
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]:
# BM25 search to get top communities
- fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
+ fuzzy_query = fulltext_query(query, group_ids)
+
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("community_name", $query)
@@ -405,11 +299,10 @@ async def community_similarity_search(
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
- CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
- YIELD node AS comm, score
MATCH (comm:Community)
- WHERE $group_ids IS NULL OR comm.group_id IN $group_ids
+ WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
RETURN
+ vector.similarity.cosine(comm.name_embedding, $search_vector) AS score,
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
@@ -417,6 +310,7 @@ async def community_similarity_search(
comm.created_at AS created_at,
comm.summary AS summary
ORDER BY score DESC
+ LIMIT $limit
""",
search_vector=search_vector,
group_ids=group_ids,
diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py
index cdc39b31..6d67a707 100644
--- a/graphiti_core/utils/maintenance/graph_data_operations.py
+++ b/graphiti_core/utils/maintenance/graph_data_operations.py
@@ -28,7 +28,16 @@ EPISODE_WINDOW_LEN = 3
logger = logging.getLogger(__name__)
-async def build_indices_and_constraints(driver: AsyncDriver):
+async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bool = False):
+ if delete_existing:
+ records, _, _ = await driver.execute_query("""
+ SHOW INDEXES YIELD name
+ """)
+ index_names = [record['name'] for record in records]
+ await asyncio.gather(
+ *[driver.execute_query("""DROP INDEX $name""", name=name) for name in index_names]
+ )
+
range_indices: list[LiteralString] = [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
@@ -52,38 +61,15 @@ async def build_indices_and_constraints(driver: AsyncDriver):
]
fulltext_indices: list[LiteralString] = [
- 'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]',
- 'CREATE FULLTEXT INDEX community_name IF NOT EXISTS FOR (n:Community) ON EACH [n.name]',
- 'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]',
+ """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
+ FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
+ """CREATE FULLTEXT INDEX community_name IF NOT EXISTS
+ FOR (n:Community) ON EACH [n.name, n.group_id]""",
+ """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
+ FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
]
- vector_indices: list[LiteralString] = [
- """
- CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
- FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
- OPTIONS {indexConfig: {
- `vector.dimensions`: 1024,
- `vector.similarity_function`: 'cosine'
- }}
- """,
- """
- CREATE VECTOR INDEX name_embedding IF NOT EXISTS
- FOR (n:Entity) ON (n.name_embedding)
- OPTIONS {indexConfig: {
- `vector.dimensions`: 1024,
- `vector.similarity_function`: 'cosine'
- }}
- """,
- """
- CREATE VECTOR INDEX community_name_embedding IF NOT EXISTS
- FOR (n:Community) ON (n.name_embedding)
- OPTIONS {indexConfig: {
- `vector.dimensions`: 1024,
- `vector.similarity_function`: 'cosine'
- }}
- """,
- ]
- index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
+ index_queries: list[LiteralString] = range_indices + fulltext_indices
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
diff --git a/pyproject.toml b/pyproject.toml
index c8ccfcaf..6b5d6647 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
-version = "0.3.6"
+version = "0.3.7"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk ",
diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py
index eec33923..f9a94201 100644
--- a/tests/test_graphiti_int.py
+++ b/tests/test_graphiti_int.py
@@ -74,7 +74,6 @@ def format_context(facts):
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
- await graphiti.build_communities()
edges = await graphiti.search(
'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None
@@ -96,7 +95,7 @@ async def test_graphiti_init():
}
logger.info(pretty_results)
- graphiti.close()
+ await graphiti.close()
@pytest.mark.asyncio