fix:async issue and sensitive logging

This commit is contained in:
buua436 2025-12-11 11:32:20 +08:00
parent c610bb605a
commit c6ca4a08f2
6 changed files with 527 additions and 27 deletions

View file

@ -132,19 +132,14 @@ def delete_agent(tenant_id: str, agent_id: str):
UserCanvasService.delete_by_id(agent_id)
return get_json_result(data=True)
_rate_limit_cache = {}
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
@token_required
async def webhook(tenant_id: str, agent_id: str):
req = await get_request_json()
if not UserCanvasService.accessible(req["id"], tenant_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
e, cvs = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
@manager.route('/webhook/<agent_id>', methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
async def webhook(agent_id: str):
# 1. Fetch canvas by agent_id
exists, cvs = UserCanvasService.get_by_id(agent_id)
if not exists:
return get_data_error_result(message="Canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
@ -152,8 +147,487 @@ async def webhook(tenant_id: str, agent_id: str):
if cvs.canvas_category == CanvasCategory.DataFlow:
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
# 3. Load DSL from canvas
dsl = getattr(cvs, "dsl", None)
if not isinstance(dsl, dict):
return get_data_error_result(message="Invalid DSL format.")
# 4. Check webhook configuration in DSL
components = dsl.get("webhook", {})
for k, cpn in components.items():
cpn_obj = components[k]["obj"]
if cpn_obj.component_name.lower() == "begin" and cpn.params.mode == "webhook":
webhook_cfg = cpn.params
if not webhook_cfg:
return get_data_error_result(message="Webhook not configured for this agent.")
# 5. Validate request method against webhook_cfg.methods
allowed_methods = webhook_cfg.get("methods", [])
request_method = request.method.upper()
if allowed_methods and request_method not in allowed_methods:
return get_data_error_result(
message=f"HTTP method '{request_method}' not allowed for this webhook."
)
# 6. Validate webhook security
async def validate_webhook_security(security_cfg: dict):
"""Validate webhook security rules based on security configuration."""
if not security_cfg:
return # No security config → allowed by default
# 1. Validate max body size
await _validate_max_body_size(security_cfg)
# 2. Validate IP whitelist
_validate_ip_whitelist(security_cfg)
# # 3. Validate rate limiting
_validate_rate_limit(security_cfg)
# 4. Validate authentication
auth_type = security_cfg.get("auth_type", "none")
if auth_type == "none":
return
if auth_type == "token":
_validate_token_auth(security_cfg)
elif auth_type == "basic":
_validate_basic_auth(security_cfg)
elif auth_type == "jwt":
_validate_jwt_auth(security_cfg)
elif auth_type == "hmac":
await _validate_hmac_auth(security_cfg)
else:
raise Exception(f"Unsupported auth_type: {auth_type}")
async def _validate_max_body_size(security_cfg):
"""Check request size does not exceed max_body_size."""
max_size = security_cfg.get("max_body_size")
if not max_size:
return
# Convert "10MB" → bytes
units = {"KB": 1024, "MB": 1024**2, "GB": 1024**3}
size_str = max_size.lower()
for suffix, factor in units.items():
if size_str.endswith(suffix):
limit = int(size_str.replace(suffix, "")) * factor
break
else:
raise Exception("Invalid max_body_size format")
content_length = request.content_length or 0
if content_length > limit:
raise Exception(f"Request body too large: {content_length} > {limit}")
def _validate_ip_whitelist(security_cfg):
"""Allow only IPs listed in ip_whitelist."""
whitelist = security_cfg.get("ip_whitelist", [])
if not whitelist:
return
client_ip = request.remote_addr
for rule in whitelist:
if "/" in rule:
# CIDR notation
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
return
else:
# Single IP
if client_ip == rule:
return
raise Exception(f"IP {client_ip} is not allowed by whitelist")
def _validate_rate_limit(security_cfg):
"""Simple in-memory rate limiting."""
rl = security_cfg.get("rate_limit")
if not rl:
return
limit = rl.get("limit", 60)
per = rl.get("per", "minute")
window = {"second": 1, "minute": 60, "hour": 3600, "day": 86400}.get(per, 60)
key = f"rl:{agent_id}"
now = int(time.time())
bucket = _rate_limit_cache.get(key, {"ts": now, "count": 0})
# Reset window
if now - bucket["ts"] > window:
bucket = {"ts": now, "count": 0}
bucket["count"] += 1
_rate_limit_cache[key] = bucket
if bucket["count"] > limit:
raise Exception("Too many requests (rate limit exceeded)")
def _validate_token_auth(security_cfg):
"""Validate header-based token authentication."""
token_cfg = security_cfg.get("token",{})
header = token_cfg.get("token_header")
token_value = token_cfg.get("token_value")
provided = request.headers.get(header)
if provided != token_value:
raise Exception("Invalid token authentication")
def _validate_basic_auth(security_cfg):
"""Validate HTTP Basic Auth credentials."""
auth_cfg = security_cfg.get("basic_auth", {})
username = auth_cfg.get("username")
password = auth_cfg.get("password")
auth = request.authorization
if not auth or auth.username != username or auth.password != password:
raise Exception("Invalid Basic Auth credentials")
def _validate_jwt_auth(security_cfg):
"""Validate JWT token in Authorization header."""
jwt_cfg = security_cfg.get("jwt", {})
secret = jwt_cfg.get("secret")
required_claims = jwt_cfg.get("required_claims", [])
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise Exception("Missing Bearer token")
token = auth_header.replace("Bearer ", "")
try:
decoded = jwt.decode(
token,
secret,
algorithms=[jwt_cfg.get("algorithm", "HS256")],
audience=jwt_cfg.get("audience"),
issuer=jwt_cfg.get("issuer"),
)
except Exception as e:
raise Exception(f"Invalid JWT: {str(e)}")
for claim in required_claims:
if claim not in decoded:
raise Exception(f"Missing JWT claim: {claim}")
async def _validate_hmac_auth(security_cfg):
"""Validate HMAC signature from header."""
hmac_cfg = security_cfg.get("hmac", {})
header = hmac_cfg.get("header")
secret = hmac_cfg.get("secret")
algorithm = hmac_cfg.get("algorithm", "sha256")
provided_sig = request.headers.get(header)
if not provided_sig:
raise Exception("Missing HMAC signature header")
body = await request.get_data()
if body is None:
body = b""
elif isinstance(body, str):
body = body.encode("utf-8")
computed = hmac.new(secret.encode(), body, getattr(hashlib, algorithm)).hexdigest()
if not hmac.compare_digest(provided_sig, computed):
raise Exception("Invalid HMAC signature")
try:
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
await validate_webhook_security(webhook_cfg.get("security", {}))
except Exception as e:
return get_data_error_result(message=str(e))
# 7. Parse request body
async def parse_webhook_request():
"""Parse request based on content-type and return structured data."""
# 1. Parse query parameters
query_data = {}
for k, v in request.args.items():
query_data[k] = v
# 2. Parse headers
header_data = {}
for k, v in request.headers.items():
header_data[k] = v
# 3. Parse body based on content-type
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
raw_files = {}
if ctype == "application/json":
try:
body_data = await request.get_json()
except:
body_data = None
elif ctype == "multipart/form-data":
form = await request.form
files = await request.files
raw_files = {name: file for name, file in files.items()}
body_data = {
"form": dict(form),
"files": {name: file.filename for name, file in files.items()},
}
elif ctype == "application/x-www-form-urlencoded":
form = await request.form
body_data = dict(form)
elif ctype == "text/plain":
body_data = (await request.get_data()).decode()
elif ctype == "application/octet-stream":
body_data = await request.get_data() # raw binary
else:
# unknown content type → raw body
body_data = await request.get_data()
return {
"query": query_data,
"headers": header_data,
"body": body_data,
"content_type": ctype,
"raw_files": raw_files
}
def extract_by_schema(data, schema, name="section"):
"""
Extract only fields defined in schema.
Required fields must exist.
Optional fields default to type-based default values.
Type validation included.
"""
if schema.get("type") != "object":
return {}
props = schema.get("properties", {})
required = schema.get("required", [])
extracted = {}
for field, field_schema in props.items():
field_type = field_schema.get("type")
# 1. Required field missing
if field in required and field not in data:
raise Exception(f"{name} missing required field: {field}")
# 2. Optional → default value
if field not in data:
extracted[field] = default_for_type(field_type)
continue
raw_value = data[field]
# 3. Auto convert value
try:
value = auto_cast_value(raw_value, field_type)
except Exception as e:
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
# 4. Type validation
if not validate_type(value, field_type):
raise Exception(
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
)
extracted[field] = value
return extracted
def default_for_type(t):
"""Return default value for the given schema type."""
if t == "file":
return ""
if t == "object":
return {}
if t == "boolean":
return False
if t == "number":
return 0
if t == "string":
return ""
if t and t.startswith("array"):
return []
if t == "null":
return None
return None
def auto_cast_value(value, expected_type):
"""Convert string values into schema type when possible."""
# Non-string values already good
if not isinstance(value, str):
return value
v = value.strip()
# Boolean
if expected_type == "boolean":
if v.lower() in ["true", "1"]:
return True
if v.lower() in ["false", "0"]:
return False
raise Exception(f"Cannot convert '{value}' to boolean")
# Number
if expected_type == "number":
# integer
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
return int(v)
# float
try:
return float(v)
except:
raise Exception(f"Cannot convert '{value}' to number")
# Object
if expected_type == "object":
try:
parsed = json.loads(v)
if isinstance(parsed, dict):
return parsed
else:
raise Exception("JSON is not an object")
except:
raise Exception(f"Cannot convert '{value}' to object")
# Array <T>
if expected_type.startswith("array"):
try:
parsed = json.loads(v)
if isinstance(parsed, list):
return parsed
else:
raise Exception("JSON is not an array")
except:
raise Exception(f"Cannot convert '{value}' to array")
# String (accept original)
if expected_type == "string":
return value
# File
if expected_type == "file":
return value
# Default: do nothing
return value
def validate_type(value, t):
"""Validate value type against schema type t."""
if t == "file":
return isinstance(value, str)
if t == "string":
return isinstance(value, str)
if t == "number":
return isinstance(value, (int, float))
if t == "boolean":
return isinstance(value, bool)
if t == "object":
return isinstance(value, dict)
# array<string> / array<number> / array<object>
if t.startswith("array"):
if not isinstance(value, list):
return False
if "<" in t and ">" in t:
inner = t[t.find("<") + 1 : t.find(">")]
# Check each element type
for item in value:
if not validate_type(item, inner):
return False
return True
return True
def extract_files_by_schema(raw_files, schema, name="files"):
"""
Extract and validate files based on schema.
Only supports type = file (single file).
Does NOT support array<file>.
"""
if schema.get("type") != "object":
return {}
props = schema.get("properties", {})
required = schema.get("required", [])
cleaned = []
for field, field_schema in props.items():
field_type = field_schema.get("type")
# 1. Required field must exist
if field in required and field not in raw_files:
raise Exception(f"{name} missing required file field: {field}")
# 2. Ignore fields that are not file
if field_type != "file":
continue
# 3. Extract single file
file_obj = raw_files.get(field)
if file_obj:
cleaned.append({
"field": field,
"file": file_obj
})
return cleaned
parsed = await parse_webhook_request()
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
# Extract strictly by schema
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
files_clean = extract_files_by_schema(parsed["raw_files"], SCHEMA.get("body", {}), name="files")
uploaded_files = []
for item in files_clean: # each {field, file}
file_obj = item["file"]
desc = FileService.upload_info(
cvs.user_id, # user
file_obj, # FileStorage
None # url (None for webhook)
)
uploaded_files.append(desc)
clean_request = {
"query": query_clean,
"headers": header_clean,
"body": body_clean
}
if not isinstance(cvs.dsl, str):
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
try:
canvas = Canvas(dsl, cvs.user_id, agent_id)
except Exception as e:
return get_json_result(
data=False, message=str(e),

View file

@ -1170,7 +1170,7 @@ async def mindmap():
search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {}
mind_map = gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
if "error" in mind_map:
return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map)

View file

@ -18,6 +18,7 @@ import time
from typing import Any, Dict, Optional
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
from common import settings
import httpx
logger = logging.getLogger(__name__)
@ -73,6 +74,32 @@ def _redact_sensitive_url_params(url: str) -> str:
except Exception:
return url
def _is_sensitive_url(url: str) -> bool:
"""Return True if URL is one of the configured OAuth endpoints."""
# Collect known sensitive endpoint URLs from settings
oauth_urls = set()
# GitHub OAuth endpoints
try:
if settings.GITHUB_OAUTH is not None:
url_val = settings.GITHUB_OAUTH.get("url")
if url_val: oauth_urls.add(url_val)
except Exception:
pass
# Feishu OAuth endpoints
try:
if settings.FEISHU_OAUTH is not None:
for k in ("app_access_token_url", "user_access_token_url"):
url_val = settings.FEISHU_OAUTH.get(k)
if url_val: oauth_urls.add(url_val)
except Exception:
pass
# Defensive normalization: compare only scheme+netloc+path
url_obj = urlparse(url)
for sensitive_url in oauth_urls:
sensitive_obj = urlparse(sensitive_url)
if (url_obj.scheme, url_obj.netloc, url_obj.path) == (sensitive_obj.scheme, sensitive_obj.netloc, sensitive_obj.path):
return True
return False
async def async_request(
method: str,
@ -115,20 +142,23 @@ async def async_request(
method=method, url=url, headers=headers, **kwargs
)
duration = time.monotonic() - start
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url else _redact_sensitive_url_params(url)
logger.debug(
f"async_request {method} {_redact_sensitive_url_params(url)} -> {response.status_code} in {duration:.3f}s"
f"async_request {method} {log_url} -> {response.status_code} in {duration:.3f}s"
)
return response
except httpx.RequestError as exc:
last_exc = exc
if attempt >= retries:
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url else _redact_sensitive_url_params(url)
logger.warning(
f"async_request exhausted retries for {method} {_redact_sensitive_url_params(url)}: {exc}"
f"async_request exhausted retries for {method} {log_url}"
)
raise
delay = _get_delay(backoff_factor, attempt)
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url else _redact_sensitive_url_params(url)
logger.warning(
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {_redact_sensitive_url_params(url)}: {exc}; retrying in {delay:.2f}s"
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {log_url}; retrying in {delay:.2f}s"
)
await asyncio.sleep(delay)
raise last_exc # pragma: no cover

View file

@ -44,7 +44,7 @@ class KGSearch(Dealer):
return response
def query_rewrite(self, llm, question, idxnms, kb_ids):
ty2ents = asyncio.run(get_entity_type2samples(idxnms, kb_ids))
ty2ents = get_entity_type2samples(idxnms, kb_ids)
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})

View file

@ -626,8 +626,8 @@ def merge_tuples(list1, list2):
return result
async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
def get_entity_type2samples(idxnms, kb_ids: list):
es_res = settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
res = defaultdict(list)
for id in es_res.ids:

View file

@ -41,13 +41,9 @@ def get_opendal_config():
scheme = opendal_config.get("scheme")
config_data = opendal_config.get("config", {})
kwargs = {"scheme": scheme, **config_data}
redacted_kwargs = kwargs.copy()
if 'password' in redacted_kwargs:
redacted_kwargs['password'] = '***REDACTED***'
if 'connection_string' in redacted_kwargs and 'password' in redacted_kwargs:
import re
redacted_kwargs['connection_string'] = re.sub(r':[^@]+@', ':***REDACTED***@', redacted_kwargs['connection_string'])
logging.info("Loaded OpenDAL configuration from yaml: %s", redacted_kwargs)
safe_log_keys=['scheme', 'host', 'port', 'database', 'table']
loggable_kwargs = {k: v for k, v in kwargs.items() if k in safe_log_keys}
logging.info("Loaded OpenDAL configuration(non sensitive): %s", loggable_kwargs)
return kwargs
except Exception as e:
logging.error("Failed to load OpenDAL configuration from yaml: %s", str(e))