Compare commits
1 commit
main
...
pensar-aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ef649aa91 |
1 changed files with 38 additions and 1 deletions
|
|
@ -79,6 +79,21 @@ def get_responses_router() -> APIRouter:
|
|||
# Use default tools if none provided
|
||||
tools = request.tools or DEFAULT_TOOLS
|
||||
|
||||
# Construct an allow-list of allowed function names, sourced from the provided tools
|
||||
allowed_function_names = set()
|
||||
for tool in tools:
|
||||
try:
|
||||
# Tool format: {'type': 'function', 'function': {'name': 'my_func', ...}}
|
||||
if isinstance(tool, dict) and tool.get("type") == "function":
|
||||
function_obj = tool.get("function", {})
|
||||
function_name = function_obj.get("name")
|
||||
if function_name:
|
||||
allowed_function_names.add(function_name)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error extracting function name from tool {tool}: {e}"
|
||||
)
|
||||
|
||||
# Call the API
|
||||
response = await call_openai_api_for_model(
|
||||
input_text=request.input,
|
||||
|
|
@ -104,6 +119,28 @@ def get_responses_router() -> APIRouter:
|
|||
arguments_str = item.get("arguments", "{}")
|
||||
call_id = item.get("call_id", f"call_{uuid.uuid4().hex}")
|
||||
|
||||
# Check if function_name is in the allow-list
|
||||
if function_name not in allowed_function_names:
|
||||
logger.warning(
|
||||
f"Function '{function_name}' called by LLM is not in the allowed list; skipping execution."
|
||||
)
|
||||
# Create an error output for this tool call
|
||||
function_result = (
|
||||
f"Function '{function_name}' is not allowed or not available."
|
||||
)
|
||||
output_status = "error"
|
||||
processed_call = ResponseToolCall(
|
||||
id=call_id,
|
||||
type="function",
|
||||
function=FunctionCall(name=function_name, arguments=arguments_str),
|
||||
output=ToolCallOutput(
|
||||
status=output_status,
|
||||
data={"result": function_result},
|
||||
),
|
||||
)
|
||||
processed_tool_calls.append(processed_call)
|
||||
continue # Don't dispatch forbidden functions
|
||||
|
||||
# Create a format the dispatcher can handle
|
||||
tool_call = {
|
||||
"id": call_id,
|
||||
|
|
@ -146,4 +183,4 @@ def get_responses_router() -> APIRouter:
|
|||
|
||||
return response_obj
|
||||
|
||||
return router
|
||||
return router
|
||||
Loading…
Add table
Reference in a new issue