diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 47509766..2c53889e 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -18,6 +18,7 @@ import asyncio import logging import os import sys +from uuid import uuid4 from dotenv import load_dotenv from pydantic import BaseModel, Field @@ -68,6 +69,7 @@ async def main(): await clear_data(client.driver) await client.build_indices_and_constraints() messages = parse_podcast_messages() + group_id = str(uuid4()) for i, message in enumerate(messages[3:14]): episodes = await client.retrieve_episodes( @@ -80,7 +82,7 @@ async def main(): episode_body=f'{message.speaker_name} ({message.role}): {message.content}', reference_time=message.actual_timestamp, source_description='Podcast Transcript', - group_id='podcast', + group_id=group_id, entity_types={'Person': Person}, previous_episode_uuids=episode_uuids, ) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 22baf72e..81a713ca 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -341,10 +341,10 @@ async def node_fulltext_search( query = ( """ - CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) - YIELD node AS n, score - WHERE n:Entity - """ + CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) + YIELD node AS n, score + WHERE n:Entity + """ + filter_query + ENTITY_NODE_RETURN + """ @@ -676,7 +676,7 @@ async def get_relevant_nodes( WHERE score > $min_score WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids - CALL db.index.fulltext.queryNodes("node_name_and_summary", 'group_id:"' + $group_id + '" AND ' + node.name, {limit: $limit}) + CALL db.index.fulltext.queryNodes("node_name_and_summary", node.fulltext_query, {limit: $limit}) YIELD node AS m WHERE m.group_id = $group_id WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes @@ -705,18 +705,21 @@ async def get_relevant_nodes( """ ) + query_nodes = [ + { + 'uuid': node.uuid, + 'name': node.name, + 'name_embedding': node.name_embedding, + 'fulltext_query': fulltext_query(node.name, [node.group_id]), + } + for node in nodes + ] + results, _, _ = await driver.execute_query( query, query_params, - nodes=[ - { - 'uuid': node.uuid, - 'name': lucene_sanitize(node.name), - 'name_embedding': node.name_embedding, - } - for node in nodes - ], - group_id=lucene_sanitize(group_id), + nodes=query_nodes, + group_id=group_id, limit=limit, min_score=min_score, database_=DEFAULT_DATABASE, diff --git a/pyproject.toml b/pyproject.toml index 8438837e..1dc3512a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.11.1" +version = "0.11.2" authors = [ { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },