From f8a583a982f1a37f016ddf014ffa0ba3d951dcc9 Mon Sep 17 00:00:00 2001 From: phact Date: Mon, 11 Aug 2025 21:57:05 -0400 Subject: [PATCH] context variable fix for agentd --- pyproject.toml | 5 ++--- src/api/chat.py | 7 +++++-- src/auth_context.py | 27 +++++++++++++++++++++++++++ src/services/chat_service.py | 7 ++++++- src/services/search_service.py | 19 +++++++++++++------ 5 files changed, 53 insertions(+), 12 deletions(-) create mode 100644 src/auth_context.py diff --git a/pyproject.toml b/pyproject.toml index ae884737..4233fc43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,8 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.13" dependencies = [ -# "agentd>=0.2.2", - "agentd>=0.2.1", + "agentd>=0.2.2", "aiofiles>=24.1.0", "cryptography>=45.0.6", "docling>=2.41.0", @@ -23,7 +22,7 @@ dependencies = [ ] [tool.uv.sources] -#agentd = { path = "/home/tato/Desktop/agentd" } +agentd = { path = "/home/tato/Desktop/agentd" } torch = [ { index = "pytorch-cu128" }, ] diff --git a/src/api/chat.py b/src/api/chat.py index 6dd8a790..42db0134 100644 --- a/src/api/chat.py +++ b/src/api/chat.py @@ -10,13 +10,16 @@ async def chat_endpoint(request: Request, chat_service, session_manager): user = request.state.user user_id = user.user_id + + # Get JWT token from request cookie + jwt_token = request.cookies.get("auth_token") if not prompt: return JSONResponse({"error": "Prompt is required"}, status_code=400) if stream: return StreamingResponse( - await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=True), + await chat_service.chat(prompt, user_id, jwt_token, previous_response_id=previous_response_id, stream=True), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", @@ -26,7 +29,7 @@ async def chat_endpoint(request: Request, chat_service, session_manager): } ) else: - result = await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=False) + result = await chat_service.chat(prompt, user_id, jwt_token, previous_response_id=previous_response_id, stream=False) return JSONResponse(result) async def langflow_endpoint(request: Request, chat_service, session_manager): diff --git a/src/auth_context.py b/src/auth_context.py new file mode 100644 index 00000000..25647eab --- /dev/null +++ b/src/auth_context.py @@ -0,0 +1,27 @@ +""" +Authentication context for tool functions. +Uses contextvars to safely pass user auth info through async calls. +""" +from contextvars import ContextVar +from typing import Optional + +# Context variables for current request authentication +_current_user_id: ContextVar[Optional[str]] = ContextVar('current_user_id', default=None) +_current_jwt_token: ContextVar[Optional[str]] = ContextVar('current_jwt_token', default=None) + +def set_auth_context(user_id: str, jwt_token: str): + """Set authentication context for the current async context""" + _current_user_id.set(user_id) + _current_jwt_token.set(jwt_token) + +def get_current_user_id() -> Optional[str]: + """Get current user ID from context""" + return _current_user_id.get() + +def get_current_jwt_token() -> Optional[str]: + """Get current JWT token from context""" + return _current_jwt_token.get() + +def get_auth_context() -> tuple[Optional[str], Optional[str]]: + """Get current authentication context (user_id, jwt_token)""" + return _current_user_id.get(), _current_jwt_token.get() \ No newline at end of file diff --git a/src/services/chat_service.py b/src/services/chat_service.py index 23105b6f..acc93296 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -1,13 +1,18 @@ from config.settings import clients, LANGFLOW_URL, FLOW_ID, LANGFLOW_KEY from agent import async_chat, async_langflow, async_chat_stream, async_langflow_stream +from auth_context import set_auth_context class ChatService: - async def chat(self, prompt: str, user_id: str = None, previous_response_id: str = None, stream: bool = False): + async def chat(self, prompt: str, user_id: str = None, jwt_token: str = None, previous_response_id: str = None, stream: bool = False): """Handle chat requests using the patched OpenAI client""" if not prompt: raise ValueError("Prompt is required") + # Set authentication context for this request so tools can access it + if user_id and jwt_token: + set_auth_context(user_id, jwt_token) + if stream: return async_chat_stream(clients.patched_async_client, prompt, user_id, previous_response_id=previous_response_id) else: diff --git a/src/services/search_service.py b/src/services/search_service.py index 4182a282..06e80bdb 100644 --- a/src/services/search_service.py +++ b/src/services/search_service.py @@ -1,23 +1,25 @@ from typing import Any, Dict, Optional from agentd.tool_decorator import tool from config.settings import clients, INDEX_NAME, EMBED_MODEL +from auth_context import get_auth_context class SearchService: def __init__(self, session_manager=None): self.session_manager = session_manager - @tool # TODO: This will be broken until we figure out how to pass JWT through @tool decorator - async def search_tool(self, query: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]: + @tool + async def search_tool(self, query: str) -> Dict[str, Any]: """ Use this tool to search for documents relevant to the query. Args: query (str): query string to search the corpus - user_id (str): user ID for access control (optional) Returns: dict (str, Any): {"results": [chunks]} on success """ + # Get authentication context from the current async context + user_id, jwt_token = get_auth_context() # Embed the query resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query]) query_embedding = resp.data[0].embedding @@ -46,8 +48,8 @@ class SearchService: if not user_id: return {"results": [], "error": "Authentication required"} - # Get user's OpenSearch client with JWT for OIDC auth - opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token) + # Get user's OpenSearch client with JWT for OIDC auth + opensearch_client = clients.create_user_opensearch_client(jwt_token) results = await opensearch_client.search(index=INDEX_NAME, body=search_body) # Transform results @@ -66,4 +68,9 @@ class SearchService: async def search(self, query: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]: """Public search method for API endpoints""" - return await self.search_tool(query, user_id, jwt_token) \ No newline at end of file + # Set auth context if provided (for direct API calls) + if user_id and jwt_token: + from auth_context import set_auth_context + set_auth_context(user_id, jwt_token) + + return await self.search_tool(query) \ No newline at end of file