Abstract Neo4j filters in search queries (#243)
* move null check for search queries to python * update search filtering * update * update
This commit is contained in:
parent
425b35ba2d
commit
34496ffa6a
3 changed files with 102 additions and 62 deletions
|
|
@ -22,10 +22,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"import json\n",
|
||||
|
|
@ -47,10 +47,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def setup_logging():\n",
|
||||
" logger = logging.getLogger()\n",
|
||||
|
|
@ -67,8 +67,8 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## LangSmith integration (Optional)\n",
|
||||
"\n",
|
||||
|
|
@ -78,18 +78,18 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ['LANGCHAIN_TRACING_V2'] = 'false'\n",
|
||||
"os.environ['LANGCHAIN_PROJECT'] = 'Graphiti LangGraph Tutorial'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Configure Graphiti\n",
|
||||
"\n",
|
||||
|
|
@ -103,10 +103,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Configure Graphiti\n",
|
||||
"\n",
|
||||
|
|
@ -127,8 +127,8 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Generating a database schema \n",
|
||||
"\n",
|
||||
|
|
@ -138,10 +138,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Note: This will clear the database\n",
|
||||
"await clear_data(client.driver)\n",
|
||||
|
|
@ -149,8 +149,8 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load Shoe Data into the Graph\n",
|
||||
"\n",
|
||||
|
|
@ -161,10 +161,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"async def ingest_products_data(client: Graphiti):\n",
|
||||
" script_dir = Path.cwd().parent\n",
|
||||
|
|
@ -187,8 +187,8 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create a user node in the Graphiti graph\n",
|
||||
"\n",
|
||||
|
|
@ -196,10 +196,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_EPISODE_MENTIONS\n",
|
||||
"\n",
|
||||
|
|
@ -224,20 +224,20 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def edges_to_facts_string(entities: list[EntityEdge]):\n",
|
||||
" return '-' + '\\n- '.join([edge.fact for edge in entities])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.messages import AIMessage, SystemMessage\n",
|
||||
"from langchain_core.tools import tool\n",
|
||||
|
|
@ -248,8 +248,8 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## `get_shoe_data` Tool\n",
|
||||
"\n",
|
||||
|
|
@ -257,10 +257,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tool\n",
|
||||
"async def get_shoe_data(query: str) -> str:\n",
|
||||
|
|
@ -278,25 +278,27 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"source": "llm = ChatOpenAI(model='gpt-4o-mini', temperature=0).bind_tools(tools)"
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = ChatOpenAI(model='gpt-4o-mini', temperature=0).bind_tools(tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test the tool node\n",
|
||||
"await tool_node.ainvoke({'messages': [await llm.ainvoke('wool shoes')]})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chatbot Function Explanation\n",
|
||||
"\n",
|
||||
|
|
@ -312,10 +314,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class State(TypedDict):\n",
|
||||
" messages: Annotated[list, add_messages]\n",
|
||||
|
|
@ -372,8 +374,8 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setting up the Agent\n",
|
||||
"\n",
|
||||
|
|
@ -387,10 +389,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"graph_builder = StateGraph(State)\n",
|
||||
"\n",
|
||||
|
|
@ -420,23 +422,23 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "Our LangGraph agent graph is illustrated below."
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with suppress(Exception):\n",
|
||||
" display(Image(graph.get_graph().draw_mermaid_png()))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Running the Agent\n",
|
||||
"\n",
|
||||
|
|
@ -444,10 +446,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await graph.ainvoke(\n",
|
||||
" {\n",
|
||||
|
|
@ -465,8 +467,8 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Viewing the Graph\n",
|
||||
"\n",
|
||||
|
|
@ -474,15 +476,17 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"source": "display(Image(filename='tinybirds-jess.png', width=850))"
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"display(Image(filename='tinybirds-jess.png', width=850))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Running the Agent interactively\n",
|
||||
"\n",
|
||||
|
|
@ -490,10 +494,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"conversation_output = widgets.Output()\n",
|
||||
"config = {'configurable': {'thread_id': uuid.uuid4().hex}}\n",
|
||||
|
|
@ -512,14 +516,14 @@
|
|||
"\n",
|
||||
" try:\n",
|
||||
" async for event in graph.astream(\n",
|
||||
" graph_state,\n",
|
||||
" config=config,\n",
|
||||
" graph_state,\n",
|
||||
" config=config,\n",
|
||||
" ):\n",
|
||||
" for value in event.values():\n",
|
||||
" if 'messages' in value:\n",
|
||||
" last_message = value['messages'][-1]\n",
|
||||
" if isinstance(last_message, AIMessage) and isinstance(\n",
|
||||
" last_message.content, str\n",
|
||||
" last_message.content, str\n",
|
||||
" ):\n",
|
||||
" conversation_output.append_stdout(last_message.content)\n",
|
||||
" except Exception as e:\n",
|
||||
|
|
|
|||
|
|
@ -56,7 +56,6 @@ class LLMClient(ABC):
|
|||
self.cache_enabled = cache
|
||||
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
|
||||
|
||||
|
||||
def _clean_input(self, input: str) -> str:
|
||||
"""Clean input string of invalid unicode and control characters.
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import asyncio
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from neo4j import AsyncDriver, Query
|
||||
|
|
@ -191,12 +192,27 @@ async def edge_similarity_search(
|
|||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||
)
|
||||
|
||||
query: LiteralString = """
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
||||
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
||||
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
||||
WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
query_params: dict[str, Any] = {}
|
||||
|
||||
group_filter_query: LiteralString = ''
|
||||
if group_ids is not None:
|
||||
group_filter_query += 'WHERE r.group_id IN $group_ids'
|
||||
query_params['group_ids'] = group_ids
|
||||
query_params['source_node_uuid'] = source_node_uuid
|
||||
query_params['target_node_uuid'] = target_node_uuid
|
||||
|
||||
if source_node_uuid is not None:
|
||||
group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
|
||||
|
||||
if target_node_uuid is not None:
|
||||
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
||||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
|
|
@ -214,9 +230,11 @@ async def edge_similarity_search(
|
|||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
)
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
runtime_query + query,
|
||||
query_params,
|
||||
search_vector=search_vector,
|
||||
source_uuid=source_node_uuid,
|
||||
target_uuid=target_node_uuid,
|
||||
|
|
@ -325,11 +343,20 @@ async def node_similarity_search(
|
|||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||
)
|
||||
|
||||
query_params: dict[str, Any] = {}
|
||||
|
||||
group_filter_query: LiteralString = ''
|
||||
if group_ids is not None:
|
||||
group_filter_query += 'WHERE n.group_id IN $group_ids'
|
||||
query_params['group_ids'] = group_ids
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
runtime_query
|
||||
+ """
|
||||
MATCH (n:Entity)
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
|
|
@ -342,6 +369,7 @@ async def node_similarity_search(
|
|||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query_params,
|
||||
search_vector=search_vector,
|
||||
group_ids=group_ids,
|
||||
limit=limit,
|
||||
|
|
@ -436,11 +464,20 @@ async def community_similarity_search(
|
|||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||
)
|
||||
|
||||
query_params: dict[str, Any] = {}
|
||||
|
||||
group_filter_query: LiteralString = ''
|
||||
if group_ids is not None:
|
||||
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
||||
query_params['group_ids'] = group_ids
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
runtime_query
|
||||
+ """
|
||||
MATCH (comm:Community)
|
||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||
WHERE score > $min_score
|
||||
RETURN
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue