Abstract Neo4j filters in search queries (#243)

* move null check for search queries to python

* update search filtering

* update

* update
This commit is contained in:
Preston Rasmussen 2024-12-16 21:45:45 -05:00 committed by GitHub
parent 425b35ba2d
commit 34496ffa6a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 102 additions and 62 deletions

View file

@ -22,10 +22,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"import asyncio\n", "import asyncio\n",
"import json\n", "import json\n",
@ -47,10 +47,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"def setup_logging():\n", "def setup_logging():\n",
" logger = logging.getLogger()\n", " logger = logging.getLogger()\n",
@ -67,8 +67,8 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## LangSmith integration (Optional)\n", "## LangSmith integration (Optional)\n",
"\n", "\n",
@ -78,18 +78,18 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"os.environ['LANGCHAIN_TRACING_V2'] = 'false'\n", "os.environ['LANGCHAIN_TRACING_V2'] = 'false'\n",
"os.environ['LANGCHAIN_PROJECT'] = 'Graphiti LangGraph Tutorial'" "os.environ['LANGCHAIN_PROJECT'] = 'Graphiti LangGraph Tutorial'"
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Configure Graphiti\n", "## Configure Graphiti\n",
"\n", "\n",
@ -103,10 +103,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# Configure Graphiti\n", "# Configure Graphiti\n",
"\n", "\n",
@ -127,8 +127,8 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Generating a database schema \n", "## Generating a database schema \n",
"\n", "\n",
@ -138,10 +138,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# Note: This will clear the database\n", "# Note: This will clear the database\n",
"await clear_data(client.driver)\n", "await clear_data(client.driver)\n",
@ -149,8 +149,8 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Load Shoe Data into the Graph\n", "## Load Shoe Data into the Graph\n",
"\n", "\n",
@ -161,10 +161,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"async def ingest_products_data(client: Graphiti):\n", "async def ingest_products_data(client: Graphiti):\n",
" script_dir = Path.cwd().parent\n", " script_dir = Path.cwd().parent\n",
@ -187,8 +187,8 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Create a user node in the Graphiti graph\n", "## Create a user node in the Graphiti graph\n",
"\n", "\n",
@ -196,10 +196,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_EPISODE_MENTIONS\n", "from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_EPISODE_MENTIONS\n",
"\n", "\n",
@ -224,20 +224,20 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"def edges_to_facts_string(entities: list[EntityEdge]):\n", "def edges_to_facts_string(entities: list[EntityEdge]):\n",
" return '-' + '\\n- '.join([edge.fact for edge in entities])" " return '-' + '\\n- '.join([edge.fact for edge in entities])"
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"from langchain_core.messages import AIMessage, SystemMessage\n", "from langchain_core.messages import AIMessage, SystemMessage\n",
"from langchain_core.tools import tool\n", "from langchain_core.tools import tool\n",
@ -248,8 +248,8 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## `get_shoe_data` Tool\n", "## `get_shoe_data` Tool\n",
"\n", "\n",
@ -257,10 +257,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"@tool\n", "@tool\n",
"async def get_shoe_data(query: str) -> str:\n", "async def get_shoe_data(query: str) -> str:\n",
@ -278,25 +278,27 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"source": "llm = ChatOpenAI(model='gpt-4o-mini', temperature=0).bind_tools(tools)" "metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(model='gpt-4o-mini', temperature=0).bind_tools(tools)"
]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# Test the tool node\n", "# Test the tool node\n",
"await tool_node.ainvoke({'messages': [await llm.ainvoke('wool shoes')]})" "await tool_node.ainvoke({'messages': [await llm.ainvoke('wool shoes')]})"
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Chatbot Function Explanation\n", "## Chatbot Function Explanation\n",
"\n", "\n",
@ -312,10 +314,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"class State(TypedDict):\n", "class State(TypedDict):\n",
" messages: Annotated[list, add_messages]\n", " messages: Annotated[list, add_messages]\n",
@ -372,8 +374,8 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Setting up the Agent\n", "## Setting up the Agent\n",
"\n", "\n",
@ -387,10 +389,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"graph_builder = StateGraph(State)\n", "graph_builder = StateGraph(State)\n",
"\n", "\n",
@ -420,23 +422,23 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": "Our LangGraph agent graph is illustrated below." "source": "Our LangGraph agent graph is illustrated below."
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"with suppress(Exception):\n", "with suppress(Exception):\n",
" display(Image(graph.get_graph().draw_mermaid_png()))" " display(Image(graph.get_graph().draw_mermaid_png()))"
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Running the Agent\n", "## Running the Agent\n",
"\n", "\n",
@ -444,10 +446,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"await graph.ainvoke(\n", "await graph.ainvoke(\n",
" {\n", " {\n",
@ -465,8 +467,8 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Viewing the Graph\n", "## Viewing the Graph\n",
"\n", "\n",
@ -474,15 +476,17 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"source": "display(Image(filename='tinybirds-jess.png', width=850))" "metadata": {},
"outputs": [],
"source": [
"display(Image(filename='tinybirds-jess.png', width=850))"
]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Running the Agent interactively\n", "## Running the Agent interactively\n",
"\n", "\n",
@ -490,10 +494,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"conversation_output = widgets.Output()\n", "conversation_output = widgets.Output()\n",
"config = {'configurable': {'thread_id': uuid.uuid4().hex}}\n", "config = {'configurable': {'thread_id': uuid.uuid4().hex}}\n",
@ -512,14 +516,14 @@
"\n", "\n",
" try:\n", " try:\n",
" async for event in graph.astream(\n", " async for event in graph.astream(\n",
" graph_state,\n", " graph_state,\n",
" config=config,\n", " config=config,\n",
" ):\n", " ):\n",
" for value in event.values():\n", " for value in event.values():\n",
" if 'messages' in value:\n", " if 'messages' in value:\n",
" last_message = value['messages'][-1]\n", " last_message = value['messages'][-1]\n",
" if isinstance(last_message, AIMessage) and isinstance(\n", " if isinstance(last_message, AIMessage) and isinstance(\n",
" last_message.content, str\n", " last_message.content, str\n",
" ):\n", " ):\n",
" conversation_output.append_stdout(last_message.content)\n", " conversation_output.append_stdout(last_message.content)\n",
" except Exception as e:\n", " except Exception as e:\n",

View file

@ -56,7 +56,6 @@ class LLMClient(ABC):
self.cache_enabled = cache self.cache_enabled = cache
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
def _clean_input(self, input: str) -> str: def _clean_input(self, input: str) -> str:
"""Clean input string of invalid unicode and control characters. """Clean input string of invalid unicode and control characters.

View file

@ -18,6 +18,7 @@ import asyncio
import logging import logging
from collections import defaultdict from collections import defaultdict
from time import time from time import time
from typing import Any
import numpy as np import numpy as np
from neo4j import AsyncDriver, Query from neo4j import AsyncDriver, Query
@ -191,12 +192,27 @@ async def edge_similarity_search(
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else '' 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
) )
query: LiteralString = """ query_params: dict[str, Any] = {}
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids) group_filter_query: LiteralString = ''
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid]) if group_ids is not None:
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid]) group_filter_query += 'WHERE r.group_id IN $group_ids'
WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score query_params['group_ids'] = group_ids
query_params['source_node_uuid'] = source_node_uuid
query_params['target_node_uuid'] = target_node_uuid
if source_node_uuid is not None:
group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
if target_node_uuid is not None:
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
query: LiteralString = (
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > $min_score WHERE score > $min_score
RETURN RETURN
r.uuid AS uuid, r.uuid AS uuid,
@ -214,9 +230,11 @@ async def edge_similarity_search(
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
""" """
)
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
runtime_query + query, runtime_query + query,
query_params,
search_vector=search_vector, search_vector=search_vector,
source_uuid=source_node_uuid, source_uuid=source_node_uuid,
target_uuid=target_node_uuid, target_uuid=target_node_uuid,
@ -325,11 +343,20 @@ async def node_similarity_search(
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else '' 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
) )
query_params: dict[str, Any] = {}
group_filter_query: LiteralString = ''
if group_ids is not None:
group_filter_query += 'WHERE n.group_id IN $group_ids'
query_params['group_ids'] = group_ids
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
runtime_query runtime_query
+ """ + """
MATCH (n:Entity) MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids """
+ group_filter_query
+ """
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > $min_score WHERE score > $min_score
RETURN RETURN
@ -342,6 +369,7 @@ async def node_similarity_search(
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
""", """,
query_params,
search_vector=search_vector, search_vector=search_vector,
group_ids=group_ids, group_ids=group_ids,
limit=limit, limit=limit,
@ -436,11 +464,20 @@ async def community_similarity_search(
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else '' 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
) )
query_params: dict[str, Any] = {}
group_filter_query: LiteralString = ''
if group_ids is not None:
group_filter_query += 'WHERE comm.group_id IN $group_ids'
query_params['group_ids'] = group_ids
records, _, _ = await driver.execute_query( records, _, _ = await driver.execute_query(
runtime_query runtime_query
+ """ + """
MATCH (comm:Community) MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids) """
+ group_filter_query
+ """
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score WHERE score > $min_score
RETURN RETURN