From 874ddda60543f5e25a3a72aaebbf628e5a9052b2 Mon Sep 17 00:00:00 2001 From: Matt23-star Date: Wed, 20 Aug 2025 15:59:05 +0800 Subject: [PATCH 1/4] feat: remove unused parameter from query methods across multiple implementations --- lightrag/base.py | 2 +- lightrag/kg/deprecated/chroma_impl.py | 2 +- lightrag/kg/faiss_impl.py | 2 +- lightrag/kg/milvus_impl.py | 2 +- lightrag/kg/mongo_impl.py | 2 +- lightrag/kg/nano_vector_db_impl.py | 2 +- lightrag/kg/postgres_impl.py | 93 ++++++--------------------- lightrag/kg/qdrant_impl.py | 2 +- lightrag/operate.py | 6 +- 9 files changed, 29 insertions(+), 84 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 9ba34280..dacfbd90 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -219,7 +219,7 @@ class BaseVectorStorage(StorageNameSpace, ABC): @abstractmethod async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int ) -> list[dict[str, Any]]: """Query the vector storage and retrieve top_k results.""" diff --git a/lightrag/kg/deprecated/chroma_impl.py b/lightrag/kg/deprecated/chroma_impl.py index ebdd4593..a6c43504 100644 --- a/lightrag/kg/deprecated/chroma_impl.py +++ b/lightrag/kg/deprecated/chroma_impl.py @@ -165,7 +165,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): raise async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int ) -> list[dict[str, Any]]: try: embedding = await self.embedding_func( diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 5098ebf7..5687834d 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -180,7 +180,7 @@ class FaissVectorDBStorage(BaseVectorStorage): return [m["__id__"] for m in list_data] async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int ) -> list[dict[str, Any]]: """ Search by a textual query; returns top_k results with their metadata + similarity distance. diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 4d927353..6747bb2d 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -810,7 +810,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): return results async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int ) -> list[dict[str, Any]]: # Ensure collection is loaded before querying self._ensure_collection_loaded() diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 8fa53c60..0c164bd2 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1771,7 +1771,7 @@ class MongoVectorDBStorage(BaseVectorStorage): return list_data async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int ) -> list[dict[str, Any]]: """Queries the vector database using Atlas Vector Search.""" # Generate the embedding diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 5bec06f4..19352a4a 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -137,7 +137,7 @@ class NanoVectorDBStorage(BaseVectorStorage): ) async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int ) -> list[dict[str, Any]]: # Execute embedding outside of lock to avoid improve cocurrent embedding = await self.embedding_func( diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 3ab8bfb8..46e8e6e6 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -2005,7 +2005,7 @@ class PGVectorStorage(BaseVectorStorage): #################### query method ############### async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int ) -> list[dict[str, Any]]: embeddings = await self.embedding_func( [query], _priority=5 @@ -2016,7 +2016,6 @@ class PGVectorStorage(BaseVectorStorage): sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) params = { "workspace": self.workspace, - "doc_ids": ids, "closer_than_threshold": 1 - self.cosine_better_than_threshold, "top_k": top_k, } @@ -4578,85 +4577,31 @@ SQL_TEMPLATES = { update_time = EXCLUDED.update_time """, "relationships": """ - WITH relevant_chunks AS (SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2 - :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar []) - ) - , rc AS ( - SELECT array_agg(chunk_id) AS chunk_arr - FROM relevant_chunks - ), cand AS ( - SELECT - r.id, r.source_id AS src_id, r.target_id AS tgt_id, r.chunk_ids, r.create_time, r.content_vector <=> '[{embedding_string}]'::vector AS dist - FROM LIGHTRAG_VDB_RELATION r - WHERE r.workspace = $1 - ORDER BY r.content_vector <=> '[{embedding_string}]'::vector - LIMIT ($4 * 50) - ) - SELECT c.src_id, - c.tgt_id, - EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at - FROM cand c - JOIN rc ON TRUE - WHERE c.dist < $3 - AND c.chunk_ids && (rc.chunk_arr::varchar[]) - ORDER BY c.dist, c.id - LIMIT $4; + SELECT r.source_id as src_id, r.target_id as tgt_id, + EXTRACT(EPOCH FROM r.create_time)::BIGINT as created_at + FROM LIGHTRAG_VDB_RELATION r + WHERE r.workspace = $1 + AND r.content_vector <=> '[{embedding_string}]'::vector < $2 + ORDER BY r.content_vector <=> '[{embedding_string}]'::vector + LIMIT $3 """, "entities": """ - WITH relevant_chunks AS (SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2 - :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar []) - ) - , rc AS ( - SELECT array_agg(chunk_id) AS chunk_arr - FROM relevant_chunks - ), cand AS ( - SELECT - e.id, e.entity_name, e.chunk_ids, e.create_time, e.content_vector <=> '[{embedding_string}]'::vector AS dist + SELECT e.entity_name, + EXTRACT(EPOCH FROM e.create_time)::BIGINT as created_at FROM LIGHTRAG_VDB_ENTITY e WHERE e.workspace = $1 + AND e.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY e.content_vector <=> '[{embedding_string}]'::vector - LIMIT ($4 * 50) - ) - SELECT c.entity_name, - EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at - FROM cand c - JOIN rc ON TRUE - WHERE c.dist < $3 - AND c.chunk_ids && (rc.chunk_arr::varchar[]) - ORDER BY c.dist, c.id - LIMIT $4; + LIMIT $3 """, "chunks": """ - WITH relevant_chunks AS (SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2 - :: varchar [] IS NULL OR full_doc_id = ANY ($2:: varchar []) - ) - , rc AS ( - SELECT array_agg(chunk_id) AS chunk_arr - FROM relevant_chunks - ), cand AS ( - SELECT - id, content, file_path, create_time, content_vector <=> '[{embedding_string}]'::vector AS dist - FROM LIGHTRAG_VDB_CHUNKS - WHERE workspace = $1 - ORDER BY content_vector <=> '[{embedding_string}]'::vector - LIMIT ($4 * 50) - ) - SELECT c.id, - c.content, - c.file_path, - EXTRACT(EPOCH FROM c.create_time) ::BIGINT AS created_at - FROM cand c - JOIN rc ON TRUE - WHERE c.dist < $3 - AND c.id = ANY (rc.chunk_arr) - ORDER BY c.dist, c.id - LIMIT $4; + SELECT id, content, file_path, + EXTRACT(EPOCH FROM create_time)::BIGINT as created_at + FROM LIGHTRAG_VDB_CHUNKS + WHERE workspace = $1 + AND content_vector <=> '[{embedding_string}]'::vector < $2 + ORDER BY content_vector <=> '[{embedding_string}]'::vector + LIMIT $3 """, # DROP tables "drop_specifiy_table_workspace": """ diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 4ece163c..e8565ac7 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -200,7 +200,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): return results async def query( - self, query: str, top_k: int, ids: list[str] | None = None + self, query: str, top_k: int ) -> list[dict[str, Any]]: embedding = await self.embedding_func( [query], _priority=5 diff --git a/lightrag/operate.py b/lightrag/operate.py index acb75f0f..0876e06c 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2055,7 +2055,7 @@ async def _get_vector_context( # Use chunk_top_k if specified, otherwise fall back to top_k search_top_k = query_param.chunk_top_k or query_param.top_k - results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids) + results = await chunks_vdb.query(query, top_k=search_top_k) if not results: return [] @@ -2599,7 +2599,7 @@ async def _get_node_data( ) results = await entities_vdb.query( - query, top_k=query_param.top_k, ids=query_param.ids + query, top_k=query_param.top_k ) if not len(results): @@ -2875,7 +2875,7 @@ async def _get_edge_data( ) results = await relationships_vdb.query( - keywords, top_k=query_param.top_k, ids=query_param.ids + keywords, top_k=query_param.top_k ) if not len(results): From 9804a1885b1c1bbf9f4afc3b59c88eb0f50f6fc3 Mon Sep 17 00:00:00 2001 From: Matt23-star Date: Thu, 28 Aug 2025 16:17:35 -0700 Subject: [PATCH 2/4] feat: refactor parameter handling in database queries to use lists for improved consistency --- lightrag/kg/postgres_impl.py | 58 +++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index b238b7a1..8848f0fd 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -787,13 +787,13 @@ class PostgreSQLDB: FROM information_schema.columns WHERE table_name = $1 AND column_name = $2 """ - - column_info = await self.query( - check_column_sql, - { + params = { "table_name": migration["table"].lower(), "column_name": migration["column"], - }, + } + column_info = await self.query( + check_column_sql, + list(params.values()), ) if not column_info: @@ -1035,9 +1035,9 @@ class PostgreSQLDB: WHERE table_name = $1 AND table_schema = 'public' """ - + params = {"table_name": table_name.lower()} table_exists = await self.query( - check_table_sql, {"table_name": table_name.lower()} + check_table_sql, list(params.values()) ) if not table_exists: @@ -1121,7 +1121,8 @@ class PostgreSQLDB: AND indexname = $1 """ - existing = await self.query(check_sql, {"indexname": index["name"]}) + params = {"indexname": index["name"]} + existing = await self.query(check_sql, list(params.values())) if not existing: logger.info(f"Creating pagination index: {index['description']}") @@ -1217,7 +1218,7 @@ class PostgreSQLDB: async def query( self, sql: str, - params: dict[str, Any] | None = None, + params: list[Any] | None = None, multirows: bool = False, with_age: bool = False, graph_name: str | None = None, @@ -1230,7 +1231,7 @@ class PostgreSQLDB: try: if params: - rows = await connection.fetch(sql, *params.values()) + rows = await connection.fetch(sql, *params) else: rows = await connection.fetch(sql) @@ -1446,7 +1447,7 @@ class PGKVStorage(BaseKVStorage): params = {"workspace": self.workspace} try: - results = await self.db.query(sql, params, multirows=True) + results = await self.db.query(sql, list(params.values()), multirows=True) # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): @@ -1540,7 +1541,7 @@ class PGKVStorage(BaseKVStorage): """Get data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.workspace, "id": id} - response = await self.db.query(sql, params) + response = await self.db.query(sql, list(params.values())) if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): # Parse llm_cache_list JSON string back to list @@ -1620,7 +1621,7 @@ class PGKVStorage(BaseKVStorage): ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.workspace} - results = await self.db.query(sql, params, multirows=True) + results = await self.db.query(sql, list(params.values()), multirows=True) if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): # Parse llm_cache_list JSON string back to list for each result @@ -1708,7 +1709,7 @@ class PGKVStorage(BaseKVStorage): ) params = {"workspace": self.workspace} try: - res = await self.db.query(sql, params, multirows=True) + res = await self.db.query(sql, list(params.values()), multirows=True) if res: exist_keys = [key["id"] for key in res] else: @@ -2019,7 +2020,7 @@ class PGVectorStorage(BaseVectorStorage): "closer_than_threshold": 1 - self.cosine_better_than_threshold, "top_k": top_k, } - results = await self.db.query(sql, params=params, multirows=True) + results = await self.db.query(sql, params=list(params.values()), multirows=True) return results async def index_done_callback(self) -> None: @@ -2116,7 +2117,7 @@ class PGVectorStorage(BaseVectorStorage): params = {"workspace": self.workspace, "id": id} try: - result = await self.db.query(query, params) + result = await self.db.query(query, list(params.values())) if result: return dict(result) return None @@ -2150,7 +2151,7 @@ class PGVectorStorage(BaseVectorStorage): params = {"workspace": self.workspace} try: - results = await self.db.query(query, params, multirows=True) + results = await self.db.query(query, list(params.values()), multirows=True) return [dict(record) for record in results] except Exception as e: logger.error( @@ -2183,7 +2184,7 @@ class PGVectorStorage(BaseVectorStorage): params = {"workspace": self.workspace} try: - results = await self.db.query(query, params, multirows=True) + results = await self.db.query(query, list(params.values()), multirows=True) vectors_dict = {} for result in results: @@ -2270,7 +2271,7 @@ class PGDocStatusStorage(DocStatusStorage): ) params = {"workspace": self.workspace} try: - res = await self.db.query(sql, params, multirows=True) + res = await self.db.query(sql, list(params.values()), multirows=True) if res: exist_keys = [key["id"] for key in res] else: @@ -2288,7 +2289,7 @@ class PGDocStatusStorage(DocStatusStorage): async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" params = {"workspace": self.workspace, "id": id} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) if result is None or result == []: return None else: @@ -2334,7 +2335,7 @@ class PGDocStatusStorage(DocStatusStorage): sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)" params = {"workspace": self.workspace, "ids": ids} - results = await self.db.query(sql, params, True) + results = await self.db.query(sql, list(params.values()), True) if not results: return [] @@ -2385,7 +2386,8 @@ class PGDocStatusStorage(DocStatusStorage): FROM LIGHTRAG_DOC_STATUS where workspace=$1 GROUP BY STATUS """ - result = await self.db.query(sql, {"workspace": self.workspace}, True) + params = {"workspace": self.workspace} + result = await self.db.query(sql, list(params.values()), True) counts = {} for doc in result: counts[doc["status"]] = doc["count"] @@ -2397,7 +2399,7 @@ class PGDocStatusStorage(DocStatusStorage): """all documents with a specific status""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" params = {"workspace": self.workspace, "status": status.value} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) docs_by_status = {} for element in result: @@ -2451,7 +2453,7 @@ class PGDocStatusStorage(DocStatusStorage): """Get all documents with a specific track_id""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2" params = {"workspace": self.workspace, "track_id": track_id} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) docs_by_track_id = {} for element in result: @@ -2551,7 +2553,7 @@ class PGDocStatusStorage(DocStatusStorage): # Query for total count count_sql = f"SELECT COUNT(*) as total FROM LIGHTRAG_DOC_STATUS {where_clause}" - count_result = await self.db.query(count_sql, params) + count_result = await self.db.query(count_sql, list(params.values())) total_count = count_result["total"] if count_result else 0 # Query for paginated data @@ -2564,7 +2566,7 @@ class PGDocStatusStorage(DocStatusStorage): params["limit"] = page_size params["offset"] = offset - result = await self.db.query(data_sql, params, True) + result = await self.db.query(data_sql, list(params.values()), True) # Convert to (doc_id, DocProcessingStatus) tuples documents = [] @@ -2621,7 +2623,7 @@ class PGDocStatusStorage(DocStatusStorage): GROUP BY status """ params = {"workspace": self.workspace} - result = await self.db.query(sql, params, True) + result = await self.db.query(sql, list(params.values()), True) counts = {} total_count = 0 @@ -3067,7 +3069,7 @@ class PGGraphStorage(BaseGraphStorage): if readonly: data = await self.db.query( query, - params, + list(params.values()) if params else None, multirows=True, with_age=True, graph_name=self.graph_name, From aa1ef3f05344169aacf9c0082b56ecc829a573ba Mon Sep 17 00:00:00 2001 From: Matt23-star Date: Thu, 28 Aug 2025 16:18:15 -0700 Subject: [PATCH 3/4] feat: optimize database query methods for improved performance and readability --- lightrag/kg/postgres_impl.py | 145 +++++++++++++++-------------------- 1 file changed, 62 insertions(+), 83 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 8848f0fd..dd1ff8ae 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -3100,114 +3100,93 @@ class PGGraphStorage(BaseGraphStorage): return result async def has_node(self, node_id: str) -> bool: - entity_name_label = self._normalize_node_id(node_id) + query = f""" + SELECT EXISTS ( + SELECT 1 + FROM {self.graph_name}.base + WHERE ag_catalog.agtype_access_operator( + VARIADIC ARRAY[properties, '"entity_id"'::agtype] + ) = (to_json($1::text)::text)::agtype + LIMIT 1 + ) AS node_exists; + """ - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) - RETURN count(n) > 0 AS node_exists - $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) - - single_result = (await self._query(query))[0] - - return single_result["node_exists"] + params = {"node_id": node_id} + row = (await self._query(query, params=params))[0] + return bool(row["node_exists"]) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - src_label = self._normalize_node_id(source_node_id) - tgt_label = self._normalize_node_id(target_node_id) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) - RETURN COUNT(r) > 0 AS edge_exists - $$) AS (edge_exists bool)""" % ( - self.graph_name, - src_label, - tgt_label, - ) - - single_result = (await self._query(query))[0] - - return single_result["edge_exists"] + query = f""" + WITH a AS ( + SELECT id AS vid + FROM {self.graph_name}.base + WHERE ag_catalog.agtype_access_operator( + VARIADIC ARRAY[properties, '"entity_id"'::agtype] + ) = (to_json($1::text)::text)::agtype + ), + b AS ( + SELECT id AS vid + FROM {self.graph_name}.base + WHERE ag_catalog.agtype_access_operator( + VARIADIC ARRAY[properties, '"entity_id"'::agtype] + ) = (to_json($2::text)::text)::agtype + ) + SELECT EXISTS ( + SELECT 1 + FROM {self.graph_name}."DIRECTED" d + JOIN a ON d.start_id = a.vid + JOIN b ON d.end_id = b.vid + LIMIT 1 + ) + OR EXISTS ( + SELECT 1 + FROM {self.graph_name}."DIRECTED" d + JOIN a ON d.end_id = a.vid + JOIN b ON d.start_id = b.vid + LIMIT 1 + ) AS edge_exists; + """ + params = { + "source_node_id": source_node_id, + "target_node_id": target_node_id, + } + row = (await self._query(query, params=params))[0] + return bool(row["edge_exists"]) async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties""" label = self._normalize_node_id(node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) - RETURN n - $$) AS (n agtype)""" % (self.graph_name, label) - record = await self._query(query) - if record: - node = record[0] - node_dict = node["n"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(node_dict, str): - try: - node_dict = json.loads(node_dict) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse node string: {node_dict}" - ) - - return node_dict + result = await self.get_nodes_batch(node_ids=[label]) + if result and node_id in result: + return result[node_id] return None async def node_degree(self, node_id: str) -> int: label = self._normalize_node_id(node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"})-[r]-() - RETURN count(r) AS total_edge_count - $$) AS (total_edge_count integer)""" % (self.graph_name, label) - record = (await self._query(query))[0] - if record: - edge_count = int(record["total_edge_count"]) - return edge_count + result = await self.node_degrees_batch(node_ids=[label]) + if result and node_id in result: + return result[node_id] async def edge_degree(self, src_id: str, tgt_id: str) -> int: - src_degree = await self.node_degree(src_id) - trg_degree = await self.node_degree(tgt_id) - # Convert None to 0 for addition - src_degree = 0 if src_degree is None else src_degree - trg_degree = 0 if trg_degree is None else trg_degree - - degrees = int(src_degree) + int(trg_degree) - - return degrees + result = await self.edge_degrees_batch(edges=[(src_id, tgt_id)]) + if result and (src_id, tgt_id) in result: + return result[(src_id, tgt_id)] async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: """Get edge properties between two nodes""" - src_label = self._normalize_node_id(source_node_id) tgt_label = self._normalize_node_id(target_node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) - RETURN properties(r) as edge_properties - LIMIT 1 - $$) AS (edge_properties agtype)""" % ( - self.graph_name, - src_label, - tgt_label, - ) - record = await self._query(query) - if record and record[0] and record[0]["edge_properties"]: - result = record[0]["edge_properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(result, str): - try: - result = json.loads(result) - except json.JSONDecodeError: - logger.warning( - f"[{self.workspace}] Failed to parse edge string: {result}" - ) - - return result + result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}]) + if result and (src_label, tgt_label) in result: + return result[(src_label, tgt_label)] + return None async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ From 24cb11f3f5ace4a2d3f36cde9d62b277c80b0c50 Mon Sep 17 00:00:00 2001 From: Matt23-star Date: Fri, 29 Aug 2025 21:09:14 -0700 Subject: [PATCH 4/4] style: ruff-format --- lightrag/kg/deprecated/chroma_impl.py | 4 +--- lightrag/kg/postgres_impl.py | 11 ++++------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/lightrag/kg/deprecated/chroma_impl.py b/lightrag/kg/deprecated/chroma_impl.py index a6c43504..75a7d4bf 100644 --- a/lightrag/kg/deprecated/chroma_impl.py +++ b/lightrag/kg/deprecated/chroma_impl.py @@ -164,9 +164,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"Error during ChromaDB upsert: {str(e)}") raise - async def query( - self, query: str, top_k: int - ) -> list[dict[str, Any]]: + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: try: embedding = await self.embedding_func( [query], _priority=5 diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 55cc6e06..5e4a4813 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -788,9 +788,9 @@ class PostgreSQLDB: WHERE table_name = $1 AND column_name = $2 """ params = { - "table_name": migration["table"].lower(), - "column_name": migration["column"], - } + "table_name": migration["table"].lower(), + "column_name": migration["column"], + } column_info = await self.query( check_column_sql, list(params.values()), @@ -1036,9 +1036,7 @@ class PostgreSQLDB: AND table_schema = 'public' """ params = {"table_name": table_name.lower()} - table_exists = await self.query( - check_table_sql, list(params.values()) - ) + table_exists = await self.query(check_table_sql, list(params.values())) if not table_exists: logger.info(f"Creating table {table_name}") @@ -3175,7 +3173,6 @@ class PGGraphStorage(BaseGraphStorage): return result[node_id] async def edge_degree(self, src_id: str, tgt_id: str) -> int: - result = await self.edge_degrees_batch(edges=[(src_id, tgt_id)]) if result and (src_id, tgt_id) in result: return result[(src_id, tgt_id)]