Merge branch 'security/fix-sql-injection-postgres'

This commit is contained in:
yangdx 2025-10-18 11:45:13 +08:00
commit c0f69395c7

View file

@ -1843,10 +1843,11 @@ class PGKVStorage(BaseKVStorage):
# Query by id
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get data by ids"""
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.workspace}
if not ids:
return []
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
params = {"workspace": self.workspace, "ids": ids}
results = await self.db.query(sql, list(params.values()), multirows=True)
def _order_results(
@ -1949,11 +1950,12 @@ class PGKVStorage(BaseKVStorage):
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.workspace}
if not keys:
return set()
table_name = namespace_to_table_name(self.namespace)
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
try:
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
@ -2532,11 +2534,12 @@ class PGDocStatusStorage(DocStatusStorage):
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
params = {"workspace": self.workspace}
if not keys:
return set()
table_name = namespace_to_table_name(self.namespace)
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": list(keys)}
try:
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
@ -2849,26 +2852,33 @@ class PGDocStatusStorage(DocStatusStorage):
elif page_size > 200:
page_size = 200
if sort_field not in ["created_at", "updated_at", "id", "file_path"]:
# Whitelist validation for sort_field to prevent SQL injection
allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"}
if sort_field not in allowed_sort_fields:
sort_field = "updated_at"
# Whitelist validation for sort_direction to prevent SQL injection
if sort_direction.lower() not in ["asc", "desc"]:
sort_direction = "desc"
else:
sort_direction = sort_direction.lower()
# Calculate offset
offset = (page - 1) * page_size
# Build WHERE clause
where_clause = "WHERE workspace=$1"
# Build parameterized query components
params = {"workspace": self.workspace}
param_count = 1
# Build WHERE clause with parameterized query
if status_filter is not None:
param_count += 1
where_clause += f" AND status=${param_count}"
where_clause = "WHERE workspace=$1 AND status=$2"
params["status"] = status_filter.value
else:
where_clause = "WHERE workspace=$1"
# Build ORDER BY clause
# Build ORDER BY clause using validated whitelist values
order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}"
# Query for total count
@ -2876,7 +2886,7 @@ class PGDocStatusStorage(DocStatusStorage):
count_result = await self.db.query(count_sql, list(params.values()))
total_count = count_result["total"] if count_result else 0
# Query for paginated data
# Query for paginated data with parameterized LIMIT and OFFSET
data_sql = f"""
SELECT * FROM LIGHTRAG_DOC_STATUS
{where_clause}
@ -4874,19 +4884,19 @@ SQL_TEMPLATES = {
""",
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
COALESCE(doc_name, '') as file_path
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
chunk_order_index, full_doc_id, file_path,
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_id_full_entities": """SELECT id, entity_names, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
@ -4901,12 +4911,12 @@ SQL_TEMPLATES = {
"get_by_ids_full_entities": """SELECT id, entity_names, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2)
""",
"get_by_ids_full_relations": """SELECT id, relation_pairs, count,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids})
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
""",
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)