From 9804a1885b1c1bbf9f4afc3b59c88eb0f50f6fc3 Mon Sep 17 00:00:00 2001 From: Matt23-star Date: Thu, 28 Aug 2025 16:17:35 -0700 Subject: [PATCH] feat: refactor parameter handling in database queries to use lists for improved consistency --- lightrag/kg/postgres_impl.py | 58 +++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index b238b7a1..8848f0fd 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -787,13 +787,13 @@ class PostgreSQLDB: FROM information_schema.columns WHERE table_name = $1 AND column_name = $2 """ - - column_info = await self.query( - check_column_sql, - { + params = { "table_name": migration["table"].lower(), "column_name": migration["column"], - }, + } + column_info = await self.query( + check_column_sql, + list(params.values()), ) if not column_info: @@ -1035,9 +1035,9 @@ class PostgreSQLDB: WHERE table_name = $1 AND table_schema = 'public' """ - + params = {"table_name": table_name.lower()} table_exists = await self.query( - check_table_sql, {"table_name": table_name.lower()} + check_table_sql, list(params.values()) ) if not table_exists: @@ -1121,7 +1121,8 @@ class PostgreSQLDB: AND indexname = $1 """ - existing = await self.query(check_sql, {"indexname": index["name"]}) + params = {"indexname": index["name"]} + existing = await self.query(check_sql, list(params.values())) if not existing: logger.info(f"Creating pagination index: {index['description']}") @@ -1217,7 +1218,7 @@ class PostgreSQLDB: async def query( self, sql: str, - params: dict[str, Any] | None = None, + params: list[Any] | None = None, multirows: bool = False, with_age: bool = False, graph_name: str | None = None, @@ -1230,7 +1231,7 @@ class PostgreSQLDB: try: if params: - rows = await connection.fetch(sql, *params.values()) + rows = await connection.fetch(sql, *params) else: rows = await connection.fetch(sql) @@ -1446,7 +1447,7 @@ class PGKVStorage(BaseKVStorage): params = {"workspace": self.workspace} try: - results = await self.db.query(sql, params, multirows=True) + results = await self.db.query(sql, list(params.values()), multirows=True) # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): @@ -1540,7 +1541,7 @@ class PGKVStorage(BaseKVStorage): """Get data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.workspace, "id": id} - response = await self.db.query(sql, params) + response = await self.db.query(sql, list(params.values())) if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): # Parse llm_cache_list JSON string back to list @@ -1620,7 +1621,7 @@ class PGKVStorage(BaseKVStorage): ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.workspace} - results = await self.db.query(sql, params, multirows=True) + results = await self.db.query(sql, list(params.values()), multirows=True) if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): # Parse llm_cache_list JSON string back to list for each result @@ -1708,7 +1709,7 @@ class PGKVStorage(BaseKVStorage): ) params = {"workspace": self.workspace} try: - res = await self.db.query(sql, params, multirows=True) + res = await self.db.query(sql, list(params.values()), multirows=True) if res: exist_keys = [key["id"] for key in res] else: @@ -2019,7 +2020,7 @@ class PGVectorStorage(BaseVectorStorage): "closer_than_threshold": 1 - self.cosine_better_than_threshold, "top_k": top_k, } - results = await self.db.query(sql, params=params, multirows=True) + results = await self.db.query(sql, params=list(params.values()), multirows=True) return results async def index_done_callback(self) -> None: @@ -2116,7 +2117,7 @@ class PGVectorStorage(BaseVectorStorage): params = {"workspace": self.workspace, "id": id} try: - result = await self.db.query(query, params) + result = await self.db.query(query, list(params.values())) if result: return dict(result) return None @@ -2150,7 +2151,7 @@ class PGVectorStorage(BaseVectorStorage): params = {"workspace": self.workspace} try: - results = await self.db.query(query, params, multirows=True) + results = await self.db.query(query, list(params.values()), multirows=True) return [dict(record) for record in results] except Exception as e: logger.error( @@ -2183,7 +2184,7 @@ class PGVectorStorage(BaseVectorStorage): params = {"workspace": self.workspace} try: - results = await self.db.query(query, params, multirows=True) + results = await self.db.query(query, list(params.values()), multirows=True) vectors_dict = {} for result in results: @@ -2270,7 +2271,7 @@ class PGDocStatusStorage(DocStatusStorage): ) params = {"workspace": self.workspace} try: - res = await self.db.query(sql, params, multirows=True) + res = await self.db.query(sql, list(params.values()), multirows=True) if res: exist_keys = [key["id"] for key in res] else: @@ -2288,7 +2289,7 @@ class PGDocStatusStorage(DocStatusStorage): async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" params = {"workspace": self.workspace, "id": id} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) if result is None or result == []: return None else: @@ -2334,7 +2335,7 @@ class PGDocStatusStorage(DocStatusStorage): sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)" params = {"workspace": self.workspace, "ids": ids} - results = await self.db.query(sql, params, True) + results = await self.db.query(sql, list(params.values()), True) if not results: return [] @@ -2385,7 +2386,8 @@ class PGDocStatusStorage(DocStatusStorage): FROM LIGHTRAG_DOC_STATUS where workspace=$1 GROUP BY STATUS """ - result = await self.db.query(sql, {"workspace": self.workspace}, True) + params = {"workspace": self.workspace} + result = await self.db.query(sql, list(params.values()), True) counts = {} for doc in result: counts[doc["status"]] = doc["count"] @@ -2397,7 +2399,7 @@ class PGDocStatusStorage(DocStatusStorage): """all documents with a specific status""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" params = {"workspace": self.workspace, "status": status.value} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) docs_by_status = {} for element in result: @@ -2451,7 +2453,7 @@ class PGDocStatusStorage(DocStatusStorage): """Get all documents with a specific track_id""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2" params = {"workspace": self.workspace, "track_id": track_id} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) docs_by_track_id = {} for element in result: @@ -2551,7 +2553,7 @@ class PGDocStatusStorage(DocStatusStorage): # Query for total count count_sql = f"SELECT COUNT(*) as total FROM LIGHTRAG_DOC_STATUS {where_clause}" - count_result = await self.db.query(count_sql, params) + count_result = await self.db.query(count_sql, list(params.values())) total_count = count_result["total"] if count_result else 0 # Query for paginated data @@ -2564,7 +2566,7 @@ class PGDocStatusStorage(DocStatusStorage): params["limit"] = page_size params["offset"] = offset - result = await self.db.query(data_sql, params, True) + result = await self.db.query(data_sql, list(params.values()), True) # Convert to (doc_id, DocProcessingStatus) tuples documents = [] @@ -2621,7 +2623,7 @@ class PGDocStatusStorage(DocStatusStorage): GROUP BY status """ params = {"workspace": self.workspace} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) counts = {} total_count = 0 @@ -3067,7 +3069,7 @@ class PGGraphStorage(BaseGraphStorage): if readonly: data = await self.db.query( query, - params, + list(params.values()) if params else None, multirows=True, with_age=True, graph_name=self.graph_name,