Abstract Neo4j filters in search queries (#243)

* move null check for search queries to python

* update search filtering

* update

* update
This commit is contained in:
Preston Rasmussen 2024-12-16 21:45:45 -05:00 committed by GitHub
parent 425b35ba2d
commit 34496ffa6a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 102 additions and 62 deletions

View file

@ -22,10 +22,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"import json\n",
@ -47,10 +47,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def setup_logging():\n",
" logger = logging.getLogger()\n",
@ -67,8 +67,8 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## LangSmith integration (Optional)\n",
"\n",
@ -78,18 +78,18 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"os.environ['LANGCHAIN_TRACING_V2'] = 'false'\n",
"os.environ['LANGCHAIN_PROJECT'] = 'Graphiti LangGraph Tutorial'"
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configure Graphiti\n",
"\n",
@ -103,10 +103,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configure Graphiti\n",
"\n",
@ -127,8 +127,8 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generating a database schema \n",
"\n",
@ -138,10 +138,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Note: This will clear the database\n",
"await clear_data(client.driver)\n",
@ -149,8 +149,8 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Shoe Data into the Graph\n",
"\n",
@ -161,10 +161,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"async def ingest_products_data(client: Graphiti):\n",
" script_dir = Path.cwd().parent\n",
@ -187,8 +187,8 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create a user node in the Graphiti graph\n",
"\n",
@ -196,10 +196,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_EPISODE_MENTIONS\n",
"\n",
@ -224,20 +224,20 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def edges_to_facts_string(entities: list[EntityEdge]):\n",
" return '-' + '\\n- '.join([edge.fact for edge in entities])"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.messages import AIMessage, SystemMessage\n",
"from langchain_core.tools import tool\n",
@ -248,8 +248,8 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## `get_shoe_data` Tool\n",
"\n",
@ -257,10 +257,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@tool\n",
"async def get_shoe_data(query: str) -> str:\n",
@ -278,25 +278,27 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "llm = ChatOpenAI(model='gpt-4o-mini', temperature=0).bind_tools(tools)"
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(model='gpt-4o-mini', temperature=0).bind_tools(tools)"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test the tool node\n",
"await tool_node.ainvoke({'messages': [await llm.ainvoke('wool shoes')]})"
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chatbot Function Explanation\n",
"\n",
@ -312,10 +314,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class State(TypedDict):\n",
" messages: Annotated[list, add_messages]\n",
@ -372,8 +374,8 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setting up the Agent\n",
"\n",
@ -387,10 +389,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph_builder = StateGraph(State)\n",
"\n",
@ -420,23 +422,23 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": "Our LangGraph agent graph is illustrated below."
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with suppress(Exception):\n",
" display(Image(graph.get_graph().draw_mermaid_png()))"
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running the Agent\n",
"\n",
@ -444,10 +446,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"await graph.ainvoke(\n",
" {\n",
@ -465,8 +467,8 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Viewing the Graph\n",
"\n",
@ -474,15 +476,17 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "display(Image(filename='tinybirds-jess.png', width=850))"
"metadata": {},
"outputs": [],
"source": [
"display(Image(filename='tinybirds-jess.png', width=850))"
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running the Agent interactively\n",
"\n",
@ -490,10 +494,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"conversation_output = widgets.Output()\n",
"config = {'configurable': {'thread_id': uuid.uuid4().hex}}\n",
@ -512,14 +516,14 @@
"\n",
" try:\n",
" async for event in graph.astream(\n",
" graph_state,\n",
" config=config,\n",
" graph_state,\n",
" config=config,\n",
" ):\n",
" for value in event.values():\n",
" if 'messages' in value:\n",
" last_message = value['messages'][-1]\n",
" if isinstance(last_message, AIMessage) and isinstance(\n",
" last_message.content, str\n",
" last_message.content, str\n",
" ):\n",
" conversation_output.append_stdout(last_message.content)\n",
" except Exception as e:\n",

View file

@ -56,7 +56,6 @@ class LLMClient(ABC):
self.cache_enabled = cache
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
def _clean_input(self, input: str) -> str:
"""Clean input string of invalid unicode and control characters.

View file

@ -18,6 +18,7 @@ import asyncio
import logging
from collections import defaultdict
from time import time
from typing import Any
import numpy as np
from neo4j import AsyncDriver, Query
@ -191,12 +192,27 @@ async def edge_similarity_search(
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)
query: LiteralString = """
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
query_params: dict[str, Any] = {}
group_filter_query: LiteralString = ''
if group_ids is not None:
group_filter_query += 'WHERE r.group_id IN $group_ids'
query_params['group_ids'] = group_ids
query_params['source_node_uuid'] = source_node_uuid
query_params['target_node_uuid'] = target_node_uuid
if source_node_uuid is not None:
group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
if target_node_uuid is not None:
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
query: LiteralString = (
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
r.uuid AS uuid,
@ -214,9 +230,11 @@ async def edge_similarity_search(
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
runtime_query + query,
query_params,
search_vector=search_vector,
source_uuid=source_node_uuid,
target_uuid=target_node_uuid,
@ -325,11 +343,20 @@ async def node_similarity_search(
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)
query_params: dict[str, Any] = {}
group_filter_query: LiteralString = ''
if group_ids is not None:
group_filter_query += 'WHERE n.group_id IN $group_ids'
query_params['group_ids'] = group_ids
records, _, _ = await driver.execute_query(
runtime_query
+ """
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
"""
+ group_filter_query
+ """
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
@ -342,6 +369,7 @@ async def node_similarity_search(
ORDER BY score DESC
LIMIT $limit
""",
query_params,
search_vector=search_vector,
group_ids=group_ids,
limit=limit,
@ -436,11 +464,20 @@ async def community_similarity_search(
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
)
query_params: dict[str, Any] = {}
group_filter_query: LiteralString = ''
if group_ids is not None:
group_filter_query += 'WHERE comm.group_id IN $group_ids'
query_params['group_ids'] = group_ids
records, _, _ = await driver.execute_query(
runtime_query
+ """
MATCH (comm:Community)
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
"""
+ group_filter_query
+ """
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN