context variable fix for agentd

This commit is contained in:
phact 2025-08-11 21:57:05 -04:00
parent b0725ed597
commit f8a583a982
5 changed files with 53 additions and 12 deletions

View file

@ -5,8 +5,7 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
# "agentd>=0.2.2",
"agentd>=0.2.1",
"agentd>=0.2.2",
"aiofiles>=24.1.0",
"cryptography>=45.0.6",
"docling>=2.41.0",
@ -23,7 +22,7 @@ dependencies = [
]
[tool.uv.sources]
#agentd = { path = "/home/tato/Desktop/agentd" }
agentd = { path = "/home/tato/Desktop/agentd" }
torch = [
{ index = "pytorch-cu128" },
]

View file

@ -10,13 +10,16 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
user = request.state.user
user_id = user.user_id
# Get JWT token from request cookie
jwt_token = request.cookies.get("auth_token")
if not prompt:
return JSONResponse({"error": "Prompt is required"}, status_code=400)
if stream:
return StreamingResponse(
await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=True),
await chat_service.chat(prompt, user_id, jwt_token, previous_response_id=previous_response_id, stream=True),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
@ -26,7 +29,7 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
}
)
else:
result = await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=False)
result = await chat_service.chat(prompt, user_id, jwt_token, previous_response_id=previous_response_id, stream=False)
return JSONResponse(result)
async def langflow_endpoint(request: Request, chat_service, session_manager):

27
src/auth_context.py Normal file
View file

@ -0,0 +1,27 @@
"""
Authentication context for tool functions.
Uses contextvars to safely pass user auth info through async calls.
"""
from contextvars import ContextVar
from typing import Optional
# Context variables for current request authentication
_current_user_id: ContextVar[Optional[str]] = ContextVar('current_user_id', default=None)
_current_jwt_token: ContextVar[Optional[str]] = ContextVar('current_jwt_token', default=None)
def set_auth_context(user_id: str, jwt_token: str):
"""Set authentication context for the current async context"""
_current_user_id.set(user_id)
_current_jwt_token.set(jwt_token)
def get_current_user_id() -> Optional[str]:
"""Get current user ID from context"""
return _current_user_id.get()
def get_current_jwt_token() -> Optional[str]:
"""Get current JWT token from context"""
return _current_jwt_token.get()
def get_auth_context() -> tuple[Optional[str], Optional[str]]:
"""Get current authentication context (user_id, jwt_token)"""
return _current_user_id.get(), _current_jwt_token.get()

View file

@ -1,13 +1,18 @@
from config.settings import clients, LANGFLOW_URL, FLOW_ID, LANGFLOW_KEY
from agent import async_chat, async_langflow, async_chat_stream, async_langflow_stream
from auth_context import set_auth_context
class ChatService:
async def chat(self, prompt: str, user_id: str = None, previous_response_id: str = None, stream: bool = False):
async def chat(self, prompt: str, user_id: str = None, jwt_token: str = None, previous_response_id: str = None, stream: bool = False):
"""Handle chat requests using the patched OpenAI client"""
if not prompt:
raise ValueError("Prompt is required")
# Set authentication context for this request so tools can access it
if user_id and jwt_token:
set_auth_context(user_id, jwt_token)
if stream:
return async_chat_stream(clients.patched_async_client, prompt, user_id, previous_response_id=previous_response_id)
else:

View file

@ -1,23 +1,25 @@
from typing import Any, Dict, Optional
from agentd.tool_decorator import tool
from config.settings import clients, INDEX_NAME, EMBED_MODEL
from auth_context import get_auth_context
class SearchService:
def __init__(self, session_manager=None):
self.session_manager = session_manager
@tool # TODO: This will be broken until we figure out how to pass JWT through @tool decorator
async def search_tool(self, query: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
@tool
async def search_tool(self, query: str) -> Dict[str, Any]:
"""
Use this tool to search for documents relevant to the query.
Args:
query (str): query string to search the corpus
user_id (str): user ID for access control (optional)
Returns:
dict (str, Any): {"results": [chunks]} on success
"""
# Get authentication context from the current async context
user_id, jwt_token = get_auth_context()
# Embed the query
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query])
query_embedding = resp.data[0].embedding
@ -46,8 +48,8 @@ class SearchService:
if not user_id:
return {"results": [], "error": "Authentication required"}
# Get user's OpenSearch client with JWT for OIDC auth
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
# Get user's OpenSearch client with JWT for OIDC auth
opensearch_client = clients.create_user_opensearch_client(jwt_token)
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
# Transform results
@ -66,4 +68,9 @@ class SearchService:
async def search(self, query: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
"""Public search method for API endpoints"""
return await self.search_tool(query, user_id, jwt_token)
# Set auth context if provided (for direct API calls)
if user_id and jwt_token:
from auth_context import set_auth_context
set_auth_context(user_id, jwt_token)
return await self.search_tool(query)