Abstract Neo4j filters in search queries (#243)
* move null check for search queries to python * update search filtering * update * update
This commit is contained in:
parent
425b35ba2d
commit
34496ffa6a
3 changed files with 102 additions and 62 deletions
|
|
@ -22,10 +22,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import asyncio\n",
|
"import asyncio\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
|
|
@ -47,10 +47,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def setup_logging():\n",
|
"def setup_logging():\n",
|
||||||
" logger = logging.getLogger()\n",
|
" logger = logging.getLogger()\n",
|
||||||
|
|
@ -67,8 +67,8 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## LangSmith integration (Optional)\n",
|
"## LangSmith integration (Optional)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -78,18 +78,18 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"os.environ['LANGCHAIN_TRACING_V2'] = 'false'\n",
|
"os.environ['LANGCHAIN_TRACING_V2'] = 'false'\n",
|
||||||
"os.environ['LANGCHAIN_PROJECT'] = 'Graphiti LangGraph Tutorial'"
|
"os.environ['LANGCHAIN_PROJECT'] = 'Graphiti LangGraph Tutorial'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Configure Graphiti\n",
|
"## Configure Graphiti\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -103,10 +103,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Configure Graphiti\n",
|
"# Configure Graphiti\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -127,8 +127,8 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Generating a database schema \n",
|
"## Generating a database schema \n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -138,10 +138,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Note: This will clear the database\n",
|
"# Note: This will clear the database\n",
|
||||||
"await clear_data(client.driver)\n",
|
"await clear_data(client.driver)\n",
|
||||||
|
|
@ -149,8 +149,8 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Load Shoe Data into the Graph\n",
|
"## Load Shoe Data into the Graph\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -161,10 +161,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"async def ingest_products_data(client: Graphiti):\n",
|
"async def ingest_products_data(client: Graphiti):\n",
|
||||||
" script_dir = Path.cwd().parent\n",
|
" script_dir = Path.cwd().parent\n",
|
||||||
|
|
@ -187,8 +187,8 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Create a user node in the Graphiti graph\n",
|
"## Create a user node in the Graphiti graph\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -196,10 +196,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_EPISODE_MENTIONS\n",
|
"from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_EPISODE_MENTIONS\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -224,20 +224,20 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def edges_to_facts_string(entities: list[EntityEdge]):\n",
|
"def edges_to_facts_string(entities: list[EntityEdge]):\n",
|
||||||
" return '-' + '\\n- '.join([edge.fact for edge in entities])"
|
" return '-' + '\\n- '.join([edge.fact for edge in entities])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain_core.messages import AIMessage, SystemMessage\n",
|
"from langchain_core.messages import AIMessage, SystemMessage\n",
|
||||||
"from langchain_core.tools import tool\n",
|
"from langchain_core.tools import tool\n",
|
||||||
|
|
@ -248,8 +248,8 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## `get_shoe_data` Tool\n",
|
"## `get_shoe_data` Tool\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -257,10 +257,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"@tool\n",
|
"@tool\n",
|
||||||
"async def get_shoe_data(query: str) -> str:\n",
|
"async def get_shoe_data(query: str) -> str:\n",
|
||||||
|
|
@ -278,25 +278,27 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"source": "llm = ChatOpenAI(model='gpt-4o-mini', temperature=0).bind_tools(tools)"
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = ChatOpenAI(model='gpt-4o-mini', temperature=0).bind_tools(tools)"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Test the tool node\n",
|
"# Test the tool node\n",
|
||||||
"await tool_node.ainvoke({'messages': [await llm.ainvoke('wool shoes')]})"
|
"await tool_node.ainvoke({'messages': [await llm.ainvoke('wool shoes')]})"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Chatbot Function Explanation\n",
|
"## Chatbot Function Explanation\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -312,10 +314,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"class State(TypedDict):\n",
|
"class State(TypedDict):\n",
|
||||||
" messages: Annotated[list, add_messages]\n",
|
" messages: Annotated[list, add_messages]\n",
|
||||||
|
|
@ -372,8 +374,8 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Setting up the Agent\n",
|
"## Setting up the Agent\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -387,10 +389,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"graph_builder = StateGraph(State)\n",
|
"graph_builder = StateGraph(State)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -420,23 +422,23 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": "Our LangGraph agent graph is illustrated below."
|
"source": "Our LangGraph agent graph is illustrated below."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"with suppress(Exception):\n",
|
"with suppress(Exception):\n",
|
||||||
" display(Image(graph.get_graph().draw_mermaid_png()))"
|
" display(Image(graph.get_graph().draw_mermaid_png()))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Running the Agent\n",
|
"## Running the Agent\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -444,10 +446,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"await graph.ainvoke(\n",
|
"await graph.ainvoke(\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
|
|
@ -465,8 +467,8 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Viewing the Graph\n",
|
"## Viewing the Graph\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -474,15 +476,17 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"source": "display(Image(filename='tinybirds-jess.png', width=850))"
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"display(Image(filename='tinybirds-jess.png', width=850))"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Running the Agent interactively\n",
|
"## Running the Agent interactively\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -490,10 +494,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"conversation_output = widgets.Output()\n",
|
"conversation_output = widgets.Output()\n",
|
||||||
"config = {'configurable': {'thread_id': uuid.uuid4().hex}}\n",
|
"config = {'configurable': {'thread_id': uuid.uuid4().hex}}\n",
|
||||||
|
|
@ -512,14 +516,14 @@
|
||||||
"\n",
|
"\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" async for event in graph.astream(\n",
|
" async for event in graph.astream(\n",
|
||||||
" graph_state,\n",
|
" graph_state,\n",
|
||||||
" config=config,\n",
|
" config=config,\n",
|
||||||
" ):\n",
|
" ):\n",
|
||||||
" for value in event.values():\n",
|
" for value in event.values():\n",
|
||||||
" if 'messages' in value:\n",
|
" if 'messages' in value:\n",
|
||||||
" last_message = value['messages'][-1]\n",
|
" last_message = value['messages'][-1]\n",
|
||||||
" if isinstance(last_message, AIMessage) and isinstance(\n",
|
" if isinstance(last_message, AIMessage) and isinstance(\n",
|
||||||
" last_message.content, str\n",
|
" last_message.content, str\n",
|
||||||
" ):\n",
|
" ):\n",
|
||||||
" conversation_output.append_stdout(last_message.content)\n",
|
" conversation_output.append_stdout(last_message.content)\n",
|
||||||
" except Exception as e:\n",
|
" except Exception as e:\n",
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,6 @@ class LLMClient(ABC):
|
||||||
self.cache_enabled = cache
|
self.cache_enabled = cache
|
||||||
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
|
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
|
||||||
|
|
||||||
|
|
||||||
def _clean_input(self, input: str) -> str:
|
def _clean_input(self, input: str) -> str:
|
||||||
"""Clean input string of invalid unicode and control characters.
|
"""Clean input string of invalid unicode and control characters.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from time import time
|
from time import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from neo4j import AsyncDriver, Query
|
from neo4j import AsyncDriver, Query
|
||||||
|
|
@ -191,12 +192,27 @@ async def edge_similarity_search(
|
||||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||||
)
|
)
|
||||||
|
|
||||||
query: LiteralString = """
|
query_params: dict[str, Any] = {}
|
||||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
||||||
WHERE ($group_ids IS NULL OR r.group_id IN $group_ids)
|
group_filter_query: LiteralString = ''
|
||||||
AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid])
|
if group_ids is not None:
|
||||||
AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid])
|
group_filter_query += 'WHERE r.group_id IN $group_ids'
|
||||||
WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
query_params['group_ids'] = group_ids
|
||||||
|
query_params['source_node_uuid'] = source_node_uuid
|
||||||
|
query_params['target_node_uuid'] = target_node_uuid
|
||||||
|
|
||||||
|
if source_node_uuid is not None:
|
||||||
|
group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
|
||||||
|
|
||||||
|
if target_node_uuid is not None:
|
||||||
|
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
||||||
|
|
||||||
|
query: LiteralString = (
|
||||||
|
"""
|
||||||
|
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||||
|
"""
|
||||||
|
+ group_filter_query
|
||||||
|
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||||
WHERE score > $min_score
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
r.uuid AS uuid,
|
r.uuid AS uuid,
|
||||||
|
|
@ -214,9 +230,11 @@ async def edge_similarity_search(
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
runtime_query + query,
|
runtime_query + query,
|
||||||
|
query_params,
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
source_uuid=source_node_uuid,
|
source_uuid=source_node_uuid,
|
||||||
target_uuid=target_node_uuid,
|
target_uuid=target_node_uuid,
|
||||||
|
|
@ -325,11 +343,20 @@ async def node_similarity_search(
|
||||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||||
)
|
)
|
||||||
|
|
||||||
|
query_params: dict[str, Any] = {}
|
||||||
|
|
||||||
|
group_filter_query: LiteralString = ''
|
||||||
|
if group_ids is not None:
|
||||||
|
group_filter_query += 'WHERE n.group_id IN $group_ids'
|
||||||
|
query_params['group_ids'] = group_ids
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
runtime_query
|
runtime_query
|
||||||
+ """
|
+ """
|
||||||
MATCH (n:Entity)
|
MATCH (n:Entity)
|
||||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
"""
|
||||||
|
+ group_filter_query
|
||||||
|
+ """
|
||||||
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
|
||||||
WHERE score > $min_score
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
|
|
@ -342,6 +369,7 @@ async def node_similarity_search(
|
||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
""",
|
""",
|
||||||
|
query_params,
|
||||||
search_vector=search_vector,
|
search_vector=search_vector,
|
||||||
group_ids=group_ids,
|
group_ids=group_ids,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -436,11 +464,20 @@ async def community_similarity_search(
|
||||||
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
|
||||||
)
|
)
|
||||||
|
|
||||||
|
query_params: dict[str, Any] = {}
|
||||||
|
|
||||||
|
group_filter_query: LiteralString = ''
|
||||||
|
if group_ids is not None:
|
||||||
|
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
||||||
|
query_params['group_ids'] = group_ids
|
||||||
|
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
runtime_query
|
runtime_query
|
||||||
+ """
|
+ """
|
||||||
MATCH (comm:Community)
|
MATCH (comm:Community)
|
||||||
WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids)
|
"""
|
||||||
|
+ group_filter_query
|
||||||
|
+ """
|
||||||
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
|
||||||
WHERE score > $min_score
|
WHERE score > $min_score
|
||||||
RETURN
|
RETURN
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue