diff --git a/api/apps/sdk/agents.py b/api/apps/sdk/agents.py index 20e897388..6d3660ecf 100644 --- a/api/apps/sdk/agents.py +++ b/api/apps/sdk/agents.py @@ -132,19 +132,14 @@ def delete_agent(tenant_id: str, agent_id: str): UserCanvasService.delete_by_id(agent_id) return get_json_result(data=True) +_rate_limit_cache = {} -@manager.route('/webhook/', methods=['POST']) # noqa: F821 -@token_required -async def webhook(tenant_id: str, agent_id: str): - req = await get_request_json() - if not UserCanvasService.accessible(req["id"], tenant_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) - - e, cvs = UserCanvasService.get_by_id(req["id"]) - if not e: - return get_data_error_result(message="canvas not found.") +@manager.route('/webhook/', methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821 +async def webhook(agent_id: str): + # 1. Fetch canvas by agent_id + exists, cvs = UserCanvasService.get_by_id(agent_id) + if not exists: + return get_data_error_result(message="Canvas not found.") if not isinstance(cvs.dsl, str): cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) @@ -152,8 +147,487 @@ async def webhook(tenant_id: str, agent_id: str): if cvs.canvas_category == CanvasCategory.DataFlow: return get_data_error_result(message="Dataflow can not be triggered by webhook.") + # 3. Load DSL from canvas + dsl = getattr(cvs, "dsl", None) + if not isinstance(dsl, dict): + return get_data_error_result(message="Invalid DSL format.") + + # 4. Check webhook configuration in DSL + components = dsl.get("webhook", {}) + for k, cpn in components.items(): + cpn_obj = components[k]["obj"] + if cpn_obj.component_name.lower() == "begin" and cpn.params.mode == "webhook": + webhook_cfg = cpn.params + + if not webhook_cfg: + return get_data_error_result(message="Webhook not configured for this agent.") + + # 5. Validate request method against webhook_cfg.methods + allowed_methods = webhook_cfg.get("methods", []) + request_method = request.method.upper() + if allowed_methods and request_method not in allowed_methods: + return get_data_error_result( + message=f"HTTP method '{request_method}' not allowed for this webhook." + ) + + # 6. Validate webhook security + async def validate_webhook_security(security_cfg: dict): + """Validate webhook security rules based on security configuration.""" + + if not security_cfg: + return # No security config → allowed by default + + # 1. Validate max body size + await _validate_max_body_size(security_cfg) + + # 2. Validate IP whitelist + _validate_ip_whitelist(security_cfg) + + # # 3. Validate rate limiting + _validate_rate_limit(security_cfg) + + # 4. Validate authentication + auth_type = security_cfg.get("auth_type", "none") + + if auth_type == "none": + return + + if auth_type == "token": + _validate_token_auth(security_cfg) + + elif auth_type == "basic": + _validate_basic_auth(security_cfg) + + elif auth_type == "jwt": + _validate_jwt_auth(security_cfg) + + elif auth_type == "hmac": + await _validate_hmac_auth(security_cfg) + + else: + raise Exception(f"Unsupported auth_type: {auth_type}") + + async def _validate_max_body_size(security_cfg): + """Check request size does not exceed max_body_size.""" + max_size = security_cfg.get("max_body_size") + if not max_size: + return + + # Convert "10MB" → bytes + units = {"KB": 1024, "MB": 1024**2, "GB": 1024**3} + size_str = max_size.lower() + + for suffix, factor in units.items(): + if size_str.endswith(suffix): + limit = int(size_str.replace(suffix, "")) * factor + break + else: + raise Exception("Invalid max_body_size format") + + content_length = request.content_length or 0 + if content_length > limit: + raise Exception(f"Request body too large: {content_length} > {limit}") + + def _validate_ip_whitelist(security_cfg): + """Allow only IPs listed in ip_whitelist.""" + whitelist = security_cfg.get("ip_whitelist", []) + if not whitelist: + return + + client_ip = request.remote_addr + + + for rule in whitelist: + if "/" in rule: + # CIDR notation + if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False): + return + else: + # Single IP + if client_ip == rule: + return + + raise Exception(f"IP {client_ip} is not allowed by whitelist") + + def _validate_rate_limit(security_cfg): + """Simple in-memory rate limiting.""" + rl = security_cfg.get("rate_limit") + if not rl: + return + + limit = rl.get("limit", 60) + per = rl.get("per", "minute") + + window = {"second": 1, "minute": 60, "hour": 3600, "day": 86400}.get(per, 60) + key = f"rl:{agent_id}" + + now = int(time.time()) + bucket = _rate_limit_cache.get(key, {"ts": now, "count": 0}) + + # Reset window + if now - bucket["ts"] > window: + bucket = {"ts": now, "count": 0} + + bucket["count"] += 1 + _rate_limit_cache[key] = bucket + + if bucket["count"] > limit: + raise Exception("Too many requests (rate limit exceeded)") + + def _validate_token_auth(security_cfg): + """Validate header-based token authentication.""" + token_cfg = security_cfg.get("token",{}) + header = token_cfg.get("token_header") + token_value = token_cfg.get("token_value") + + provided = request.headers.get(header) + if provided != token_value: + raise Exception("Invalid token authentication") + + def _validate_basic_auth(security_cfg): + """Validate HTTP Basic Auth credentials.""" + auth_cfg = security_cfg.get("basic_auth", {}) + username = auth_cfg.get("username") + password = auth_cfg.get("password") + + auth = request.authorization + if not auth or auth.username != username or auth.password != password: + raise Exception("Invalid Basic Auth credentials") + + def _validate_jwt_auth(security_cfg): + """Validate JWT token in Authorization header.""" + jwt_cfg = security_cfg.get("jwt", {}) + secret = jwt_cfg.get("secret") + required_claims = jwt_cfg.get("required_claims", []) + + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + raise Exception("Missing Bearer token") + + token = auth_header.replace("Bearer ", "") + + try: + decoded = jwt.decode( + token, + secret, + algorithms=[jwt_cfg.get("algorithm", "HS256")], + audience=jwt_cfg.get("audience"), + issuer=jwt_cfg.get("issuer"), + ) + except Exception as e: + raise Exception(f"Invalid JWT: {str(e)}") + + for claim in required_claims: + if claim not in decoded: + raise Exception(f"Missing JWT claim: {claim}") + + async def _validate_hmac_auth(security_cfg): + """Validate HMAC signature from header.""" + hmac_cfg = security_cfg.get("hmac", {}) + header = hmac_cfg.get("header") + secret = hmac_cfg.get("secret") + algorithm = hmac_cfg.get("algorithm", "sha256") + + provided_sig = request.headers.get(header) + if not provided_sig: + raise Exception("Missing HMAC signature header") + + body = await request.get_data() + if body is None: + body = b"" + elif isinstance(body, str): + body = body.encode("utf-8") + + computed = hmac.new(secret.encode(), body, getattr(hashlib, algorithm)).hexdigest() + + if not hmac.compare_digest(provided_sig, computed): + raise Exception("Invalid HMAC signature") + try: - canvas = Canvas(cvs.dsl, tenant_id, agent_id) + await validate_webhook_security(webhook_cfg.get("security", {})) + except Exception as e: + return get_data_error_result(message=str(e)) + + # 7. Parse request body + async def parse_webhook_request(): + """Parse request based on content-type and return structured data.""" + + # 1. Parse query parameters + query_data = {} + for k, v in request.args.items(): + query_data[k] = v + + # 2. Parse headers + header_data = {} + for k, v in request.headers.items(): + header_data[k] = v + + # 3. Parse body based on content-type + ctype = request.headers.get("Content-Type", "").split(";")[0].strip() + raw_files = {} + + if ctype == "application/json": + try: + body_data = await request.get_json() + except: + body_data = None + + elif ctype == "multipart/form-data": + form = await request.form + files = await request.files + raw_files = {name: file for name, file in files.items()} + body_data = { + "form": dict(form), + "files": {name: file.filename for name, file in files.items()}, + } + + elif ctype == "application/x-www-form-urlencoded": + form = await request.form + body_data = dict(form) + + elif ctype == "text/plain": + body_data = (await request.get_data()).decode() + + elif ctype == "application/octet-stream": + body_data = await request.get_data() # raw binary + + else: + # unknown content type → raw body + body_data = await request.get_data() + + return { + "query": query_data, + "headers": header_data, + "body": body_data, + "content_type": ctype, + "raw_files": raw_files + } + + def extract_by_schema(data, schema, name="section"): + """ + Extract only fields defined in schema. + Required fields must exist. + Optional fields default to type-based default values. + Type validation included. + """ + if schema.get("type") != "object": + return {} + + props = schema.get("properties", {}) + required = schema.get("required", []) + + extracted = {} + + for field, field_schema in props.items(): + field_type = field_schema.get("type") + + # 1. Required field missing + if field in required and field not in data: + raise Exception(f"{name} missing required field: {field}") + + # 2. Optional → default value + if field not in data: + extracted[field] = default_for_type(field_type) + continue + + raw_value = data[field] + + # 3. Auto convert value + try: + value = auto_cast_value(raw_value, field_type) + except Exception as e: + raise Exception(f"{name}.{field} auto-cast failed: {str(e)}") + + # 4. Type validation + if not validate_type(value, field_type): + raise Exception( + f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}" + ) + + extracted[field] = value + + return extracted + + + def default_for_type(t): + """Return default value for the given schema type.""" + if t == "file": + return "" + if t == "object": + return {} + if t == "boolean": + return False + if t == "number": + return 0 + if t == "string": + return "" + if t and t.startswith("array"): + return [] + if t == "null": + return None + return None + + def auto_cast_value(value, expected_type): + """Convert string values into schema type when possible.""" + + # Non-string values already good + if not isinstance(value, str): + return value + + v = value.strip() + + # Boolean + if expected_type == "boolean": + if v.lower() in ["true", "1"]: + return True + if v.lower() in ["false", "0"]: + return False + raise Exception(f"Cannot convert '{value}' to boolean") + + # Number + if expected_type == "number": + # integer + if v.isdigit() or (v.startswith("-") and v[1:].isdigit()): + return int(v) + + # float + try: + return float(v) + except: + raise Exception(f"Cannot convert '{value}' to number") + + # Object + if expected_type == "object": + try: + parsed = json.loads(v) + if isinstance(parsed, dict): + return parsed + else: + raise Exception("JSON is not an object") + except: + raise Exception(f"Cannot convert '{value}' to object") + + # Array + if expected_type.startswith("array"): + try: + parsed = json.loads(v) + if isinstance(parsed, list): + return parsed + else: + raise Exception("JSON is not an array") + except: + raise Exception(f"Cannot convert '{value}' to array") + + # String (accept original) + if expected_type == "string": + return value + + # File + if expected_type == "file": + return value + # Default: do nothing + return value + + + def validate_type(value, t): + """Validate value type against schema type t.""" + if t == "file": + return isinstance(value, str) + + if t == "string": + return isinstance(value, str) + + if t == "number": + return isinstance(value, (int, float)) + + if t == "boolean": + return isinstance(value, bool) + + if t == "object": + return isinstance(value, dict) + + # array / array / array + if t.startswith("array"): + if not isinstance(value, list): + return False + + if "<" in t and ">" in t: + inner = t[t.find("<") + 1 : t.find(">")] + + # Check each element type + for item in value: + if not validate_type(item, inner): + return False + + return True + + return True + + def extract_files_by_schema(raw_files, schema, name="files"): + """ + Extract and validate files based on schema. + Only supports type = file (single file). + Does NOT support array. + """ + + if schema.get("type") != "object": + return {} + + props = schema.get("properties", {}) + required = schema.get("required", []) + + cleaned = [] + + for field, field_schema in props.items(): + field_type = field_schema.get("type") + + # 1. Required field must exist + if field in required and field not in raw_files: + raise Exception(f"{name} missing required file field: {field}") + + # 2. Ignore fields that are not file + if field_type != "file": + continue + + # 3. Extract single file + file_obj = raw_files.get(field) + + if file_obj: + cleaned.append({ + "field": field, + "file": file_obj + }) + return cleaned + + parsed = await parse_webhook_request() + SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}}) + + # Extract strictly by schema + query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query") + header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers") + body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body") + files_clean = extract_files_by_schema(parsed["raw_files"], SCHEMA.get("body", {}), name="files") + + uploaded_files = [] + for item in files_clean: # each {field, file} + file_obj = item["file"] + desc = FileService.upload_info( + cvs.user_id, # user + file_obj, # FileStorage + None # url (None for webhook) + ) + uploaded_files.append(desc) + + clean_request = { + "query": query_clean, + "headers": header_clean, + "body": body_clean + } + + if not isinstance(cvs.dsl, str): + dsl = json.dumps(cvs.dsl, ensure_ascii=False) + try: + canvas = Canvas(dsl, cvs.user_id, agent_id) except Exception as e: return get_json_result( data=False, message=str(e), diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 27760f1a8..55c290710 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -1170,7 +1170,7 @@ async def mindmap(): search_id = req.get("search_id", "") search_app = SearchService.get_detail(search_id) if search_id else {} - mind_map = gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {})) + mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {})) if "error" in mind_map: return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) diff --git a/common/http_client.py b/common/http_client.py index 5c57f8638..f9c0de2b4 100644 --- a/common/http_client.py +++ b/common/http_client.py @@ -18,6 +18,7 @@ import time from typing import Any, Dict, Optional from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse +from common import settings import httpx logger = logging.getLogger(__name__) @@ -73,6 +74,32 @@ def _redact_sensitive_url_params(url: str) -> str: except Exception: return url +def _is_sensitive_url(url: str) -> bool: + """Return True if URL is one of the configured OAuth endpoints.""" + # Collect known sensitive endpoint URLs from settings + oauth_urls = set() + # GitHub OAuth endpoints + try: + if settings.GITHUB_OAUTH is not None: + url_val = settings.GITHUB_OAUTH.get("url") + if url_val: oauth_urls.add(url_val) + except Exception: + pass + # Feishu OAuth endpoints + try: + if settings.FEISHU_OAUTH is not None: + for k in ("app_access_token_url", "user_access_token_url"): + url_val = settings.FEISHU_OAUTH.get(k) + if url_val: oauth_urls.add(url_val) + except Exception: + pass + # Defensive normalization: compare only scheme+netloc+path + url_obj = urlparse(url) + for sensitive_url in oauth_urls: + sensitive_obj = urlparse(sensitive_url) + if (url_obj.scheme, url_obj.netloc, url_obj.path) == (sensitive_obj.scheme, sensitive_obj.netloc, sensitive_obj.path): + return True + return False async def async_request( method: str, @@ -115,20 +142,23 @@ async def async_request( method=method, url=url, headers=headers, **kwargs ) duration = time.monotonic() - start + log_url = "" if _is_sensitive_url else _redact_sensitive_url_params(url) logger.debug( - f"async_request {method} {_redact_sensitive_url_params(url)} -> {response.status_code} in {duration:.3f}s" + f"async_request {method} {log_url} -> {response.status_code} in {duration:.3f}s" ) return response except httpx.RequestError as exc: last_exc = exc if attempt >= retries: + log_url = "" if _is_sensitive_url else _redact_sensitive_url_params(url) logger.warning( - f"async_request exhausted retries for {method} {_redact_sensitive_url_params(url)}: {exc}" + f"async_request exhausted retries for {method} {log_url}" ) raise delay = _get_delay(backoff_factor, attempt) + log_url = "" if _is_sensitive_url else _redact_sensitive_url_params(url) logger.warning( - f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {_redact_sensitive_url_params(url)}: {exc}; retrying in {delay:.2f}s" + f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {log_url}; retrying in {delay:.2f}s" ) await asyncio.sleep(delay) raise last_exc # pragma: no cover diff --git a/graphrag/search.py b/graphrag/search.py index 7399ea393..c21a0d827 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -44,7 +44,7 @@ class KGSearch(Dealer): return response def query_rewrite(self, llm, question, idxnms, kb_ids): - ty2ents = asyncio.run(get_entity_type2samples(idxnms, kb_ids)) + ty2ents = get_entity_type2samples(idxnms, kb_ids) hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {}) diff --git a/graphrag/utils.py b/graphrag/utils.py index 9b3dc2c2b..7e3fec1a9 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -626,8 +626,8 @@ def merge_tuples(list1, list2): return result -async def get_entity_type2samples(idxnms, kb_ids: list): - es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids) +def get_entity_type2samples(idxnms, kb_ids: list): + es_res = settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids) res = defaultdict(list) for id in es_res.ids: diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py index a260daebc..1f52f6f63 100644 --- a/rag/utils/opendal_conn.py +++ b/rag/utils/opendal_conn.py @@ -41,13 +41,9 @@ def get_opendal_config(): scheme = opendal_config.get("scheme") config_data = opendal_config.get("config", {}) kwargs = {"scheme": scheme, **config_data} - redacted_kwargs = kwargs.copy() - if 'password' in redacted_kwargs: - redacted_kwargs['password'] = '***REDACTED***' - if 'connection_string' in redacted_kwargs and 'password' in redacted_kwargs: - import re - redacted_kwargs['connection_string'] = re.sub(r':[^@]+@', ':***REDACTED***@', redacted_kwargs['connection_string']) - logging.info("Loaded OpenDAL configuration from yaml: %s", redacted_kwargs) + safe_log_keys=['scheme', 'host', 'port', 'database', 'table'] + loggable_kwargs = {k: v for k, v in kwargs.items() if k in safe_log_keys} + logging.info("Loaded OpenDAL configuration(non sensitive): %s", loggable_kwargs) return kwargs except Exception as e: logging.error("Failed to load OpenDAL configuration from yaml: %s", str(e))