diff --git a/examples/langgraph-agent/agent.ipynb b/examples/langgraph-agent/agent.ipynb index 54fe4b20..f04fb573 100644 --- a/examples/langgraph-agent/agent.ipynb +++ b/examples/langgraph-agent/agent.ipynb @@ -22,10 +22,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "import asyncio\n", "import json\n", @@ -47,10 +47,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "def setup_logging():\n", " logger = logging.getLogger()\n", @@ -67,8 +67,8 @@ ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## LangSmith integration (Optional)\n", "\n", @@ -78,18 +78,18 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "os.environ['LANGCHAIN_TRACING_V2'] = 'false'\n", "os.environ['LANGCHAIN_PROJECT'] = 'Graphiti LangGraph Tutorial'" ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## Configure Graphiti\n", "\n", @@ -103,10 +103,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# Configure Graphiti\n", "\n", @@ -127,8 +127,8 @@ ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## Generating a database schema \n", "\n", @@ -138,10 +138,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# Note: This will clear the database\n", "await clear_data(client.driver)\n", @@ -149,8 +149,8 @@ ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## Load Shoe Data into the Graph\n", "\n", @@ -161,10 +161,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "async def ingest_products_data(client: Graphiti):\n", " script_dir = Path.cwd().parent\n", @@ -187,8 +187,8 @@ ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## Create a user node in the Graphiti graph\n", "\n", @@ -196,10 +196,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_EPISODE_MENTIONS\n", "\n", @@ -224,20 +224,20 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "def edges_to_facts_string(entities: list[EntityEdge]):\n", " return '-' + '\\n- '.join([edge.fact for edge in entities])" ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from langchain_core.messages import AIMessage, SystemMessage\n", "from langchain_core.tools import tool\n", @@ -248,8 +248,8 @@ ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## `get_shoe_data` Tool\n", "\n", @@ -257,10 +257,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "@tool\n", "async def get_shoe_data(query: str) -> str:\n", @@ -278,25 +278,27 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "llm = ChatOpenAI(model='gpt-4o-mini', temperature=0).bind_tools(tools)" + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(model='gpt-4o-mini', temperature=0).bind_tools(tools)" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# Test the tool node\n", "await tool_node.ainvoke({'messages': [await llm.ainvoke('wool shoes')]})" ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## Chatbot Function Explanation\n", "\n", @@ -312,10 +314,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "class State(TypedDict):\n", " messages: Annotated[list, add_messages]\n", @@ -372,8 +374,8 @@ ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## Setting up the Agent\n", "\n", @@ -387,10 +389,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "graph_builder = StateGraph(State)\n", "\n", @@ -420,23 +422,23 @@ ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": "Our LangGraph agent graph is illustrated below." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "with suppress(Exception):\n", " display(Image(graph.get_graph().draw_mermaid_png()))" ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## Running the Agent\n", "\n", @@ -444,10 +446,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "await graph.ainvoke(\n", " {\n", @@ -465,8 +467,8 @@ ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## Viewing the Graph\n", "\n", @@ -474,15 +476,17 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "display(Image(filename='tinybirds-jess.png', width=850))" + "metadata": {}, + "outputs": [], + "source": [ + "display(Image(filename='tinybirds-jess.png', width=850))" + ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "## Running the Agent interactively\n", "\n", @@ -490,10 +494,10 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "conversation_output = widgets.Output()\n", "config = {'configurable': {'thread_id': uuid.uuid4().hex}}\n", @@ -512,14 +516,14 @@ "\n", " try:\n", " async for event in graph.astream(\n", - " graph_state,\n", - " config=config,\n", + " graph_state,\n", + " config=config,\n", " ):\n", " for value in event.values():\n", " if 'messages' in value:\n", " last_message = value['messages'][-1]\n", " if isinstance(last_message, AIMessage) and isinstance(\n", - " last_message.content, str\n", + " last_message.content, str\n", " ):\n", " conversation_output.append_stdout(last_message.content)\n", " except Exception as e:\n", diff --git a/graphiti_core/llm_client/client.py b/graphiti_core/llm_client/client.py index ac01bed7..0c28e895 100644 --- a/graphiti_core/llm_client/client.py +++ b/graphiti_core/llm_client/client.py @@ -56,7 +56,6 @@ class LLMClient(ABC): self.cache_enabled = cache self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory - def _clean_input(self, input: str) -> str: """Clean input string of invalid unicode and control characters. diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 726ff46b..e271dc01 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -18,6 +18,7 @@ import asyncio import logging from collections import defaultdict from time import time +from typing import Any import numpy as np from neo4j import AsyncDriver, Query @@ -191,12 +192,27 @@ async def edge_similarity_search( 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else '' ) - query: LiteralString = """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - WHERE ($group_ids IS NULL OR r.group_id IN $group_ids) - AND ($source_uuid IS NULL OR n.uuid IN [$source_uuid, $target_uuid]) - AND ($target_uuid IS NULL OR m.uuid IN [$source_uuid, $target_uuid]) - WITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score + query_params: dict[str, Any] = {} + + group_filter_query: LiteralString = '' + if group_ids is not None: + group_filter_query += 'WHERE r.group_id IN $group_ids' + query_params['group_ids'] = group_ids + query_params['source_node_uuid'] = source_node_uuid + query_params['target_node_uuid'] = target_node_uuid + + if source_node_uuid is not None: + group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])' + + if target_node_uuid is not None: + group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])' + + query: LiteralString = ( + """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + + group_filter_query + + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score WHERE score > $min_score RETURN r.uuid AS uuid, @@ -214,9 +230,11 @@ async def edge_similarity_search( ORDER BY score DESC LIMIT $limit """ + ) records, _, _ = await driver.execute_query( runtime_query + query, + query_params, search_vector=search_vector, source_uuid=source_node_uuid, target_uuid=target_node_uuid, @@ -325,11 +343,20 @@ async def node_similarity_search( 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else '' ) + query_params: dict[str, Any] = {} + + group_filter_query: LiteralString = '' + if group_ids is not None: + group_filter_query += 'WHERE n.group_id IN $group_ids' + query_params['group_ids'] = group_ids + records, _, _ = await driver.execute_query( runtime_query + """ MATCH (n:Entity) - WHERE $group_ids IS NULL OR n.group_id IN $group_ids + """ + + group_filter_query + + """ WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score WHERE score > $min_score RETURN @@ -342,6 +369,7 @@ async def node_similarity_search( ORDER BY score DESC LIMIT $limit """, + query_params, search_vector=search_vector, group_ids=group_ids, limit=limit, @@ -436,11 +464,20 @@ async def community_similarity_search( 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else '' ) + query_params: dict[str, Any] = {} + + group_filter_query: LiteralString = '' + if group_ids is not None: + group_filter_query += 'WHERE comm.group_id IN $group_ids' + query_params['group_ids'] = group_ids + records, _, _ = await driver.execute_query( runtime_query + """ MATCH (comm:Community) - WHERE ($group_ids IS NULL OR comm.group_id IN $group_ids) + """ + + group_filter_query + + """ WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score WHERE score > $min_score RETURN