feat: refactor parameter handling in database queries to use lists for improved consistency

This commit is contained in:
Matt23-star 2025-08-28 16:17:35 -07:00
parent 015e9ae3dd
commit 9804a1885b

View file

@ -787,13 +787,13 @@ class PostgreSQLDB:
FROM information_schema.columns
WHERE table_name = $1 AND column_name = $2
"""
column_info = await self.query(
check_column_sql,
{
params = {
"table_name": migration["table"].lower(),
"column_name": migration["column"],
},
}
column_info = await self.query(
check_column_sql,
list(params.values()),
)
if not column_info:
@ -1035,9 +1035,9 @@ class PostgreSQLDB:
WHERE table_name = $1
AND table_schema = 'public'
"""
params = {"table_name": table_name.lower()}
table_exists = await self.query(
check_table_sql, {"table_name": table_name.lower()}
check_table_sql, list(params.values())
)
if not table_exists:
@ -1121,7 +1121,8 @@ class PostgreSQLDB:
AND indexname = $1
"""
existing = await self.query(check_sql, {"indexname": index["name"]})
params = {"indexname": index["name"]}
existing = await self.query(check_sql, list(params.values()))
if not existing:
logger.info(f"Creating pagination index: {index['description']}")
@ -1217,7 +1218,7 @@ class PostgreSQLDB:
async def query(
self,
sql: str,
params: dict[str, Any] | None = None,
params: list[Any] | None = None,
multirows: bool = False,
with_age: bool = False,
graph_name: str | None = None,
@ -1230,7 +1231,7 @@ class PostgreSQLDB:
try:
if params:
rows = await connection.fetch(sql, *params.values())
rows = await connection.fetch(sql, *params)
else:
rows = await connection.fetch(sql)
@ -1446,7 +1447,7 @@ class PGKVStorage(BaseKVStorage):
params = {"workspace": self.workspace}
try:
results = await self.db.query(sql, params, multirows=True)
results = await self.db.query(sql, list(params.values()), multirows=True)
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
@ -1540,7 +1541,7 @@ class PGKVStorage(BaseKVStorage):
"""Get data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.workspace, "id": id}
response = await self.db.query(sql, params)
response = await self.db.query(sql, list(params.values()))
if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list
@ -1620,7 +1621,7 @@ class PGKVStorage(BaseKVStorage):
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.workspace}
results = await self.db.query(sql, params, multirows=True)
results = await self.db.query(sql, list(params.values()), multirows=True)
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list for each result
@ -1708,7 +1709,7 @@ class PGKVStorage(BaseKVStorage):
)
params = {"workspace": self.workspace}
try:
res = await self.db.query(sql, params, multirows=True)
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
exist_keys = [key["id"] for key in res]
else:
@ -2019,7 +2020,7 @@ class PGVectorStorage(BaseVectorStorage):
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
"top_k": top_k,
}
results = await self.db.query(sql, params=params, multirows=True)
results = await self.db.query(sql, params=list(params.values()), multirows=True)
return results
async def index_done_callback(self) -> None:
@ -2116,7 +2117,7 @@ class PGVectorStorage(BaseVectorStorage):
params = {"workspace": self.workspace, "id": id}
try:
result = await self.db.query(query, params)
result = await self.db.query(query, list(params.values()))
if result:
return dict(result)
return None
@ -2150,7 +2151,7 @@ class PGVectorStorage(BaseVectorStorage):
params = {"workspace": self.workspace}
try:
results = await self.db.query(query, params, multirows=True)
results = await self.db.query(query, list(params.values()), multirows=True)
return [dict(record) for record in results]
except Exception as e:
logger.error(
@ -2183,7 +2184,7 @@ class PGVectorStorage(BaseVectorStorage):
params = {"workspace": self.workspace}
try:
results = await self.db.query(query, params, multirows=True)
results = await self.db.query(query, list(params.values()), multirows=True)
vectors_dict = {}
for result in results:
@ -2270,7 +2271,7 @@ class PGDocStatusStorage(DocStatusStorage):
)
params = {"workspace": self.workspace}
try:
res = await self.db.query(sql, params, multirows=True)
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
exist_keys = [key["id"] for key in res]
else:
@ -2288,7 +2289,7 @@ class PGDocStatusStorage(DocStatusStorage):
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
params = {"workspace": self.workspace, "id": id}
result = await self.db.query(sql, params, True)
result = await self.db.query(sql, list(params.values()), True)
if result is None or result == []:
return None
else:
@ -2334,7 +2335,7 @@ class PGDocStatusStorage(DocStatusStorage):
sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": ids}
results = await self.db.query(sql, params, True)
results = await self.db.query(sql, list(params.values()), True)
if not results:
return []
@ -2385,7 +2386,8 @@ class PGDocStatusStorage(DocStatusStorage):
FROM LIGHTRAG_DOC_STATUS
where workspace=$1 GROUP BY STATUS
"""
result = await self.db.query(sql, {"workspace": self.workspace}, True)
params = {"workspace": self.workspace}
result = await self.db.query(sql, list(params.values()), True)
counts = {}
for doc in result:
counts[doc["status"]] = doc["count"]
@ -2397,7 +2399,7 @@ class PGDocStatusStorage(DocStatusStorage):
"""all documents with a specific status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.workspace, "status": status.value}
result = await self.db.query(sql, params, True)
result = await self.db.query(sql, list(params.values()), True)
docs_by_status = {}
for element in result:
@ -2451,7 +2453,7 @@ class PGDocStatusStorage(DocStatusStorage):
"""Get all documents with a specific track_id"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2"
params = {"workspace": self.workspace, "track_id": track_id}
result = await self.db.query(sql, params, True)
result = await self.db.query(sql, list(params.values()), True)
docs_by_track_id = {}
for element in result:
@ -2551,7 +2553,7 @@ class PGDocStatusStorage(DocStatusStorage):
# Query for total count
count_sql = f"SELECT COUNT(*) as total FROM LIGHTRAG_DOC_STATUS {where_clause}"
count_result = await self.db.query(count_sql, params)
count_result = await self.db.query(count_sql, list(params.values()))
total_count = count_result["total"] if count_result else 0
# Query for paginated data
@ -2564,7 +2566,7 @@ class PGDocStatusStorage(DocStatusStorage):
params["limit"] = page_size
params["offset"] = offset
result = await self.db.query(data_sql, params, True)
result = await self.db.query(data_sql, list(params.values()), True)
# Convert to (doc_id, DocProcessingStatus) tuples
documents = []
@ -2621,7 +2623,7 @@ class PGDocStatusStorage(DocStatusStorage):
GROUP BY status
"""
params = {"workspace": self.workspace}
result = await self.db.query(sql, params, True)
result = await self.db.query(sql, list(params.values()), True)
counts = {}
total_count = 0
@ -3067,7 +3069,7 @@ class PGGraphStorage(BaseGraphStorage):
if readonly:
data = await self.db.query(
query,
params,
list(params.values()) if params else None,
multirows=True,
with_age=True,
graph_name=self.graph_name,