context variable fix for agentd
This commit is contained in:
parent
b0725ed597
commit
f8a583a982
5 changed files with 53 additions and 12 deletions
|
|
@ -5,8 +5,7 @@ description = "Add your description here"
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
# "agentd>=0.2.2",
|
||||
"agentd>=0.2.1",
|
||||
"agentd>=0.2.2",
|
||||
"aiofiles>=24.1.0",
|
||||
"cryptography>=45.0.6",
|
||||
"docling>=2.41.0",
|
||||
|
|
@ -23,7 +22,7 @@ dependencies = [
|
|||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
#agentd = { path = "/home/tato/Desktop/agentd" }
|
||||
agentd = { path = "/home/tato/Desktop/agentd" }
|
||||
torch = [
|
||||
{ index = "pytorch-cu128" },
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,13 +10,16 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
|||
|
||||
user = request.state.user
|
||||
user_id = user.user_id
|
||||
|
||||
# Get JWT token from request cookie
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
|
||||
if not prompt:
|
||||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=True),
|
||||
await chat_service.chat(prompt, user_id, jwt_token, previous_response_id=previous_response_id, stream=True),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
|
|
@ -26,7 +29,7 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
|||
}
|
||||
)
|
||||
else:
|
||||
result = await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=False)
|
||||
result = await chat_service.chat(prompt, user_id, jwt_token, previous_response_id=previous_response_id, stream=False)
|
||||
return JSONResponse(result)
|
||||
|
||||
async def langflow_endpoint(request: Request, chat_service, session_manager):
|
||||
|
|
|
|||
27
src/auth_context.py
Normal file
27
src/auth_context.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""
|
||||
Authentication context for tool functions.
|
||||
Uses contextvars to safely pass user auth info through async calls.
|
||||
"""
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
# Context variables for current request authentication
|
||||
_current_user_id: ContextVar[Optional[str]] = ContextVar('current_user_id', default=None)
|
||||
_current_jwt_token: ContextVar[Optional[str]] = ContextVar('current_jwt_token', default=None)
|
||||
|
||||
def set_auth_context(user_id: str, jwt_token: str):
|
||||
"""Set authentication context for the current async context"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_jwt_token.set(jwt_token)
|
||||
|
||||
def get_current_user_id() -> Optional[str]:
|
||||
"""Get current user ID from context"""
|
||||
return _current_user_id.get()
|
||||
|
||||
def get_current_jwt_token() -> Optional[str]:
|
||||
"""Get current JWT token from context"""
|
||||
return _current_jwt_token.get()
|
||||
|
||||
def get_auth_context() -> tuple[Optional[str], Optional[str]]:
|
||||
"""Get current authentication context (user_id, jwt_token)"""
|
||||
return _current_user_id.get(), _current_jwt_token.get()
|
||||
|
|
@ -1,13 +1,18 @@
|
|||
from config.settings import clients, LANGFLOW_URL, FLOW_ID, LANGFLOW_KEY
|
||||
from agent import async_chat, async_langflow, async_chat_stream, async_langflow_stream
|
||||
from auth_context import set_auth_context
|
||||
|
||||
class ChatService:
|
||||
|
||||
async def chat(self, prompt: str, user_id: str = None, previous_response_id: str = None, stream: bool = False):
|
||||
async def chat(self, prompt: str, user_id: str = None, jwt_token: str = None, previous_response_id: str = None, stream: bool = False):
|
||||
"""Handle chat requests using the patched OpenAI client"""
|
||||
if not prompt:
|
||||
raise ValueError("Prompt is required")
|
||||
|
||||
# Set authentication context for this request so tools can access it
|
||||
if user_id and jwt_token:
|
||||
set_auth_context(user_id, jwt_token)
|
||||
|
||||
if stream:
|
||||
return async_chat_stream(clients.patched_async_client, prompt, user_id, previous_response_id=previous_response_id)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,23 +1,25 @@
|
|||
from typing import Any, Dict, Optional
|
||||
from agentd.tool_decorator import tool
|
||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
||||
from auth_context import get_auth_context
|
||||
|
||||
class SearchService:
|
||||
def __init__(self, session_manager=None):
|
||||
self.session_manager = session_manager
|
||||
|
||||
@tool # TODO: This will be broken until we figure out how to pass JWT through @tool decorator
|
||||
async def search_tool(self, query: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
||||
@tool
|
||||
async def search_tool(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Use this tool to search for documents relevant to the query.
|
||||
|
||||
Args:
|
||||
query (str): query string to search the corpus
|
||||
user_id (str): user ID for access control (optional)
|
||||
|
||||
Returns:
|
||||
dict (str, Any): {"results": [chunks]} on success
|
||||
"""
|
||||
# Get authentication context from the current async context
|
||||
user_id, jwt_token = get_auth_context()
|
||||
# Embed the query
|
||||
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query])
|
||||
query_embedding = resp.data[0].embedding
|
||||
|
|
@ -46,8 +48,8 @@ class SearchService:
|
|||
if not user_id:
|
||||
return {"results": [], "error": "Authentication required"}
|
||||
|
||||
# Get user's OpenSearch client with JWT for OIDC auth
|
||||
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
||||
# Get user's OpenSearch client with JWT for OIDC auth
|
||||
opensearch_client = clients.create_user_opensearch_client(jwt_token)
|
||||
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
|
||||
|
||||
# Transform results
|
||||
|
|
@ -66,4 +68,9 @@ class SearchService:
|
|||
|
||||
async def search(self, query: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
||||
"""Public search method for API endpoints"""
|
||||
return await self.search_tool(query, user_id, jwt_token)
|
||||
# Set auth context if provided (for direct API calls)
|
||||
if user_id and jwt_token:
|
||||
from auth_context import set_auth_context
|
||||
set_auth_context(user_id, jwt_token)
|
||||
|
||||
return await self.search_tool(query)
|
||||
Loading…
Add table
Reference in a new issue