fix:async issue and sensitive logging
This commit is contained in:
parent
c610bb605a
commit
c6ca4a08f2
6 changed files with 527 additions and 27 deletions
|
|
@ -132,19 +132,14 @@ def delete_agent(tenant_id: str, agent_id: str):
|
||||||
UserCanvasService.delete_by_id(agent_id)
|
UserCanvasService.delete_by_id(agent_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
_rate_limit_cache = {}
|
||||||
|
|
||||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
@manager.route('/webhook/<agent_id>', methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
|
||||||
@token_required
|
async def webhook(agent_id: str):
|
||||||
async def webhook(tenant_id: str, agent_id: str):
|
# 1. Fetch canvas by agent_id
|
||||||
req = await get_request_json()
|
exists, cvs = UserCanvasService.get_by_id(agent_id)
|
||||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
if not exists:
|
||||||
return get_json_result(
|
return get_data_error_result(message="Canvas not found.")
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
|
||||||
code=RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
|
||||||
if not e:
|
|
||||||
return get_data_error_result(message="canvas not found.")
|
|
||||||
|
|
||||||
if not isinstance(cvs.dsl, str):
|
if not isinstance(cvs.dsl, str):
|
||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
|
@ -152,8 +147,487 @@ async def webhook(tenant_id: str, agent_id: str):
|
||||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||||
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
|
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
|
||||||
|
|
||||||
|
# 3. Load DSL from canvas
|
||||||
|
dsl = getattr(cvs, "dsl", None)
|
||||||
|
if not isinstance(dsl, dict):
|
||||||
|
return get_data_error_result(message="Invalid DSL format.")
|
||||||
|
|
||||||
|
# 4. Check webhook configuration in DSL
|
||||||
|
components = dsl.get("webhook", {})
|
||||||
|
for k, cpn in components.items():
|
||||||
|
cpn_obj = components[k]["obj"]
|
||||||
|
if cpn_obj.component_name.lower() == "begin" and cpn.params.mode == "webhook":
|
||||||
|
webhook_cfg = cpn.params
|
||||||
|
|
||||||
|
if not webhook_cfg:
|
||||||
|
return get_data_error_result(message="Webhook not configured for this agent.")
|
||||||
|
|
||||||
|
# 5. Validate request method against webhook_cfg.methods
|
||||||
|
allowed_methods = webhook_cfg.get("methods", [])
|
||||||
|
request_method = request.method.upper()
|
||||||
|
if allowed_methods and request_method not in allowed_methods:
|
||||||
|
return get_data_error_result(
|
||||||
|
message=f"HTTP method '{request_method}' not allowed for this webhook."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Validate webhook security
|
||||||
|
async def validate_webhook_security(security_cfg: dict):
|
||||||
|
"""Validate webhook security rules based on security configuration."""
|
||||||
|
|
||||||
|
if not security_cfg:
|
||||||
|
return # No security config → allowed by default
|
||||||
|
|
||||||
|
# 1. Validate max body size
|
||||||
|
await _validate_max_body_size(security_cfg)
|
||||||
|
|
||||||
|
# 2. Validate IP whitelist
|
||||||
|
_validate_ip_whitelist(security_cfg)
|
||||||
|
|
||||||
|
# # 3. Validate rate limiting
|
||||||
|
_validate_rate_limit(security_cfg)
|
||||||
|
|
||||||
|
# 4. Validate authentication
|
||||||
|
auth_type = security_cfg.get("auth_type", "none")
|
||||||
|
|
||||||
|
if auth_type == "none":
|
||||||
|
return
|
||||||
|
|
||||||
|
if auth_type == "token":
|
||||||
|
_validate_token_auth(security_cfg)
|
||||||
|
|
||||||
|
elif auth_type == "basic":
|
||||||
|
_validate_basic_auth(security_cfg)
|
||||||
|
|
||||||
|
elif auth_type == "jwt":
|
||||||
|
_validate_jwt_auth(security_cfg)
|
||||||
|
|
||||||
|
elif auth_type == "hmac":
|
||||||
|
await _validate_hmac_auth(security_cfg)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported auth_type: {auth_type}")
|
||||||
|
|
||||||
|
async def _validate_max_body_size(security_cfg):
|
||||||
|
"""Check request size does not exceed max_body_size."""
|
||||||
|
max_size = security_cfg.get("max_body_size")
|
||||||
|
if not max_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert "10MB" → bytes
|
||||||
|
units = {"KB": 1024, "MB": 1024**2, "GB": 1024**3}
|
||||||
|
size_str = max_size.lower()
|
||||||
|
|
||||||
|
for suffix, factor in units.items():
|
||||||
|
if size_str.endswith(suffix):
|
||||||
|
limit = int(size_str.replace(suffix, "")) * factor
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid max_body_size format")
|
||||||
|
|
||||||
|
content_length = request.content_length or 0
|
||||||
|
if content_length > limit:
|
||||||
|
raise Exception(f"Request body too large: {content_length} > {limit}")
|
||||||
|
|
||||||
|
def _validate_ip_whitelist(security_cfg):
|
||||||
|
"""Allow only IPs listed in ip_whitelist."""
|
||||||
|
whitelist = security_cfg.get("ip_whitelist", [])
|
||||||
|
if not whitelist:
|
||||||
|
return
|
||||||
|
|
||||||
|
client_ip = request.remote_addr
|
||||||
|
|
||||||
|
|
||||||
|
for rule in whitelist:
|
||||||
|
if "/" in rule:
|
||||||
|
# CIDR notation
|
||||||
|
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Single IP
|
||||||
|
if client_ip == rule:
|
||||||
|
return
|
||||||
|
|
||||||
|
raise Exception(f"IP {client_ip} is not allowed by whitelist")
|
||||||
|
|
||||||
|
def _validate_rate_limit(security_cfg):
|
||||||
|
"""Simple in-memory rate limiting."""
|
||||||
|
rl = security_cfg.get("rate_limit")
|
||||||
|
if not rl:
|
||||||
|
return
|
||||||
|
|
||||||
|
limit = rl.get("limit", 60)
|
||||||
|
per = rl.get("per", "minute")
|
||||||
|
|
||||||
|
window = {"second": 1, "minute": 60, "hour": 3600, "day": 86400}.get(per, 60)
|
||||||
|
key = f"rl:{agent_id}"
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
bucket = _rate_limit_cache.get(key, {"ts": now, "count": 0})
|
||||||
|
|
||||||
|
# Reset window
|
||||||
|
if now - bucket["ts"] > window:
|
||||||
|
bucket = {"ts": now, "count": 0}
|
||||||
|
|
||||||
|
bucket["count"] += 1
|
||||||
|
_rate_limit_cache[key] = bucket
|
||||||
|
|
||||||
|
if bucket["count"] > limit:
|
||||||
|
raise Exception("Too many requests (rate limit exceeded)")
|
||||||
|
|
||||||
|
def _validate_token_auth(security_cfg):
|
||||||
|
"""Validate header-based token authentication."""
|
||||||
|
token_cfg = security_cfg.get("token",{})
|
||||||
|
header = token_cfg.get("token_header")
|
||||||
|
token_value = token_cfg.get("token_value")
|
||||||
|
|
||||||
|
provided = request.headers.get(header)
|
||||||
|
if provided != token_value:
|
||||||
|
raise Exception("Invalid token authentication")
|
||||||
|
|
||||||
|
def _validate_basic_auth(security_cfg):
|
||||||
|
"""Validate HTTP Basic Auth credentials."""
|
||||||
|
auth_cfg = security_cfg.get("basic_auth", {})
|
||||||
|
username = auth_cfg.get("username")
|
||||||
|
password = auth_cfg.get("password")
|
||||||
|
|
||||||
|
auth = request.authorization
|
||||||
|
if not auth or auth.username != username or auth.password != password:
|
||||||
|
raise Exception("Invalid Basic Auth credentials")
|
||||||
|
|
||||||
|
def _validate_jwt_auth(security_cfg):
|
||||||
|
"""Validate JWT token in Authorization header."""
|
||||||
|
jwt_cfg = security_cfg.get("jwt", {})
|
||||||
|
secret = jwt_cfg.get("secret")
|
||||||
|
required_claims = jwt_cfg.get("required_claims", [])
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
raise Exception("Missing Bearer token")
|
||||||
|
|
||||||
|
token = auth_header.replace("Bearer ", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
decoded = jwt.decode(
|
||||||
|
token,
|
||||||
|
secret,
|
||||||
|
algorithms=[jwt_cfg.get("algorithm", "HS256")],
|
||||||
|
audience=jwt_cfg.get("audience"),
|
||||||
|
issuer=jwt_cfg.get("issuer"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Invalid JWT: {str(e)}")
|
||||||
|
|
||||||
|
for claim in required_claims:
|
||||||
|
if claim not in decoded:
|
||||||
|
raise Exception(f"Missing JWT claim: {claim}")
|
||||||
|
|
||||||
|
async def _validate_hmac_auth(security_cfg):
|
||||||
|
"""Validate HMAC signature from header."""
|
||||||
|
hmac_cfg = security_cfg.get("hmac", {})
|
||||||
|
header = hmac_cfg.get("header")
|
||||||
|
secret = hmac_cfg.get("secret")
|
||||||
|
algorithm = hmac_cfg.get("algorithm", "sha256")
|
||||||
|
|
||||||
|
provided_sig = request.headers.get(header)
|
||||||
|
if not provided_sig:
|
||||||
|
raise Exception("Missing HMAC signature header")
|
||||||
|
|
||||||
|
body = await request.get_data()
|
||||||
|
if body is None:
|
||||||
|
body = b""
|
||||||
|
elif isinstance(body, str):
|
||||||
|
body = body.encode("utf-8")
|
||||||
|
|
||||||
|
computed = hmac.new(secret.encode(), body, getattr(hashlib, algorithm)).hexdigest()
|
||||||
|
|
||||||
|
if not hmac.compare_digest(provided_sig, computed):
|
||||||
|
raise Exception("Invalid HMAC signature")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
await validate_webhook_security(webhook_cfg.get("security", {}))
|
||||||
|
except Exception as e:
|
||||||
|
return get_data_error_result(message=str(e))
|
||||||
|
|
||||||
|
# 7. Parse request body
|
||||||
|
async def parse_webhook_request():
|
||||||
|
"""Parse request based on content-type and return structured data."""
|
||||||
|
|
||||||
|
# 1. Parse query parameters
|
||||||
|
query_data = {}
|
||||||
|
for k, v in request.args.items():
|
||||||
|
query_data[k] = v
|
||||||
|
|
||||||
|
# 2. Parse headers
|
||||||
|
header_data = {}
|
||||||
|
for k, v in request.headers.items():
|
||||||
|
header_data[k] = v
|
||||||
|
|
||||||
|
# 3. Parse body based on content-type
|
||||||
|
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
|
||||||
|
raw_files = {}
|
||||||
|
|
||||||
|
if ctype == "application/json":
|
||||||
|
try:
|
||||||
|
body_data = await request.get_json()
|
||||||
|
except:
|
||||||
|
body_data = None
|
||||||
|
|
||||||
|
elif ctype == "multipart/form-data":
|
||||||
|
form = await request.form
|
||||||
|
files = await request.files
|
||||||
|
raw_files = {name: file for name, file in files.items()}
|
||||||
|
body_data = {
|
||||||
|
"form": dict(form),
|
||||||
|
"files": {name: file.filename for name, file in files.items()},
|
||||||
|
}
|
||||||
|
|
||||||
|
elif ctype == "application/x-www-form-urlencoded":
|
||||||
|
form = await request.form
|
||||||
|
body_data = dict(form)
|
||||||
|
|
||||||
|
elif ctype == "text/plain":
|
||||||
|
body_data = (await request.get_data()).decode()
|
||||||
|
|
||||||
|
elif ctype == "application/octet-stream":
|
||||||
|
body_data = await request.get_data() # raw binary
|
||||||
|
|
||||||
|
else:
|
||||||
|
# unknown content type → raw body
|
||||||
|
body_data = await request.get_data()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query": query_data,
|
||||||
|
"headers": header_data,
|
||||||
|
"body": body_data,
|
||||||
|
"content_type": ctype,
|
||||||
|
"raw_files": raw_files
|
||||||
|
}
|
||||||
|
|
||||||
|
def extract_by_schema(data, schema, name="section"):
|
||||||
|
"""
|
||||||
|
Extract only fields defined in schema.
|
||||||
|
Required fields must exist.
|
||||||
|
Optional fields default to type-based default values.
|
||||||
|
Type validation included.
|
||||||
|
"""
|
||||||
|
if schema.get("type") != "object":
|
||||||
|
return {}
|
||||||
|
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
required = schema.get("required", [])
|
||||||
|
|
||||||
|
extracted = {}
|
||||||
|
|
||||||
|
for field, field_schema in props.items():
|
||||||
|
field_type = field_schema.get("type")
|
||||||
|
|
||||||
|
# 1. Required field missing
|
||||||
|
if field in required and field not in data:
|
||||||
|
raise Exception(f"{name} missing required field: {field}")
|
||||||
|
|
||||||
|
# 2. Optional → default value
|
||||||
|
if field not in data:
|
||||||
|
extracted[field] = default_for_type(field_type)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_value = data[field]
|
||||||
|
|
||||||
|
# 3. Auto convert value
|
||||||
|
try:
|
||||||
|
value = auto_cast_value(raw_value, field_type)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
|
||||||
|
|
||||||
|
# 4. Type validation
|
||||||
|
if not validate_type(value, field_type):
|
||||||
|
raise Exception(
|
||||||
|
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted[field] = value
|
||||||
|
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
|
||||||
|
def default_for_type(t):
|
||||||
|
"""Return default value for the given schema type."""
|
||||||
|
if t == "file":
|
||||||
|
return ""
|
||||||
|
if t == "object":
|
||||||
|
return {}
|
||||||
|
if t == "boolean":
|
||||||
|
return False
|
||||||
|
if t == "number":
|
||||||
|
return 0
|
||||||
|
if t == "string":
|
||||||
|
return ""
|
||||||
|
if t and t.startswith("array"):
|
||||||
|
return []
|
||||||
|
if t == "null":
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def auto_cast_value(value, expected_type):
|
||||||
|
"""Convert string values into schema type when possible."""
|
||||||
|
|
||||||
|
# Non-string values already good
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return value
|
||||||
|
|
||||||
|
v = value.strip()
|
||||||
|
|
||||||
|
# Boolean
|
||||||
|
if expected_type == "boolean":
|
||||||
|
if v.lower() in ["true", "1"]:
|
||||||
|
return True
|
||||||
|
if v.lower() in ["false", "0"]:
|
||||||
|
return False
|
||||||
|
raise Exception(f"Cannot convert '{value}' to boolean")
|
||||||
|
|
||||||
|
# Number
|
||||||
|
if expected_type == "number":
|
||||||
|
# integer
|
||||||
|
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
|
||||||
|
return int(v)
|
||||||
|
|
||||||
|
# float
|
||||||
|
try:
|
||||||
|
return float(v)
|
||||||
|
except:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to number")
|
||||||
|
|
||||||
|
# Object
|
||||||
|
if expected_type == "object":
|
||||||
|
try:
|
||||||
|
parsed = json.loads(v)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return parsed
|
||||||
|
else:
|
||||||
|
raise Exception("JSON is not an object")
|
||||||
|
except:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to object")
|
||||||
|
|
||||||
|
# Array <T>
|
||||||
|
if expected_type.startswith("array"):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(v)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
return parsed
|
||||||
|
else:
|
||||||
|
raise Exception("JSON is not an array")
|
||||||
|
except:
|
||||||
|
raise Exception(f"Cannot convert '{value}' to array")
|
||||||
|
|
||||||
|
# String (accept original)
|
||||||
|
if expected_type == "string":
|
||||||
|
return value
|
||||||
|
|
||||||
|
# File
|
||||||
|
if expected_type == "file":
|
||||||
|
return value
|
||||||
|
# Default: do nothing
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def validate_type(value, t):
|
||||||
|
"""Validate value type against schema type t."""
|
||||||
|
if t == "file":
|
||||||
|
return isinstance(value, str)
|
||||||
|
|
||||||
|
if t == "string":
|
||||||
|
return isinstance(value, str)
|
||||||
|
|
||||||
|
if t == "number":
|
||||||
|
return isinstance(value, (int, float))
|
||||||
|
|
||||||
|
if t == "boolean":
|
||||||
|
return isinstance(value, bool)
|
||||||
|
|
||||||
|
if t == "object":
|
||||||
|
return isinstance(value, dict)
|
||||||
|
|
||||||
|
# array<string> / array<number> / array<object>
|
||||||
|
if t.startswith("array"):
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if "<" in t and ">" in t:
|
||||||
|
inner = t[t.find("<") + 1 : t.find(">")]
|
||||||
|
|
||||||
|
# Check each element type
|
||||||
|
for item in value:
|
||||||
|
if not validate_type(item, inner):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def extract_files_by_schema(raw_files, schema, name="files"):
|
||||||
|
"""
|
||||||
|
Extract and validate files based on schema.
|
||||||
|
Only supports type = file (single file).
|
||||||
|
Does NOT support array<file>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if schema.get("type") != "object":
|
||||||
|
return {}
|
||||||
|
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
required = schema.get("required", [])
|
||||||
|
|
||||||
|
cleaned = []
|
||||||
|
|
||||||
|
for field, field_schema in props.items():
|
||||||
|
field_type = field_schema.get("type")
|
||||||
|
|
||||||
|
# 1. Required field must exist
|
||||||
|
if field in required and field not in raw_files:
|
||||||
|
raise Exception(f"{name} missing required file field: {field}")
|
||||||
|
|
||||||
|
# 2. Ignore fields that are not file
|
||||||
|
if field_type != "file":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 3. Extract single file
|
||||||
|
file_obj = raw_files.get(field)
|
||||||
|
|
||||||
|
if file_obj:
|
||||||
|
cleaned.append({
|
||||||
|
"field": field,
|
||||||
|
"file": file_obj
|
||||||
|
})
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
parsed = await parse_webhook_request()
|
||||||
|
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
|
||||||
|
|
||||||
|
# Extract strictly by schema
|
||||||
|
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
|
||||||
|
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
|
||||||
|
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
|
||||||
|
files_clean = extract_files_by_schema(parsed["raw_files"], SCHEMA.get("body", {}), name="files")
|
||||||
|
|
||||||
|
uploaded_files = []
|
||||||
|
for item in files_clean: # each {field, file}
|
||||||
|
file_obj = item["file"]
|
||||||
|
desc = FileService.upload_info(
|
||||||
|
cvs.user_id, # user
|
||||||
|
file_obj, # FileStorage
|
||||||
|
None # url (None for webhook)
|
||||||
|
)
|
||||||
|
uploaded_files.append(desc)
|
||||||
|
|
||||||
|
clean_request = {
|
||||||
|
"query": query_clean,
|
||||||
|
"headers": header_clean,
|
||||||
|
"body": body_clean
|
||||||
|
}
|
||||||
|
|
||||||
|
if not isinstance(cvs.dsl, str):
|
||||||
|
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
try:
|
||||||
|
canvas = Canvas(dsl, cvs.user_id, agent_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message=str(e),
|
data=False, message=str(e),
|
||||||
|
|
|
||||||
|
|
@ -1170,7 +1170,7 @@ async def mindmap():
|
||||||
search_id = req.get("search_id", "")
|
search_id = req.get("search_id", "")
|
||||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||||
|
|
||||||
mind_map = gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
|
mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
|
||||||
if "error" in mind_map:
|
if "error" in mind_map:
|
||||||
return server_error_response(Exception(mind_map["error"]))
|
return server_error_response(Exception(mind_map["error"]))
|
||||||
return get_json_result(data=mind_map)
|
return get_json_result(data=mind_map)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import time
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
||||||
|
|
||||||
|
from common import settings
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -73,6 +74,32 @@ def _redact_sensitive_url_params(url: str) -> str:
|
||||||
except Exception:
|
except Exception:
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
def _is_sensitive_url(url: str) -> bool:
|
||||||
|
"""Return True if URL is one of the configured OAuth endpoints."""
|
||||||
|
# Collect known sensitive endpoint URLs from settings
|
||||||
|
oauth_urls = set()
|
||||||
|
# GitHub OAuth endpoints
|
||||||
|
try:
|
||||||
|
if settings.GITHUB_OAUTH is not None:
|
||||||
|
url_val = settings.GITHUB_OAUTH.get("url")
|
||||||
|
if url_val: oauth_urls.add(url_val)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Feishu OAuth endpoints
|
||||||
|
try:
|
||||||
|
if settings.FEISHU_OAUTH is not None:
|
||||||
|
for k in ("app_access_token_url", "user_access_token_url"):
|
||||||
|
url_val = settings.FEISHU_OAUTH.get(k)
|
||||||
|
if url_val: oauth_urls.add(url_val)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Defensive normalization: compare only scheme+netloc+path
|
||||||
|
url_obj = urlparse(url)
|
||||||
|
for sensitive_url in oauth_urls:
|
||||||
|
sensitive_obj = urlparse(sensitive_url)
|
||||||
|
if (url_obj.scheme, url_obj.netloc, url_obj.path) == (sensitive_obj.scheme, sensitive_obj.netloc, sensitive_obj.path):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def async_request(
|
async def async_request(
|
||||||
method: str,
|
method: str,
|
||||||
|
|
@ -115,20 +142,23 @@ async def async_request(
|
||||||
method=method, url=url, headers=headers, **kwargs
|
method=method, url=url, headers=headers, **kwargs
|
||||||
)
|
)
|
||||||
duration = time.monotonic() - start
|
duration = time.monotonic() - start
|
||||||
|
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url else _redact_sensitive_url_params(url)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"async_request {method} {_redact_sensitive_url_params(url)} -> {response.status_code} in {duration:.3f}s"
|
f"async_request {method} {log_url} -> {response.status_code} in {duration:.3f}s"
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except httpx.RequestError as exc:
|
except httpx.RequestError as exc:
|
||||||
last_exc = exc
|
last_exc = exc
|
||||||
if attempt >= retries:
|
if attempt >= retries:
|
||||||
|
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url else _redact_sensitive_url_params(url)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"async_request exhausted retries for {method} {_redact_sensitive_url_params(url)}: {exc}"
|
f"async_request exhausted retries for {method} {log_url}"
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
delay = _get_delay(backoff_factor, attempt)
|
delay = _get_delay(backoff_factor, attempt)
|
||||||
|
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url else _redact_sensitive_url_params(url)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {_redact_sensitive_url_params(url)}: {exc}; retrying in {delay:.2f}s"
|
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {log_url}; retrying in {delay:.2f}s"
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
raise last_exc # pragma: no cover
|
raise last_exc # pragma: no cover
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class KGSearch(Dealer):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def query_rewrite(self, llm, question, idxnms, kb_ids):
|
def query_rewrite(self, llm, question, idxnms, kb_ids):
|
||||||
ty2ents = asyncio.run(get_entity_type2samples(idxnms, kb_ids))
|
ty2ents = get_entity_type2samples(idxnms, kb_ids)
|
||||||
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
|
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
|
||||||
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
|
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
|
||||||
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})
|
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})
|
||||||
|
|
|
||||||
|
|
@ -626,8 +626,8 @@ def merge_tuples(list1, list2):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def get_entity_type2samples(idxnms, kb_ids: list):
|
def get_entity_type2samples(idxnms, kb_ids: list):
|
||||||
es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
|
es_res = settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
|
||||||
|
|
||||||
res = defaultdict(list)
|
res = defaultdict(list)
|
||||||
for id in es_res.ids:
|
for id in es_res.ids:
|
||||||
|
|
|
||||||
|
|
@ -41,13 +41,9 @@ def get_opendal_config():
|
||||||
scheme = opendal_config.get("scheme")
|
scheme = opendal_config.get("scheme")
|
||||||
config_data = opendal_config.get("config", {})
|
config_data = opendal_config.get("config", {})
|
||||||
kwargs = {"scheme": scheme, **config_data}
|
kwargs = {"scheme": scheme, **config_data}
|
||||||
redacted_kwargs = kwargs.copy()
|
safe_log_keys=['scheme', 'host', 'port', 'database', 'table']
|
||||||
if 'password' in redacted_kwargs:
|
loggable_kwargs = {k: v for k, v in kwargs.items() if k in safe_log_keys}
|
||||||
redacted_kwargs['password'] = '***REDACTED***'
|
logging.info("Loaded OpenDAL configuration(non sensitive): %s", loggable_kwargs)
|
||||||
if 'connection_string' in redacted_kwargs and 'password' in redacted_kwargs:
|
|
||||||
import re
|
|
||||||
redacted_kwargs['connection_string'] = re.sub(r':[^@]+@', ':***REDACTED***@', redacted_kwargs['connection_string'])
|
|
||||||
logging.info("Loaded OpenDAL configuration from yaml: %s", redacted_kwargs)
|
|
||||||
return kwargs
|
return kwargs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Failed to load OpenDAL configuration from yaml: %s", str(e))
|
logging.error("Failed to load OpenDAL configuration from yaml: %s", str(e))
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue