diff --git a/lightrag/kg/deprecated/chroma_impl.py b/lightrag/kg/deprecated/chroma_impl.py index ebdd4593..75a7d4bf 100644 --- a/lightrag/kg/deprecated/chroma_impl.py +++ b/lightrag/kg/deprecated/chroma_impl.py @@ -164,9 +164,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"Error during ChromaDB upsert: {str(e)}") raise - async def query( - self, query: str, top_k: int, ids: list[str] | None = None - ) -> list[dict[str, Any]]: + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: try: embedding = await self.embedding_func( [query], _priority=5 diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 03a26f54..5e4a4813 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -787,13 +787,13 @@ class PostgreSQLDB: FROM information_schema.columns WHERE table_name = $1 AND column_name = $2 """ - + params = { + "table_name": migration["table"].lower(), + "column_name": migration["column"], + } column_info = await self.query( check_column_sql, - { - "table_name": migration["table"].lower(), - "column_name": migration["column"], - }, + list(params.values()), ) if not column_info: @@ -1035,10 +1035,8 @@ class PostgreSQLDB: WHERE table_name = $1 AND table_schema = 'public' """ - - table_exists = await self.query( - check_table_sql, {"table_name": table_name.lower()} - ) + params = {"table_name": table_name.lower()} + table_exists = await self.query(check_table_sql, list(params.values())) if not table_exists: logger.info(f"Creating table {table_name}") @@ -1121,7 +1119,8 @@ class PostgreSQLDB: AND indexname = $1 """ - existing = await self.query(check_sql, {"indexname": index["name"]}) + params = {"indexname": index["name"]} + existing = await self.query(check_sql, list(params.values())) if not existing: logger.info(f"Creating pagination index: {index['description']}") @@ -1217,7 +1216,7 @@ class PostgreSQLDB: async def query( self, sql: str, - params: dict[str, Any] | None = None, + params: list[Any] | None = None, multirows: bool = False, with_age: bool = False, graph_name: str | None = None, @@ -1230,7 +1229,7 @@ class PostgreSQLDB: try: if params: - rows = await connection.fetch(sql, *params.values()) + rows = await connection.fetch(sql, *params) else: rows = await connection.fetch(sql) @@ -1446,7 +1445,7 @@ class PGKVStorage(BaseKVStorage): params = {"workspace": self.workspace} try: - results = await self.db.query(sql, params, multirows=True) + results = await self.db.query(sql, list(params.values()), multirows=True) # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): @@ -1540,7 +1539,7 @@ class PGKVStorage(BaseKVStorage): """Get data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.workspace, "id": id} - response = await self.db.query(sql, params) + response = await self.db.query(sql, list(params.values())) if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): # Parse llm_cache_list JSON string back to list @@ -1620,7 +1619,7 @@ class PGKVStorage(BaseKVStorage): ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.workspace} - results = await self.db.query(sql, params, multirows=True) + results = await self.db.query(sql, list(params.values()), multirows=True) if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): # Parse llm_cache_list JSON string back to list for each result @@ -1708,7 +1707,7 @@ class PGKVStorage(BaseKVStorage): ) params = {"workspace": self.workspace} try: - res = await self.db.query(sql, params, multirows=True) + res = await self.db.query(sql, list(params.values()), multirows=True) if res: exist_keys = [key["id"] for key in res] else: @@ -2023,7 +2022,7 @@ class PGVectorStorage(BaseVectorStorage): "closer_than_threshold": 1 - self.cosine_better_than_threshold, "top_k": top_k, } - results = await self.db.query(sql, params=params, multirows=True) + results = await self.db.query(sql, params=list(params.values()), multirows=True) return results async def index_done_callback(self) -> None: @@ -2120,7 +2119,7 @@ class PGVectorStorage(BaseVectorStorage): params = {"workspace": self.workspace, "id": id} try: - result = await self.db.query(query, params) + result = await self.db.query(query, list(params.values())) if result: return dict(result) return None @@ -2154,7 +2153,7 @@ class PGVectorStorage(BaseVectorStorage): params = {"workspace": self.workspace} try: - results = await self.db.query(query, params, multirows=True) + results = await self.db.query(query, list(params.values()), multirows=True) return [dict(record) for record in results] except Exception as e: logger.error( @@ -2187,7 +2186,7 @@ class PGVectorStorage(BaseVectorStorage): params = {"workspace": self.workspace} try: - results = await self.db.query(query, params, multirows=True) + results = await self.db.query(query, list(params.values()), multirows=True) vectors_dict = {} for result in results: @@ -2274,7 +2273,7 @@ class PGDocStatusStorage(DocStatusStorage): ) params = {"workspace": self.workspace} try: - res = await self.db.query(sql, params, multirows=True) + res = await self.db.query(sql, list(params.values()), multirows=True) if res: exist_keys = [key["id"] for key in res] else: @@ -2292,7 +2291,7 @@ class PGDocStatusStorage(DocStatusStorage): async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" params = {"workspace": self.workspace, "id": id} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) if result is None or result == []: return None else: @@ -2338,7 +2337,7 @@ class PGDocStatusStorage(DocStatusStorage): sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)" params = {"workspace": self.workspace, "ids": ids} - results = await self.db.query(sql, params, True) + results = await self.db.query(sql, list(params.values()), True) if not results: return [] @@ -2389,7 +2388,8 @@ class PGDocStatusStorage(DocStatusStorage): FROM LIGHTRAG_DOC_STATUS where workspace=$1 GROUP BY STATUS """ - result = await self.db.query(sql, {"workspace": self.workspace}, True) + params = {"workspace": self.workspace} + result = await self.db.query(sql, list(params.values()), True) counts = {} for doc in result: counts[doc["status"]] = doc["count"] @@ -2401,7 +2401,7 @@ class PGDocStatusStorage(DocStatusStorage): """all documents with a specific status""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" params = {"workspace": self.workspace, "status": status.value} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) docs_by_status = {} for element in result: @@ -2455,7 +2455,7 @@ class PGDocStatusStorage(DocStatusStorage): """Get all documents with a specific track_id""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2" params = {"workspace": self.workspace, "track_id": track_id} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) docs_by_track_id = {} for element in result: @@ -2555,7 +2555,7 @@ class PGDocStatusStorage(DocStatusStorage): # Query for total count count_sql = f"SELECT COUNT(*) as total FROM LIGHTRAG_DOC_STATUS {where_clause}" - count_result = await self.db.query(count_sql, params) + count_result = await self.db.query(count_sql, list(params.values())) total_count = count_result["total"] if count_result else 0 # Query for paginated data @@ -2568,7 +2568,7 @@ class PGDocStatusStorage(DocStatusStorage): params["limit"] = page_size params["offset"] = offset - result = await self.db.query(data_sql, params, True) + result = await self.db.query(data_sql, list(params.values()), True) # Convert to (doc_id, DocProcessingStatus) tuples documents = [] @@ -2625,7 +2625,7 @@ class PGDocStatusStorage(DocStatusStorage): GROUP BY status """ params = {"workspace": self.workspace} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) counts = {} total_count = 0 @@ -3071,7 +3071,7 @@ class PGGraphStorage(BaseGraphStorage): if readonly: data = await self.db.query( query, - params, + list(params.values()) if params else None, multirows=True, with_age=True, graph_name=self.graph_name, @@ -3102,114 +3102,92 @@ class PGGraphStorage(BaseGraphStorage): return result async def has_node(self, node_id: str) -> bool: - entity_name_label = self._normalize_node_id(node_id) + query = f""" + SELECT EXISTS ( + SELECT 1 + FROM {self.graph_name}.base + WHERE ag_catalog.agtype_access_operator( + VARIADIC ARRAY[properties, '"entity_id"'::agtype] + ) = (to_json($1::text)::text)::agtype + LIMIT 1 + ) AS node_exists; + """ - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) - RETURN count(n) > 0 AS node_exists - $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) - - single_result = (await self._query(query))[0] - - return single_result["node_exists"] + params = {"node_id": node_id} + row = (await self._query(query, params=params))[0] + return bool(row["node_exists"]) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - src_label = self._normalize_node_id(source_node_id) - tgt_label = self._normalize_node_id(target_node_id) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) - RETURN COUNT(r) > 0 AS edge_exists - $$) AS (edge_exists bool)""" % ( - self.graph_name, - src_label, - tgt_label, - ) - - single_result = (await self._query(query))[0] - - return single_result["edge_exists"] + query = f""" + WITH a AS ( + SELECT id AS vid + FROM {self.graph_name}.base + WHERE ag_catalog.agtype_access_operator( + VARIADIC ARRAY[properties, '"entity_id"'::agtype] + ) = (to_json($1::text)::text)::agtype + ), + b AS ( + SELECT id AS vid + FROM {self.graph_name}.base + WHERE ag_catalog.agtype_access_operator( + VARIADIC ARRAY[properties, '"entity_id"'::agtype] + ) = (to_json($2::text)::text)::agtype + ) + SELECT EXISTS ( + SELECT 1 + FROM {self.graph_name}."DIRECTED" d + JOIN a ON d.start_id = a.vid + JOIN b ON d.end_id = b.vid + LIMIT 1 + ) + OR EXISTS ( + SELECT 1 + FROM {self.graph_name}."DIRECTED" d + JOIN a ON d.end_id = a.vid + JOIN b ON d.start_id = b.vid + LIMIT 1 + ) AS edge_exists; + """ + params = { + "source_node_id": source_node_id, + "target_node_id": target_node_id, + } + row = (await self._query(query, params=params))[0] + return bool(row["edge_exists"]) async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties""" label = self._normalize_node_id(node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) - RETURN n - $$) AS (n agtype)""" % (self.graph_name, label) - record = await self._query(query) - if record: - node = record[0] - node_dict = node["n"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(node_dict, str): - try: - node_dict = json.loads(node_dict) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string: {node_dict}" - ) - - return node_dict + result = await self.get_nodes_batch(node_ids=[label]) + if result and node_id in result: + return result[node_id] return None async def node_degree(self, node_id: str) -> int: label = self._normalize_node_id(node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"})-[r]-() - RETURN count(r) AS total_edge_count - $$) AS (total_edge_count integer)""" % (self.graph_name, label) - record = (await self._query(query))[0] - if record: - edge_count = int(record["total_edge_count"]) - return edge_count + result = await self.node_degrees_batch(node_ids=[label]) + if result and node_id in result: + return result[node_id] async def edge_degree(self, src_id: str, tgt_id: str) -> int: - src_degree = await self.node_degree(src_id) - trg_degree = await self.node_degree(tgt_id) - - # Convert None to 0 for addition - src_degree = 0 if src_degree is None else src_degree - trg_degree = 0 if trg_degree is None else trg_degree - - degrees = int(src_degree) + int(trg_degree) - - return degrees + result = await self.edge_degrees_batch(edges=[(src_id, tgt_id)]) + if result and (src_id, tgt_id) in result: + return result[(src_id, tgt_id)] async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: """Get edge properties between two nodes""" - src_label = self._normalize_node_id(source_node_id) tgt_label = self._normalize_node_id(target_node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) - RETURN properties(r) as edge_properties - LIMIT 1 - $$) AS (edge_properties agtype)""" % ( - self.graph_name, - src_label, - tgt_label, - ) - record = await self._query(query) - if record and record[0] and record[0]["edge_properties"]: - result = record[0]["edge_properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(result, str): - try: - result = json.loads(result) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse edge string: {result}" - ) - - return result + result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}]) + if result and (src_label, tgt_label) in result: + return result[(src_label, tgt_label)] + return None async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """