Merge 4753e3c49c into 2a0f835ffe
This commit is contained in:
commit
0130525754
5 changed files with 564 additions and 37 deletions
|
|
@ -160,7 +160,7 @@ class Graph:
|
|||
return self._tenant_id
|
||||
|
||||
def get_value_with_variable(self,value: str) -> Any:
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||
out_parts = []
|
||||
last = 0
|
||||
|
||||
|
|
@ -368,7 +368,7 @@ class Canvas(Graph):
|
|||
|
||||
if kwargs.get("webhook_payload"):
|
||||
for k, cpn in self.components.items():
|
||||
if self.components[k]["obj"].component_name.lower() == "webhook":
|
||||
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
|
||||
for kk, vv in kwargs["webhook_payload"].items():
|
||||
self.components[k]["obj"].set_output(kk, vv)
|
||||
|
||||
|
|
|
|||
|
|
@ -393,7 +393,7 @@ class ComponentParamBase(ABC):
|
|||
class ComponentBase(ABC):
|
||||
component_name: str
|
||||
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class BeginParam(UserFillUpParam):
|
|||
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task"])
|
||||
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task","Webhook"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return getattr(self, "inputs")
|
||||
|
|
|
|||
|
|
@ -23,6 +23,26 @@ class WebhookParam(ComponentParamBase):
|
|||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.security = {
|
||||
"auth_type": "none",
|
||||
}
|
||||
self.schema = {}
|
||||
self.execution_mode = "Immediately"
|
||||
self.response = {}
|
||||
self.outputs = {
|
||||
"query": {
|
||||
"value": {},
|
||||
"type": "object"
|
||||
},
|
||||
"headers": {
|
||||
"value": {},
|
||||
"type": "object"
|
||||
},
|
||||
"body": {
|
||||
"value": {},
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return getattr(self, "inputs")
|
||||
|
|
|
|||
|
|
@ -14,14 +14,19 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import jwt
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db import CanvasCategory
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
|
|
@ -132,48 +137,550 @@ def delete_agent(tenant_id: str, agent_id: str):
|
|||
UserCanvasService.delete_by_id(agent_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
_rate_limit_cache = {}
|
||||
|
||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
async def webhook(tenant_id: str, agent_id: str):
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
@manager.route('/webhook/<agent_id>', methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
|
||||
async def webhook(agent_id: str):
|
||||
# 1. Fetch canvas by agent_id
|
||||
exists, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not exists:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
|
||||
|
||||
# 2. Check canvas category
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
|
||||
|
||||
# 3. Load DSL from canvas
|
||||
dsl = getattr(cvs, "dsl", None)
|
||||
if not isinstance(dsl, dict):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
|
||||
|
||||
# 4. Check webhook configuration in DSL
|
||||
components = dsl.get("components", {})
|
||||
for k, _ in components.items():
|
||||
cpn_obj = components[k]["obj"]
|
||||
if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
|
||||
webhook_cfg = cpn_obj["params"]
|
||||
|
||||
if not webhook_cfg:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
|
||||
|
||||
# 5. Validate request method against webhook_cfg.methods
|
||||
allowed_methods = webhook_cfg.get("methods", [])
|
||||
request_method = request.method.upper()
|
||||
if allowed_methods and request_method not in allowed_methods:
|
||||
return get_data_error_result(
|
||||
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
|
||||
),RetCode.BAD_REQUEST
|
||||
|
||||
# 6. Validate webhook security
|
||||
async def validate_webhook_security(security_cfg: dict):
|
||||
"""Validate webhook security rules based on security configuration."""
|
||||
|
||||
if not security_cfg:
|
||||
return # No security config → allowed by default
|
||||
|
||||
# 1. Validate max body size
|
||||
await _validate_max_body_size(security_cfg)
|
||||
|
||||
# 2. Validate IP whitelist
|
||||
_validate_ip_whitelist(security_cfg)
|
||||
|
||||
# # 3. Validate rate limiting
|
||||
_validate_rate_limit(security_cfg)
|
||||
|
||||
# 4. Validate authentication
|
||||
auth_type = security_cfg.get("auth_type", "none")
|
||||
|
||||
if auth_type == "none":
|
||||
return
|
||||
|
||||
if auth_type == "token":
|
||||
_validate_token_auth(security_cfg)
|
||||
|
||||
elif auth_type == "basic":
|
||||
_validate_basic_auth(security_cfg)
|
||||
|
||||
elif auth_type == "jwt":
|
||||
_validate_jwt_auth(security_cfg)
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported auth_type: {auth_type}")
|
||||
|
||||
async def _validate_max_body_size(security_cfg):
|
||||
"""Check request size does not exceed max_body_size."""
|
||||
max_size = security_cfg.get("max_body_size")
|
||||
if not max_size:
|
||||
return
|
||||
|
||||
# Convert "10MB" → bytes
|
||||
units = {"kb": 1024, "mb": 1024**2, "gb": 1024**3}
|
||||
size_str = max_size.lower()
|
||||
|
||||
for suffix, factor in units.items():
|
||||
if size_str.endswith(suffix):
|
||||
limit = int(size_str.replace(suffix, "")) * factor
|
||||
break
|
||||
else:
|
||||
raise Exception("Invalid max_body_size format")
|
||||
|
||||
content_length = request.content_length or 0
|
||||
if content_length > limit:
|
||||
raise Exception(f"Request body too large: {content_length} > {limit}")
|
||||
|
||||
def _validate_ip_whitelist(security_cfg):
|
||||
"""Allow only IPs listed in ip_whitelist."""
|
||||
whitelist = security_cfg.get("ip_whitelist", [])
|
||||
if not whitelist:
|
||||
return
|
||||
|
||||
client_ip = request.remote_addr
|
||||
|
||||
|
||||
for rule in whitelist:
|
||||
if "/" in rule:
|
||||
# CIDR notation
|
||||
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
|
||||
return
|
||||
else:
|
||||
# Single IP
|
||||
if client_ip == rule:
|
||||
return
|
||||
|
||||
raise Exception(f"IP {client_ip} is not allowed by whitelist")
|
||||
|
||||
def _validate_rate_limit(security_cfg):
|
||||
"""Simple in-memory rate limiting."""
|
||||
rl = security_cfg.get("rate_limit")
|
||||
if not rl:
|
||||
return
|
||||
|
||||
limit = rl.get("limit", 60)
|
||||
per = rl.get("per", "minute")
|
||||
|
||||
window = {"second": 1, "minute": 60, "hour": 3600, "day": 86400}.get(per, 60)
|
||||
key = f"rl:{agent_id}"
|
||||
|
||||
now = int(time.time())
|
||||
bucket = _rate_limit_cache.get(key, {"ts": now, "count": 0})
|
||||
|
||||
# Reset window
|
||||
if now - bucket["ts"] > window:
|
||||
bucket = {"ts": now, "count": 0}
|
||||
|
||||
bucket["count"] += 1
|
||||
_rate_limit_cache[key] = bucket
|
||||
|
||||
if bucket["count"] > limit:
|
||||
raise Exception("Too many requests (rate limit exceeded)")
|
||||
|
||||
def _validate_token_auth(security_cfg):
|
||||
"""Validate header-based token authentication."""
|
||||
token_cfg = security_cfg.get("token",{})
|
||||
header = token_cfg.get("token_header")
|
||||
token_value = token_cfg.get("token_value")
|
||||
|
||||
provided = request.headers.get(header)
|
||||
if provided != token_value:
|
||||
raise Exception("Invalid token authentication")
|
||||
|
||||
def _validate_basic_auth(security_cfg):
|
||||
"""Validate HTTP Basic Auth credentials."""
|
||||
auth_cfg = security_cfg.get("basic_auth", {})
|
||||
username = auth_cfg.get("username")
|
||||
password = auth_cfg.get("password")
|
||||
|
||||
auth = request.authorization
|
||||
if not auth or auth.username != username or auth.password != password:
|
||||
raise Exception("Invalid Basic Auth credentials")
|
||||
|
||||
def _validate_jwt_auth(security_cfg):
|
||||
"""Validate JWT token in Authorization header."""
|
||||
jwt_cfg = security_cfg.get("jwt", {})
|
||||
secret = jwt_cfg.get("secret")
|
||||
required_claims = jwt_cfg.get("required_claims", [])
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise Exception("Missing Bearer token")
|
||||
|
||||
token = auth_header[len("Bearer "):].strip()
|
||||
if not token:
|
||||
raise Exception("Empty Bearer token")
|
||||
|
||||
alg = (jwt_cfg.get("algorithm") or "HS256").upper()
|
||||
|
||||
decode_kwargs = {
|
||||
"key": secret,
|
||||
"algorithms": [alg],
|
||||
}
|
||||
options = {}
|
||||
if jwt_cfg.get("audience"):
|
||||
decode_kwargs["audience"] = jwt_cfg["audience"]
|
||||
options["verify_aud"] = True
|
||||
else:
|
||||
options["verify_aud"] = False
|
||||
|
||||
if jwt_cfg.get("issuer"):
|
||||
decode_kwargs["issuer"] = jwt_cfg["issuer"]
|
||||
options["verify_iss"] = True
|
||||
else:
|
||||
options["verify_iss"] = False
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
token,
|
||||
options=options,
|
||||
**decode_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Invalid JWT: {str(e)}")
|
||||
|
||||
raw_required_claims = jwt_cfg.get("required_claims", [])
|
||||
if isinstance(raw_required_claims, str):
|
||||
required_claims = [raw_required_claims]
|
||||
elif isinstance(raw_required_claims, (list, tuple, set)):
|
||||
required_claims = list(raw_required_claims)
|
||||
else:
|
||||
required_claims = []
|
||||
|
||||
required_claims = [
|
||||
c for c in required_claims
|
||||
if isinstance(c, str) and c.strip()
|
||||
]
|
||||
|
||||
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
|
||||
for claim in required_claims:
|
||||
if claim in RESERVED_CLAIMS:
|
||||
raise Exception(f"Reserved JWT claim cannot be required: {claim}")
|
||||
|
||||
for claim in required_claims:
|
||||
if claim not in decoded:
|
||||
raise Exception(f"Missing JWT claim: {claim}")
|
||||
|
||||
return decoded
|
||||
|
||||
try:
|
||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
||||
security_config=webhook_cfg.get("security", {})
|
||||
await validate_webhook_security(security_config)
|
||||
except Exception as e:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||
if not isinstance(cvs.dsl, str):
|
||||
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
try:
|
||||
canvas = Canvas(dsl, cvs.user_id, agent_id)
|
||||
except Exception as e:
|
||||
return get_json_result(
|
||||
data=False, message=str(e),
|
||||
code=RetCode.EXCEPTION_ERROR)
|
||||
code=RetCode.EXCEPTION_ERROR), RetCode.EXCEPTION_ERROR
|
||||
|
||||
# 7. Parse request body
|
||||
async def parse_webhook_request(content_type):
|
||||
"""Parse request based on content-type and return structured data."""
|
||||
|
||||
# 1. Query
|
||||
query_data = {k: v for k, v in request.args.items()}
|
||||
|
||||
# 2. Headers
|
||||
header_data = {k: v for k, v in request.headers.items()}
|
||||
|
||||
# 3. Body
|
||||
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
if ctype != content_type:
|
||||
raise ValueError(
|
||||
f"Invalid Content-Type: expect '{content_type}', got '{ctype or 'empty'}'"
|
||||
)
|
||||
|
||||
body_data: dict = {}
|
||||
|
||||
async def sse():
|
||||
nonlocal canvas
|
||||
try:
|
||||
async for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
if ctype == "application/json":
|
||||
body_data = await request.get_json() or {}
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
||||
elif ctype == "multipart/form-data":
|
||||
nonlocal canvas
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
body_data = {}
|
||||
|
||||
for key, value in form.items():
|
||||
body_data[key] = value
|
||||
|
||||
for key, file in files.items():
|
||||
desc = FileService.upload_info(
|
||||
cvs.user_id, # user
|
||||
file, # FileStorage
|
||||
None # url (None for webhook)
|
||||
)
|
||||
file_parsed= await canvas.get_files_async([desc])
|
||||
body_data[key] = file_parsed
|
||||
|
||||
elif ctype == "application/x-www-form-urlencoded":
|
||||
form = await request.form
|
||||
body_data = dict(form)
|
||||
|
||||
else:
|
||||
# text/plain / octet-stream / empty / unknown
|
||||
raw = await request.get_data()
|
||||
if raw:
|
||||
try:
|
||||
body_data = json.loads(raw.decode("utf-8"))
|
||||
except Exception:
|
||||
body_data = {}
|
||||
else:
|
||||
body_data = {}
|
||||
|
||||
except Exception:
|
||||
body_data = {}
|
||||
|
||||
return {
|
||||
"query": query_data,
|
||||
"headers": header_data,
|
||||
"body": body_data,
|
||||
"content_type": ctype,
|
||||
}
|
||||
|
||||
def extract_by_schema(data, schema, name="section"):
|
||||
"""
|
||||
Extract only fields defined in schema.
|
||||
Required fields must exist.
|
||||
Optional fields default to type-based default values.
|
||||
Type validation included.
|
||||
"""
|
||||
props = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
extracted = {}
|
||||
|
||||
for field, field_schema in props.items():
|
||||
field_type = field_schema.get("type")
|
||||
|
||||
# 1. Required field missing
|
||||
if field in required and field not in data:
|
||||
raise Exception(f"{name} missing required field: {field}")
|
||||
|
||||
# 2. Optional → default value
|
||||
if field not in data:
|
||||
extracted[field] = default_for_type(field_type)
|
||||
continue
|
||||
|
||||
raw_value = data[field]
|
||||
|
||||
# 3. Auto convert value
|
||||
try:
|
||||
value = auto_cast_value(raw_value, field_type)
|
||||
except Exception as e:
|
||||
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
|
||||
|
||||
# 4. Type validation
|
||||
if not validate_type(value, field_type):
|
||||
raise Exception(
|
||||
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
extracted[field] = value
|
||||
|
||||
return extracted
|
||||
|
||||
|
||||
def default_for_type(t):
|
||||
"""Return default value for the given schema type."""
|
||||
if t == "file":
|
||||
return []
|
||||
if t == "object":
|
||||
return {}
|
||||
if t == "boolean":
|
||||
return False
|
||||
if t == "number":
|
||||
return 0
|
||||
if t == "string":
|
||||
return ""
|
||||
if t and t.startswith("array"):
|
||||
return []
|
||||
if t == "null":
|
||||
return None
|
||||
return None
|
||||
|
||||
def auto_cast_value(value, expected_type):
|
||||
"""Convert string values into schema type when possible."""
|
||||
|
||||
# Non-string values already good
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
v = value.strip()
|
||||
|
||||
# Boolean
|
||||
if expected_type == "boolean":
|
||||
if v.lower() in ["true", "1"]:
|
||||
return True
|
||||
if v.lower() in ["false", "0"]:
|
||||
return False
|
||||
raise Exception(f"Cannot convert '{value}' to boolean")
|
||||
|
||||
# Number
|
||||
if expected_type == "number":
|
||||
# integer
|
||||
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
|
||||
return int(v)
|
||||
|
||||
# float
|
||||
try:
|
||||
return float(v)
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to number")
|
||||
|
||||
# Object
|
||||
if expected_type == "object":
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
else:
|
||||
raise Exception("JSON is not an object")
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to object")
|
||||
|
||||
# Array <T>
|
||||
if expected_type.startswith("array"):
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
if isinstance(parsed, list):
|
||||
return parsed
|
||||
else:
|
||||
raise Exception("JSON is not an array")
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to array")
|
||||
|
||||
# String (accept original)
|
||||
if expected_type == "string":
|
||||
return value
|
||||
|
||||
# File
|
||||
if expected_type == "file":
|
||||
return value
|
||||
# Default: do nothing
|
||||
return value
|
||||
|
||||
|
||||
def validate_type(value, t):
|
||||
"""Validate value type against schema type t."""
|
||||
if t == "file":
|
||||
return isinstance(value, list)
|
||||
|
||||
if t == "string":
|
||||
return isinstance(value, str)
|
||||
|
||||
if t == "number":
|
||||
return isinstance(value, (int, float))
|
||||
|
||||
if t == "boolean":
|
||||
return isinstance(value, bool)
|
||||
|
||||
if t == "object":
|
||||
return isinstance(value, dict)
|
||||
|
||||
# array<string> / array<number> / array<object>
|
||||
if t.startswith("array"):
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if "<" in t and ">" in t:
|
||||
inner = t[t.find("<") + 1 : t.find(">")]
|
||||
|
||||
# Check each element type
|
||||
for item in value:
|
||||
if not validate_type(item, inner):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
|
||||
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
|
||||
|
||||
# Extract strictly by schema
|
||||
try:
|
||||
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
|
||||
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
|
||||
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
|
||||
except Exception as e:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||
|
||||
clean_request = {
|
||||
"query": query_clean,
|
||||
"headers": header_clean,
|
||||
"body": body_clean
|
||||
}
|
||||
|
||||
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
|
||||
response_cfg = webhook_cfg.get("response", {})
|
||||
|
||||
if execution_mode == "Immediately":
|
||||
status = response_cfg.get("status", 200)
|
||||
body_tpl = response_cfg.get("body_template", "")
|
||||
|
||||
def parse_body(body: str):
|
||||
if not body:
|
||||
return None, "application/json"
|
||||
|
||||
try:
|
||||
parsed = json.loads(body)
|
||||
return parsed, "application/json"
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return body, "text/plain"
|
||||
|
||||
|
||||
body, content_type = parse_body(body_tpl)
|
||||
resp = Response(
|
||||
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
|
||||
status=status,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
async def background_run():
|
||||
try:
|
||||
async for _ in canvas.run(
|
||||
query="",
|
||||
user_id=cvs.user_id,
|
||||
webhook_payload=clean_request
|
||||
):
|
||||
pass # or log/save ans
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"Webhook background run failed: {e}")
|
||||
|
||||
asyncio.create_task(background_run())
|
||||
return resp
|
||||
else:
|
||||
async def sse():
|
||||
nonlocal canvas
|
||||
contents: list[str] = []
|
||||
|
||||
try:
|
||||
async for ans in canvas.run(
|
||||
query="",
|
||||
user_id=cvs.user_id,
|
||||
webhook_payload=clean_request,
|
||||
):
|
||||
if ans["event"] == "message":
|
||||
content = ans["data"]["content"]
|
||||
if ans["data"].get("start_to_think", False):
|
||||
content = "<think>"
|
||||
elif ans["data"].get("end_to_think", False):
|
||||
content = "</think>"
|
||||
if content:
|
||||
contents.append(content)
|
||||
|
||||
final_content = "".join(contents)
|
||||
yield json.dumps(final_content, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
yield json.dumps({"code": 500, "message": str(e)}, ensure_ascii=False)
|
||||
|
||||
resp = Response(sse(), mimetype="application/json")
|
||||
return resp
|
||||
Loading…
Add table
Reference in a new issue