context variable fix for agentd
This commit is contained in:
parent
b0725ed597
commit
f8a583a982
5 changed files with 53 additions and 12 deletions
|
|
@ -5,8 +5,7 @@ description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# "agentd>=0.2.2",
|
"agentd>=0.2.2",
|
||||||
"agentd>=0.2.1",
|
|
||||||
"aiofiles>=24.1.0",
|
"aiofiles>=24.1.0",
|
||||||
"cryptography>=45.0.6",
|
"cryptography>=45.0.6",
|
||||||
"docling>=2.41.0",
|
"docling>=2.41.0",
|
||||||
|
|
@ -23,7 +22,7 @@ dependencies = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
#agentd = { path = "/home/tato/Desktop/agentd" }
|
agentd = { path = "/home/tato/Desktop/agentd" }
|
||||||
torch = [
|
torch = [
|
||||||
{ index = "pytorch-cu128" },
|
{ index = "pytorch-cu128" },
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -10,13 +10,16 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
||||||
|
|
||||||
user = request.state.user
|
user = request.state.user
|
||||||
user_id = user.user_id
|
user_id = user.user_id
|
||||||
|
|
||||||
|
# Get JWT token from request cookie
|
||||||
|
jwt_token = request.cookies.get("auth_token")
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=True),
|
await chat_service.chat(prompt, user_id, jwt_token, previous_response_id=previous_response_id, stream=True),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
|
|
@ -26,7 +29,7 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await chat_service.chat(prompt, user_id, previous_response_id=previous_response_id, stream=False)
|
result = await chat_service.chat(prompt, user_id, jwt_token, previous_response_id=previous_response_id, stream=False)
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
async def langflow_endpoint(request: Request, chat_service, session_manager):
|
async def langflow_endpoint(request: Request, chat_service, session_manager):
|
||||||
|
|
|
||||||
27
src/auth_context.py
Normal file
27
src/auth_context.py
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""
|
||||||
|
Authentication context for tool functions.
|
||||||
|
Uses contextvars to safely pass user auth info through async calls.
|
||||||
|
"""
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# Context variables for current request authentication
|
||||||
|
_current_user_id: ContextVar[Optional[str]] = ContextVar('current_user_id', default=None)
|
||||||
|
_current_jwt_token: ContextVar[Optional[str]] = ContextVar('current_jwt_token', default=None)
|
||||||
|
|
||||||
|
def set_auth_context(user_id: str, jwt_token: str):
|
||||||
|
"""Set authentication context for the current async context"""
|
||||||
|
_current_user_id.set(user_id)
|
||||||
|
_current_jwt_token.set(jwt_token)
|
||||||
|
|
||||||
|
def get_current_user_id() -> Optional[str]:
|
||||||
|
"""Get current user ID from context"""
|
||||||
|
return _current_user_id.get()
|
||||||
|
|
||||||
|
def get_current_jwt_token() -> Optional[str]:
|
||||||
|
"""Get current JWT token from context"""
|
||||||
|
return _current_jwt_token.get()
|
||||||
|
|
||||||
|
def get_auth_context() -> tuple[Optional[str], Optional[str]]:
|
||||||
|
"""Get current authentication context (user_id, jwt_token)"""
|
||||||
|
return _current_user_id.get(), _current_jwt_token.get()
|
||||||
|
|
@ -1,13 +1,18 @@
|
||||||
from config.settings import clients, LANGFLOW_URL, FLOW_ID, LANGFLOW_KEY
|
from config.settings import clients, LANGFLOW_URL, FLOW_ID, LANGFLOW_KEY
|
||||||
from agent import async_chat, async_langflow, async_chat_stream, async_langflow_stream
|
from agent import async_chat, async_langflow, async_chat_stream, async_langflow_stream
|
||||||
|
from auth_context import set_auth_context
|
||||||
|
|
||||||
class ChatService:
|
class ChatService:
|
||||||
|
|
||||||
async def chat(self, prompt: str, user_id: str = None, previous_response_id: str = None, stream: bool = False):
|
async def chat(self, prompt: str, user_id: str = None, jwt_token: str = None, previous_response_id: str = None, stream: bool = False):
|
||||||
"""Handle chat requests using the patched OpenAI client"""
|
"""Handle chat requests using the patched OpenAI client"""
|
||||||
if not prompt:
|
if not prompt:
|
||||||
raise ValueError("Prompt is required")
|
raise ValueError("Prompt is required")
|
||||||
|
|
||||||
|
# Set authentication context for this request so tools can access it
|
||||||
|
if user_id and jwt_token:
|
||||||
|
set_auth_context(user_id, jwt_token)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return async_chat_stream(clients.patched_async_client, prompt, user_id, previous_response_id=previous_response_id)
|
return async_chat_stream(clients.patched_async_client, prompt, user_id, previous_response_id=previous_response_id)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,25 @@
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from agentd.tool_decorator import tool
|
from agentd.tool_decorator import tool
|
||||||
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
||||||
|
from auth_context import get_auth_context
|
||||||
|
|
||||||
class SearchService:
|
class SearchService:
|
||||||
def __init__(self, session_manager=None):
|
def __init__(self, session_manager=None):
|
||||||
self.session_manager = session_manager
|
self.session_manager = session_manager
|
||||||
|
|
||||||
@tool # TODO: This will be broken until we figure out how to pass JWT through @tool decorator
|
@tool
|
||||||
async def search_tool(self, query: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
async def search_tool(self, query: str) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Use this tool to search for documents relevant to the query.
|
Use this tool to search for documents relevant to the query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): query string to search the corpus
|
query (str): query string to search the corpus
|
||||||
user_id (str): user ID for access control (optional)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict (str, Any): {"results": [chunks]} on success
|
dict (str, Any): {"results": [chunks]} on success
|
||||||
"""
|
"""
|
||||||
|
# Get authentication context from the current async context
|
||||||
|
user_id, jwt_token = get_auth_context()
|
||||||
# Embed the query
|
# Embed the query
|
||||||
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query])
|
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query])
|
||||||
query_embedding = resp.data[0].embedding
|
query_embedding = resp.data[0].embedding
|
||||||
|
|
@ -46,8 +48,8 @@ class SearchService:
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return {"results": [], "error": "Authentication required"}
|
return {"results": [], "error": "Authentication required"}
|
||||||
|
|
||||||
# Get user's OpenSearch client with JWT for OIDC auth
|
# Get user's OpenSearch client with JWT for OIDC auth
|
||||||
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
opensearch_client = clients.create_user_opensearch_client(jwt_token)
|
||||||
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
|
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
|
||||||
|
|
||||||
# Transform results
|
# Transform results
|
||||||
|
|
@ -66,4 +68,9 @@ class SearchService:
|
||||||
|
|
||||||
async def search(self, query: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
async def search(self, query: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
||||||
"""Public search method for API endpoints"""
|
"""Public search method for API endpoints"""
|
||||||
return await self.search_tool(query, user_id, jwt_token)
|
# Set auth context if provided (for direct API calls)
|
||||||
|
if user_id and jwt_token:
|
||||||
|
from auth_context import set_auth_context
|
||||||
|
set_auth_context(user_id, jwt_token)
|
||||||
|
|
||||||
|
return await self.search_tool(query)
|
||||||
Loading…
Add table
Reference in a new issue