fix:async issue and sensitive logging
This commit is contained in:
parent
c610bb605a
commit
c6ca4a08f2
6 changed files with 527 additions and 27 deletions
|
|
@ -132,19 +132,14 @@ def delete_agent(tenant_id: str, agent_id: str):
|
|||
UserCanvasService.delete_by_id(agent_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
_rate_limit_cache = {}
|
||||
|
||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
async def webhook(tenant_id: str, agent_id: str):
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
@manager.route('/webhook/<agent_id>', methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
|
||||
async def webhook(agent_id: str):
|
||||
# 1. Fetch canvas by agent_id
|
||||
exists, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not exists:
|
||||
return get_data_error_result(message="Canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
|
@ -152,8 +147,487 @@ async def webhook(tenant_id: str, agent_id: str):
|
|||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
|
||||
|
||||
# 3. Load DSL from canvas
|
||||
dsl = getattr(cvs, "dsl", None)
|
||||
if not isinstance(dsl, dict):
|
||||
return get_data_error_result(message="Invalid DSL format.")
|
||||
|
||||
# 4. Check webhook configuration in DSL
|
||||
components = dsl.get("webhook", {})
|
||||
for k, cpn in components.items():
|
||||
cpn_obj = components[k]["obj"]
|
||||
if cpn_obj.component_name.lower() == "begin" and cpn.params.mode == "webhook":
|
||||
webhook_cfg = cpn.params
|
||||
|
||||
if not webhook_cfg:
|
||||
return get_data_error_result(message="Webhook not configured for this agent.")
|
||||
|
||||
# 5. Validate request method against webhook_cfg.methods
|
||||
allowed_methods = webhook_cfg.get("methods", [])
|
||||
request_method = request.method.upper()
|
||||
if allowed_methods and request_method not in allowed_methods:
|
||||
return get_data_error_result(
|
||||
message=f"HTTP method '{request_method}' not allowed for this webhook."
|
||||
)
|
||||
|
||||
# 6. Validate webhook security
|
||||
async def validate_webhook_security(security_cfg: dict):
|
||||
"""Validate webhook security rules based on security configuration."""
|
||||
|
||||
if not security_cfg:
|
||||
return # No security config → allowed by default
|
||||
|
||||
# 1. Validate max body size
|
||||
await _validate_max_body_size(security_cfg)
|
||||
|
||||
# 2. Validate IP whitelist
|
||||
_validate_ip_whitelist(security_cfg)
|
||||
|
||||
# # 3. Validate rate limiting
|
||||
_validate_rate_limit(security_cfg)
|
||||
|
||||
# 4. Validate authentication
|
||||
auth_type = security_cfg.get("auth_type", "none")
|
||||
|
||||
if auth_type == "none":
|
||||
return
|
||||
|
||||
if auth_type == "token":
|
||||
_validate_token_auth(security_cfg)
|
||||
|
||||
elif auth_type == "basic":
|
||||
_validate_basic_auth(security_cfg)
|
||||
|
||||
elif auth_type == "jwt":
|
||||
_validate_jwt_auth(security_cfg)
|
||||
|
||||
elif auth_type == "hmac":
|
||||
await _validate_hmac_auth(security_cfg)
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported auth_type: {auth_type}")
|
||||
|
||||
async def _validate_max_body_size(security_cfg):
|
||||
"""Check request size does not exceed max_body_size."""
|
||||
max_size = security_cfg.get("max_body_size")
|
||||
if not max_size:
|
||||
return
|
||||
|
||||
# Convert "10MB" → bytes
|
||||
units = {"KB": 1024, "MB": 1024**2, "GB": 1024**3}
|
||||
size_str = max_size.lower()
|
||||
|
||||
for suffix, factor in units.items():
|
||||
if size_str.endswith(suffix):
|
||||
limit = int(size_str.replace(suffix, "")) * factor
|
||||
break
|
||||
else:
|
||||
raise Exception("Invalid max_body_size format")
|
||||
|
||||
content_length = request.content_length or 0
|
||||
if content_length > limit:
|
||||
raise Exception(f"Request body too large: {content_length} > {limit}")
|
||||
|
||||
def _validate_ip_whitelist(security_cfg):
|
||||
"""Allow only IPs listed in ip_whitelist."""
|
||||
whitelist = security_cfg.get("ip_whitelist", [])
|
||||
if not whitelist:
|
||||
return
|
||||
|
||||
client_ip = request.remote_addr
|
||||
|
||||
|
||||
for rule in whitelist:
|
||||
if "/" in rule:
|
||||
# CIDR notation
|
||||
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
|
||||
return
|
||||
else:
|
||||
# Single IP
|
||||
if client_ip == rule:
|
||||
return
|
||||
|
||||
raise Exception(f"IP {client_ip} is not allowed by whitelist")
|
||||
|
||||
def _validate_rate_limit(security_cfg):
|
||||
"""Simple in-memory rate limiting."""
|
||||
rl = security_cfg.get("rate_limit")
|
||||
if not rl:
|
||||
return
|
||||
|
||||
limit = rl.get("limit", 60)
|
||||
per = rl.get("per", "minute")
|
||||
|
||||
window = {"second": 1, "minute": 60, "hour": 3600, "day": 86400}.get(per, 60)
|
||||
key = f"rl:{agent_id}"
|
||||
|
||||
now = int(time.time())
|
||||
bucket = _rate_limit_cache.get(key, {"ts": now, "count": 0})
|
||||
|
||||
# Reset window
|
||||
if now - bucket["ts"] > window:
|
||||
bucket = {"ts": now, "count": 0}
|
||||
|
||||
bucket["count"] += 1
|
||||
_rate_limit_cache[key] = bucket
|
||||
|
||||
if bucket["count"] > limit:
|
||||
raise Exception("Too many requests (rate limit exceeded)")
|
||||
|
||||
def _validate_token_auth(security_cfg):
|
||||
"""Validate header-based token authentication."""
|
||||
token_cfg = security_cfg.get("token",{})
|
||||
header = token_cfg.get("token_header")
|
||||
token_value = token_cfg.get("token_value")
|
||||
|
||||
provided = request.headers.get(header)
|
||||
if provided != token_value:
|
||||
raise Exception("Invalid token authentication")
|
||||
|
||||
def _validate_basic_auth(security_cfg):
|
||||
"""Validate HTTP Basic Auth credentials."""
|
||||
auth_cfg = security_cfg.get("basic_auth", {})
|
||||
username = auth_cfg.get("username")
|
||||
password = auth_cfg.get("password")
|
||||
|
||||
auth = request.authorization
|
||||
if not auth or auth.username != username or auth.password != password:
|
||||
raise Exception("Invalid Basic Auth credentials")
|
||||
|
||||
def _validate_jwt_auth(security_cfg):
|
||||
"""Validate JWT token in Authorization header."""
|
||||
jwt_cfg = security_cfg.get("jwt", {})
|
||||
secret = jwt_cfg.get("secret")
|
||||
required_claims = jwt_cfg.get("required_claims", [])
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise Exception("Missing Bearer token")
|
||||
|
||||
token = auth_header.replace("Bearer ", "")
|
||||
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
token,
|
||||
secret,
|
||||
algorithms=[jwt_cfg.get("algorithm", "HS256")],
|
||||
audience=jwt_cfg.get("audience"),
|
||||
issuer=jwt_cfg.get("issuer"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Invalid JWT: {str(e)}")
|
||||
|
||||
for claim in required_claims:
|
||||
if claim not in decoded:
|
||||
raise Exception(f"Missing JWT claim: {claim}")
|
||||
|
||||
async def _validate_hmac_auth(security_cfg):
|
||||
"""Validate HMAC signature from header."""
|
||||
hmac_cfg = security_cfg.get("hmac", {})
|
||||
header = hmac_cfg.get("header")
|
||||
secret = hmac_cfg.get("secret")
|
||||
algorithm = hmac_cfg.get("algorithm", "sha256")
|
||||
|
||||
provided_sig = request.headers.get(header)
|
||||
if not provided_sig:
|
||||
raise Exception("Missing HMAC signature header")
|
||||
|
||||
body = await request.get_data()
|
||||
if body is None:
|
||||
body = b""
|
||||
elif isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
|
||||
computed = hmac.new(secret.encode(), body, getattr(hashlib, algorithm)).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(provided_sig, computed):
|
||||
raise Exception("Invalid HMAC signature")
|
||||
|
||||
try:
|
||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
||||
await validate_webhook_security(webhook_cfg.get("security", {}))
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=str(e))
|
||||
|
||||
# 7. Parse request body
|
||||
async def parse_webhook_request():
|
||||
"""Parse request based on content-type and return structured data."""
|
||||
|
||||
# 1. Parse query parameters
|
||||
query_data = {}
|
||||
for k, v in request.args.items():
|
||||
query_data[k] = v
|
||||
|
||||
# 2. Parse headers
|
||||
header_data = {}
|
||||
for k, v in request.headers.items():
|
||||
header_data[k] = v
|
||||
|
||||
# 3. Parse body based on content-type
|
||||
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
raw_files = {}
|
||||
|
||||
if ctype == "application/json":
|
||||
try:
|
||||
body_data = await request.get_json()
|
||||
except:
|
||||
body_data = None
|
||||
|
||||
elif ctype == "multipart/form-data":
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
raw_files = {name: file for name, file in files.items()}
|
||||
body_data = {
|
||||
"form": dict(form),
|
||||
"files": {name: file.filename for name, file in files.items()},
|
||||
}
|
||||
|
||||
elif ctype == "application/x-www-form-urlencoded":
|
||||
form = await request.form
|
||||
body_data = dict(form)
|
||||
|
||||
elif ctype == "text/plain":
|
||||
body_data = (await request.get_data()).decode()
|
||||
|
||||
elif ctype == "application/octet-stream":
|
||||
body_data = await request.get_data() # raw binary
|
||||
|
||||
else:
|
||||
# unknown content type → raw body
|
||||
body_data = await request.get_data()
|
||||
|
||||
return {
|
||||
"query": query_data,
|
||||
"headers": header_data,
|
||||
"body": body_data,
|
||||
"content_type": ctype,
|
||||
"raw_files": raw_files
|
||||
}
|
||||
|
||||
def extract_by_schema(data, schema, name="section"):
|
||||
"""
|
||||
Extract only fields defined in schema.
|
||||
Required fields must exist.
|
||||
Optional fields default to type-based default values.
|
||||
Type validation included.
|
||||
"""
|
||||
if schema.get("type") != "object":
|
||||
return {}
|
||||
|
||||
props = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
extracted = {}
|
||||
|
||||
for field, field_schema in props.items():
|
||||
field_type = field_schema.get("type")
|
||||
|
||||
# 1. Required field missing
|
||||
if field in required and field not in data:
|
||||
raise Exception(f"{name} missing required field: {field}")
|
||||
|
||||
# 2. Optional → default value
|
||||
if field not in data:
|
||||
extracted[field] = default_for_type(field_type)
|
||||
continue
|
||||
|
||||
raw_value = data[field]
|
||||
|
||||
# 3. Auto convert value
|
||||
try:
|
||||
value = auto_cast_value(raw_value, field_type)
|
||||
except Exception as e:
|
||||
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
|
||||
|
||||
# 4. Type validation
|
||||
if not validate_type(value, field_type):
|
||||
raise Exception(
|
||||
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
extracted[field] = value
|
||||
|
||||
return extracted
|
||||
|
||||
|
||||
def default_for_type(t):
|
||||
"""Return default value for the given schema type."""
|
||||
if t == "file":
|
||||
return ""
|
||||
if t == "object":
|
||||
return {}
|
||||
if t == "boolean":
|
||||
return False
|
||||
if t == "number":
|
||||
return 0
|
||||
if t == "string":
|
||||
return ""
|
||||
if t and t.startswith("array"):
|
||||
return []
|
||||
if t == "null":
|
||||
return None
|
||||
return None
|
||||
|
||||
def auto_cast_value(value, expected_type):
|
||||
"""Convert string values into schema type when possible."""
|
||||
|
||||
# Non-string values already good
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
v = value.strip()
|
||||
|
||||
# Boolean
|
||||
if expected_type == "boolean":
|
||||
if v.lower() in ["true", "1"]:
|
||||
return True
|
||||
if v.lower() in ["false", "0"]:
|
||||
return False
|
||||
raise Exception(f"Cannot convert '{value}' to boolean")
|
||||
|
||||
# Number
|
||||
if expected_type == "number":
|
||||
# integer
|
||||
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
|
||||
return int(v)
|
||||
|
||||
# float
|
||||
try:
|
||||
return float(v)
|
||||
except:
|
||||
raise Exception(f"Cannot convert '{value}' to number")
|
||||
|
||||
# Object
|
||||
if expected_type == "object":
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
else:
|
||||
raise Exception("JSON is not an object")
|
||||
except:
|
||||
raise Exception(f"Cannot convert '{value}' to object")
|
||||
|
||||
# Array <T>
|
||||
if expected_type.startswith("array"):
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
if isinstance(parsed, list):
|
||||
return parsed
|
||||
else:
|
||||
raise Exception("JSON is not an array")
|
||||
except:
|
||||
raise Exception(f"Cannot convert '{value}' to array")
|
||||
|
||||
# String (accept original)
|
||||
if expected_type == "string":
|
||||
return value
|
||||
|
||||
# File
|
||||
if expected_type == "file":
|
||||
return value
|
||||
# Default: do nothing
|
||||
return value
|
||||
|
||||
|
||||
def validate_type(value, t):
|
||||
"""Validate value type against schema type t."""
|
||||
if t == "file":
|
||||
return isinstance(value, str)
|
||||
|
||||
if t == "string":
|
||||
return isinstance(value, str)
|
||||
|
||||
if t == "number":
|
||||
return isinstance(value, (int, float))
|
||||
|
||||
if t == "boolean":
|
||||
return isinstance(value, bool)
|
||||
|
||||
if t == "object":
|
||||
return isinstance(value, dict)
|
||||
|
||||
# array<string> / array<number> / array<object>
|
||||
if t.startswith("array"):
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if "<" in t and ">" in t:
|
||||
inner = t[t.find("<") + 1 : t.find(">")]
|
||||
|
||||
# Check each element type
|
||||
for item in value:
|
||||
if not validate_type(item, inner):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
def extract_files_by_schema(raw_files, schema, name="files"):
|
||||
"""
|
||||
Extract and validate files based on schema.
|
||||
Only supports type = file (single file).
|
||||
Does NOT support array<file>.
|
||||
"""
|
||||
|
||||
if schema.get("type") != "object":
|
||||
return {}
|
||||
|
||||
props = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
cleaned = []
|
||||
|
||||
for field, field_schema in props.items():
|
||||
field_type = field_schema.get("type")
|
||||
|
||||
# 1. Required field must exist
|
||||
if field in required and field not in raw_files:
|
||||
raise Exception(f"{name} missing required file field: {field}")
|
||||
|
||||
# 2. Ignore fields that are not file
|
||||
if field_type != "file":
|
||||
continue
|
||||
|
||||
# 3. Extract single file
|
||||
file_obj = raw_files.get(field)
|
||||
|
||||
if file_obj:
|
||||
cleaned.append({
|
||||
"field": field,
|
||||
"file": file_obj
|
||||
})
|
||||
return cleaned
|
||||
|
||||
parsed = await parse_webhook_request()
|
||||
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
|
||||
|
||||
# Extract strictly by schema
|
||||
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
|
||||
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
|
||||
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
|
||||
files_clean = extract_files_by_schema(parsed["raw_files"], SCHEMA.get("body", {}), name="files")
|
||||
|
||||
uploaded_files = []
|
||||
for item in files_clean: # each {field, file}
|
||||
file_obj = item["file"]
|
||||
desc = FileService.upload_info(
|
||||
cvs.user_id, # user
|
||||
file_obj, # FileStorage
|
||||
None # url (None for webhook)
|
||||
)
|
||||
uploaded_files.append(desc)
|
||||
|
||||
clean_request = {
|
||||
"query": query_clean,
|
||||
"headers": header_clean,
|
||||
"body": body_clean
|
||||
}
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
try:
|
||||
canvas = Canvas(dsl, cvs.user_id, agent_id)
|
||||
except Exception as e:
|
||||
return get_json_result(
|
||||
data=False, message=str(e),
|
||||
|
|
|
|||
|
|
@ -1170,7 +1170,7 @@ async def mindmap():
|
|||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
|
||||
mind_map = gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
|
||||
mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
|
||||
if "error" in mind_map:
|
||||
return server_error_response(Exception(mind_map["error"]))
|
||||
return get_json_result(data=mind_map)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import time
|
|||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
||||
|
||||
from common import settings
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -73,6 +74,32 @@ def _redact_sensitive_url_params(url: str) -> str:
|
|||
except Exception:
|
||||
return url
|
||||
|
||||
def _is_sensitive_url(url: str) -> bool:
|
||||
"""Return True if URL is one of the configured OAuth endpoints."""
|
||||
# Collect known sensitive endpoint URLs from settings
|
||||
oauth_urls = set()
|
||||
# GitHub OAuth endpoints
|
||||
try:
|
||||
if settings.GITHUB_OAUTH is not None:
|
||||
url_val = settings.GITHUB_OAUTH.get("url")
|
||||
if url_val: oauth_urls.add(url_val)
|
||||
except Exception:
|
||||
pass
|
||||
# Feishu OAuth endpoints
|
||||
try:
|
||||
if settings.FEISHU_OAUTH is not None:
|
||||
for k in ("app_access_token_url", "user_access_token_url"):
|
||||
url_val = settings.FEISHU_OAUTH.get(k)
|
||||
if url_val: oauth_urls.add(url_val)
|
||||
except Exception:
|
||||
pass
|
||||
# Defensive normalization: compare only scheme+netloc+path
|
||||
url_obj = urlparse(url)
|
||||
for sensitive_url in oauth_urls:
|
||||
sensitive_obj = urlparse(sensitive_url)
|
||||
if (url_obj.scheme, url_obj.netloc, url_obj.path) == (sensitive_obj.scheme, sensitive_obj.netloc, sensitive_obj.path):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def async_request(
|
||||
method: str,
|
||||
|
|
@ -115,20 +142,23 @@ async def async_request(
|
|||
method=method, url=url, headers=headers, **kwargs
|
||||
)
|
||||
duration = time.monotonic() - start
|
||||
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url else _redact_sensitive_url_params(url)
|
||||
logger.debug(
|
||||
f"async_request {method} {_redact_sensitive_url_params(url)} -> {response.status_code} in {duration:.3f}s"
|
||||
f"async_request {method} {log_url} -> {response.status_code} in {duration:.3f}s"
|
||||
)
|
||||
return response
|
||||
except httpx.RequestError as exc:
|
||||
last_exc = exc
|
||||
if attempt >= retries:
|
||||
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url else _redact_sensitive_url_params(url)
|
||||
logger.warning(
|
||||
f"async_request exhausted retries for {method} {_redact_sensitive_url_params(url)}: {exc}"
|
||||
f"async_request exhausted retries for {method} {log_url}"
|
||||
)
|
||||
raise
|
||||
delay = _get_delay(backoff_factor, attempt)
|
||||
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url else _redact_sensitive_url_params(url)
|
||||
logger.warning(
|
||||
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {_redact_sensitive_url_params(url)}: {exc}; retrying in {delay:.2f}s"
|
||||
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {log_url}; retrying in {delay:.2f}s"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
raise last_exc # pragma: no cover
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class KGSearch(Dealer):
|
|||
return response
|
||||
|
||||
def query_rewrite(self, llm, question, idxnms, kb_ids):
|
||||
ty2ents = asyncio.run(get_entity_type2samples(idxnms, kb_ids))
|
||||
ty2ents = get_entity_type2samples(idxnms, kb_ids)
|
||||
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
|
||||
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
|
||||
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})
|
||||
|
|
|
|||
|
|
@ -626,8 +626,8 @@ def merge_tuples(list1, list2):
|
|||
return result
|
||||
|
||||
|
||||
async def get_entity_type2samples(idxnms, kb_ids: list):
|
||||
es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
|
||||
def get_entity_type2samples(idxnms, kb_ids: list):
|
||||
es_res = settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
|
||||
|
||||
res = defaultdict(list)
|
||||
for id in es_res.ids:
|
||||
|
|
|
|||
|
|
@ -41,13 +41,9 @@ def get_opendal_config():
|
|||
scheme = opendal_config.get("scheme")
|
||||
config_data = opendal_config.get("config", {})
|
||||
kwargs = {"scheme": scheme, **config_data}
|
||||
redacted_kwargs = kwargs.copy()
|
||||
if 'password' in redacted_kwargs:
|
||||
redacted_kwargs['password'] = '***REDACTED***'
|
||||
if 'connection_string' in redacted_kwargs and 'password' in redacted_kwargs:
|
||||
import re
|
||||
redacted_kwargs['connection_string'] = re.sub(r':[^@]+@', ':***REDACTED***@', redacted_kwargs['connection_string'])
|
||||
logging.info("Loaded OpenDAL configuration from yaml: %s", redacted_kwargs)
|
||||
safe_log_keys=['scheme', 'host', 'port', 'database', 'table']
|
||||
loggable_kwargs = {k: v for k, v in kwargs.items() if k in safe_log_keys}
|
||||
logging.info("Loaded OpenDAL configuration(non sensitive): %s", loggable_kwargs)
|
||||
return kwargs
|
||||
except Exception as e:
|
||||
logging.error("Failed to load OpenDAL configuration from yaml: %s", str(e))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue