Merge branch 'security/fix-sql-injection-postgres'
This commit is contained in:
commit
c0f69395c7
1 changed files with 35 additions and 25 deletions
|
|
@ -1843,10 +1843,11 @@ class PGKVStorage(BaseKVStorage):
|
||||||
# Query by id
|
# Query by id
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
"""Get data by ids"""
|
"""Get data by ids"""
|
||||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
if not ids:
|
||||||
ids=",".join([f"'{id}'" for id in ids])
|
return []
|
||||||
)
|
|
||||||
params = {"workspace": self.workspace}
|
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace]
|
||||||
|
params = {"workspace": self.workspace, "ids": ids}
|
||||||
results = await self.db.query(sql, list(params.values()), multirows=True)
|
results = await self.db.query(sql, list(params.values()), multirows=True)
|
||||||
|
|
||||||
def _order_results(
|
def _order_results(
|
||||||
|
|
@ -1949,11 +1950,12 @@ class PGKVStorage(BaseKVStorage):
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Filter out duplicated content"""
|
"""Filter out duplicated content"""
|
||||||
sql = SQL_TEMPLATES["filter_keys"].format(
|
if not keys:
|
||||||
table_name=namespace_to_table_name(self.namespace),
|
return set()
|
||||||
ids=",".join([f"'{id}'" for id in keys]),
|
|
||||||
)
|
table_name = namespace_to_table_name(self.namespace)
|
||||||
params = {"workspace": self.workspace}
|
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
||||||
|
params = {"workspace": self.workspace, "ids": list(keys)}
|
||||||
try:
|
try:
|
||||||
res = await self.db.query(sql, list(params.values()), multirows=True)
|
res = await self.db.query(sql, list(params.values()), multirows=True)
|
||||||
if res:
|
if res:
|
||||||
|
|
@ -2532,11 +2534,12 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Filter out duplicated content"""
|
"""Filter out duplicated content"""
|
||||||
sql = SQL_TEMPLATES["filter_keys"].format(
|
if not keys:
|
||||||
table_name=namespace_to_table_name(self.namespace),
|
return set()
|
||||||
ids=",".join([f"'{id}'" for id in keys]),
|
|
||||||
)
|
table_name = namespace_to_table_name(self.namespace)
|
||||||
params = {"workspace": self.workspace}
|
sql = f"SELECT id FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
||||||
|
params = {"workspace": self.workspace, "ids": list(keys)}
|
||||||
try:
|
try:
|
||||||
res = await self.db.query(sql, list(params.values()), multirows=True)
|
res = await self.db.query(sql, list(params.values()), multirows=True)
|
||||||
if res:
|
if res:
|
||||||
|
|
@ -2849,26 +2852,33 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
elif page_size > 200:
|
elif page_size > 200:
|
||||||
page_size = 200
|
page_size = 200
|
||||||
|
|
||||||
if sort_field not in ["created_at", "updated_at", "id", "file_path"]:
|
# Whitelist validation for sort_field to prevent SQL injection
|
||||||
|
allowed_sort_fields = {"created_at", "updated_at", "id", "file_path"}
|
||||||
|
if sort_field not in allowed_sort_fields:
|
||||||
sort_field = "updated_at"
|
sort_field = "updated_at"
|
||||||
|
|
||||||
|
# Whitelist validation for sort_direction to prevent SQL injection
|
||||||
if sort_direction.lower() not in ["asc", "desc"]:
|
if sort_direction.lower() not in ["asc", "desc"]:
|
||||||
sort_direction = "desc"
|
sort_direction = "desc"
|
||||||
|
else:
|
||||||
|
sort_direction = sort_direction.lower()
|
||||||
|
|
||||||
# Calculate offset
|
# Calculate offset
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
# Build WHERE clause
|
# Build parameterized query components
|
||||||
where_clause = "WHERE workspace=$1"
|
|
||||||
params = {"workspace": self.workspace}
|
params = {"workspace": self.workspace}
|
||||||
param_count = 1
|
param_count = 1
|
||||||
|
|
||||||
|
# Build WHERE clause with parameterized query
|
||||||
if status_filter is not None:
|
if status_filter is not None:
|
||||||
param_count += 1
|
param_count += 1
|
||||||
where_clause += f" AND status=${param_count}"
|
where_clause = "WHERE workspace=$1 AND status=$2"
|
||||||
params["status"] = status_filter.value
|
params["status"] = status_filter.value
|
||||||
|
else:
|
||||||
|
where_clause = "WHERE workspace=$1"
|
||||||
|
|
||||||
# Build ORDER BY clause
|
# Build ORDER BY clause using validated whitelist values
|
||||||
order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}"
|
order_clause = f"ORDER BY {sort_field} {sort_direction.upper()}"
|
||||||
|
|
||||||
# Query for total count
|
# Query for total count
|
||||||
|
|
@ -2876,7 +2886,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
count_result = await self.db.query(count_sql, list(params.values()))
|
count_result = await self.db.query(count_sql, list(params.values()))
|
||||||
total_count = count_result["total"] if count_result else 0
|
total_count = count_result["total"] if count_result else 0
|
||||||
|
|
||||||
# Query for paginated data
|
# Query for paginated data with parameterized LIMIT and OFFSET
|
||||||
data_sql = f"""
|
data_sql = f"""
|
||||||
SELECT * FROM LIGHTRAG_DOC_STATUS
|
SELECT * FROM LIGHTRAG_DOC_STATUS
|
||||||
{where_clause}
|
{where_clause}
|
||||||
|
|
@ -4874,19 +4884,19 @@ SQL_TEMPLATES = {
|
||||||
""",
|
""",
|
||||||
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
|
"get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content,
|
||||||
COALESCE(doc_name, '') as file_path
|
COALESCE(doc_name, '') as file_path
|
||||||
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
|
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id = ANY($2)
|
||||||
""",
|
""",
|
||||||
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
||||||
chunk_order_index, full_doc_id, file_path,
|
chunk_order_index, full_doc_id, file_path,
|
||||||
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
|
COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
|
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id = ANY($2)
|
||||||
""",
|
""",
|
||||||
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
|
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, chunk_id, cache_type, queryparam,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids})
|
FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id = ANY($2)
|
||||||
""",
|
""",
|
||||||
"get_by_id_full_entities": """SELECT id, entity_names, count,
|
"get_by_id_full_entities": """SELECT id, entity_names, count,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
|
|
@ -4901,12 +4911,12 @@ SQL_TEMPLATES = {
|
||||||
"get_by_ids_full_entities": """SELECT id, entity_names, count,
|
"get_by_ids_full_entities": """SELECT id, entity_names, count,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id IN ({ids})
|
FROM LIGHTRAG_FULL_ENTITIES WHERE workspace=$1 AND id = ANY($2)
|
||||||
""",
|
""",
|
||||||
"get_by_ids_full_relations": """SELECT id, relation_pairs, count,
|
"get_by_ids_full_relations": """SELECT id, relation_pairs, count,
|
||||||
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
||||||
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
||||||
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id IN ({ids})
|
FROM LIGHTRAG_FULL_RELATIONS WHERE workspace=$1 AND id = ANY($2)
|
||||||
""",
|
""",
|
||||||
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
|
"filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
|
||||||
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)
|
"upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, doc_name, workspace)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue