feat: refactor parameter handling in database queries to use lists for improved consistency
This commit is contained in:
parent
015e9ae3dd
commit
9804a1885b
1 changed files with 30 additions and 28 deletions
|
|
@ -787,13 +787,13 @@ class PostgreSQLDB:
|
|||
FROM information_schema.columns
|
||||
WHERE table_name = $1 AND column_name = $2
|
||||
"""
|
||||
|
||||
column_info = await self.query(
|
||||
check_column_sql,
|
||||
{
|
||||
params = {
|
||||
"table_name": migration["table"].lower(),
|
||||
"column_name": migration["column"],
|
||||
},
|
||||
}
|
||||
column_info = await self.query(
|
||||
check_column_sql,
|
||||
list(params.values()),
|
||||
)
|
||||
|
||||
if not column_info:
|
||||
|
|
@ -1035,9 +1035,9 @@ class PostgreSQLDB:
|
|||
WHERE table_name = $1
|
||||
AND table_schema = 'public'
|
||||
"""
|
||||
|
||||
params = {"table_name": table_name.lower()}
|
||||
table_exists = await self.query(
|
||||
check_table_sql, {"table_name": table_name.lower()}
|
||||
check_table_sql, list(params.values())
|
||||
)
|
||||
|
||||
if not table_exists:
|
||||
|
|
@ -1121,7 +1121,8 @@ class PostgreSQLDB:
|
|||
AND indexname = $1
|
||||
"""
|
||||
|
||||
existing = await self.query(check_sql, {"indexname": index["name"]})
|
||||
params = {"indexname": index["name"]}
|
||||
existing = await self.query(check_sql, list(params.values()))
|
||||
|
||||
if not existing:
|
||||
logger.info(f"Creating pagination index: {index['description']}")
|
||||
|
|
@ -1217,7 +1218,7 @@ class PostgreSQLDB:
|
|||
async def query(
|
||||
self,
|
||||
sql: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
params: list[Any] | None = None,
|
||||
multirows: bool = False,
|
||||
with_age: bool = False,
|
||||
graph_name: str | None = None,
|
||||
|
|
@ -1230,7 +1231,7 @@ class PostgreSQLDB:
|
|||
|
||||
try:
|
||||
if params:
|
||||
rows = await connection.fetch(sql, *params.values())
|
||||
rows = await connection.fetch(sql, *params)
|
||||
else:
|
||||
rows = await connection.fetch(sql)
|
||||
|
||||
|
|
@ -1446,7 +1447,7 @@ class PGKVStorage(BaseKVStorage):
|
|||
params = {"workspace": self.workspace}
|
||||
|
||||
try:
|
||||
results = await self.db.query(sql, params, multirows=True)
|
||||
results = await self.db.query(sql, list(params.values()), multirows=True)
|
||||
|
||||
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
|
|
@ -1540,7 +1541,7 @@ class PGKVStorage(BaseKVStorage):
|
|||
"""Get data by id."""
|
||||
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"workspace": self.workspace, "id": id}
|
||||
response = await self.db.query(sql, params)
|
||||
response = await self.db.query(sql, list(params.values()))
|
||||
|
||||
if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
# Parse llm_cache_list JSON string back to list
|
||||
|
|
@ -1620,7 +1621,7 @@ class PGKVStorage(BaseKVStorage):
|
|||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
params = {"workspace": self.workspace}
|
||||
results = await self.db.query(sql, params, multirows=True)
|
||||
results = await self.db.query(sql, list(params.values()), multirows=True)
|
||||
|
||||
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||
# Parse llm_cache_list JSON string back to list for each result
|
||||
|
|
@ -1708,7 +1709,7 @@ class PGKVStorage(BaseKVStorage):
|
|||
)
|
||||
params = {"workspace": self.workspace}
|
||||
try:
|
||||
res = await self.db.query(sql, params, multirows=True)
|
||||
res = await self.db.query(sql, list(params.values()), multirows=True)
|
||||
if res:
|
||||
exist_keys = [key["id"] for key in res]
|
||||
else:
|
||||
|
|
@ -2019,7 +2020,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
|
||||
"top_k": top_k,
|
||||
}
|
||||
results = await self.db.query(sql, params=params, multirows=True)
|
||||
results = await self.db.query(sql, params=list(params.values()), multirows=True)
|
||||
return results
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
|
|
@ -2116,7 +2117,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
params = {"workspace": self.workspace, "id": id}
|
||||
|
||||
try:
|
||||
result = await self.db.query(query, params)
|
||||
result = await self.db.query(query, list(params.values()))
|
||||
if result:
|
||||
return dict(result)
|
||||
return None
|
||||
|
|
@ -2150,7 +2151,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
params = {"workspace": self.workspace}
|
||||
|
||||
try:
|
||||
results = await self.db.query(query, params, multirows=True)
|
||||
results = await self.db.query(query, list(params.values()), multirows=True)
|
||||
return [dict(record) for record in results]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
@ -2183,7 +2184,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
params = {"workspace": self.workspace}
|
||||
|
||||
try:
|
||||
results = await self.db.query(query, params, multirows=True)
|
||||
results = await self.db.query(query, list(params.values()), multirows=True)
|
||||
vectors_dict = {}
|
||||
|
||||
for result in results:
|
||||
|
|
@ -2270,7 +2271,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
)
|
||||
params = {"workspace": self.workspace}
|
||||
try:
|
||||
res = await self.db.query(sql, params, multirows=True)
|
||||
res = await self.db.query(sql, list(params.values()), multirows=True)
|
||||
if res:
|
||||
exist_keys = [key["id"] for key in res]
|
||||
else:
|
||||
|
|
@ -2288,7 +2289,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
||||
params = {"workspace": self.workspace, "id": id}
|
||||
result = await self.db.query(sql, params, True)
|
||||
result = await self.db.query(sql, list(params.values()), True)
|
||||
if result is None or result == []:
|
||||
return None
|
||||
else:
|
||||
|
|
@ -2334,7 +2335,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)"
|
||||
params = {"workspace": self.workspace, "ids": ids}
|
||||
|
||||
results = await self.db.query(sql, params, True)
|
||||
results = await self.db.query(sql, list(params.values()), True)
|
||||
|
||||
if not results:
|
||||
return []
|
||||
|
|
@ -2385,7 +2386,8 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
FROM LIGHTRAG_DOC_STATUS
|
||||
where workspace=$1 GROUP BY STATUS
|
||||
"""
|
||||
result = await self.db.query(sql, {"workspace": self.workspace}, True)
|
||||
params = {"workspace": self.workspace}
|
||||
result = await self.db.query(sql, list(params.values()), True)
|
||||
counts = {}
|
||||
for doc in result:
|
||||
counts[doc["status"]] = doc["count"]
|
||||
|
|
@ -2397,7 +2399,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
"""all documents with a specific status"""
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
|
||||
params = {"workspace": self.workspace, "status": status.value}
|
||||
result = await self.db.query(sql, params, True)
|
||||
result = await self.db.query(sql, list(params.values()), True)
|
||||
|
||||
docs_by_status = {}
|
||||
for element in result:
|
||||
|
|
@ -2451,7 +2453,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
"""Get all documents with a specific track_id"""
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2"
|
||||
params = {"workspace": self.workspace, "track_id": track_id}
|
||||
result = await self.db.query(sql, params, True)
|
||||
result = await self.db.query(sql, list(params.values()), True)
|
||||
|
||||
docs_by_track_id = {}
|
||||
for element in result:
|
||||
|
|
@ -2551,7 +2553,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
|
||||
# Query for total count
|
||||
count_sql = f"SELECT COUNT(*) as total FROM LIGHTRAG_DOC_STATUS {where_clause}"
|
||||
count_result = await self.db.query(count_sql, params)
|
||||
count_result = await self.db.query(count_sql, list(params.values()))
|
||||
total_count = count_result["total"] if count_result else 0
|
||||
|
||||
# Query for paginated data
|
||||
|
|
@ -2564,7 +2566,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
params["limit"] = page_size
|
||||
params["offset"] = offset
|
||||
|
||||
result = await self.db.query(data_sql, params, True)
|
||||
result = await self.db.query(data_sql, list(params.values()), True)
|
||||
|
||||
# Convert to (doc_id, DocProcessingStatus) tuples
|
||||
documents = []
|
||||
|
|
@ -2621,7 +2623,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
GROUP BY status
|
||||
"""
|
||||
params = {"workspace": self.workspace}
|
||||
result = await self.db.query(sql, params, True)
|
||||
result = await self.db.query(sql, list(params.values()), True)
|
||||
|
||||
counts = {}
|
||||
total_count = 0
|
||||
|
|
@ -3067,7 +3069,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||
if readonly:
|
||||
data = await self.db.query(
|
||||
query,
|
||||
params,
|
||||
list(params.values()) if params else None,
|
||||
multirows=True,
|
||||
with_age=True,
|
||||
graph_name=self.graph_name,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue