From 9be22dd666fe836932199e0621b850f917d5a327 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 11 Oct 2025 12:37:59 +0800 Subject: [PATCH] Preserve ordering in get_by_ids methods across all storage implementations - Fix result ordering in vector stores - Update KV storage get_by_ids methods - Maintain order in doc status storage - Return None for missing IDs --- lightrag/kg/deprecated/chroma_impl.py | 18 ++++--- lightrag/kg/json_doc_status_impl.py | 8 +-- lightrag/kg/milvus_impl.py | 17 +++++- lightrag/kg/mongo_impl.py | 40 +++++++++++--- lightrag/kg/nano_vector_db_impl.py | 19 +++++-- lightrag/kg/postgres_impl.py | 78 ++++++++++++++++++++------- lightrag/kg/qdrant_impl.py | 26 +++++++-- lightrag/kg/redis_impl.py | 10 ++-- 8 files changed, 165 insertions(+), 51 deletions(-) diff --git a/lightrag/kg/deprecated/chroma_impl.py b/lightrag/kg/deprecated/chroma_impl.py index 75a7d4bf..54bf9037 100644 --- a/lightrag/kg/deprecated/chroma_impl.py +++ b/lightrag/kg/deprecated/chroma_impl.py @@ -295,17 +295,23 @@ class ChromaVectorDBStorage(BaseVectorStorage): if not result or not result["ids"] or len(result["ids"]) == 0: return [] - # Format the results to match the expected structure - return [ - { - "id": result["ids"][i], + # Format the results to match the expected structure and preserve ordering + formatted_map: dict[str, dict[str, Any]] = {} + for i, result_id in enumerate(result["ids"]): + record = { + "id": result_id, "vector": result["embeddings"][i], "content": result["documents"][i], "created_at": result["metadatas"][i].get("created_at"), **result["metadatas"][i], } - for i in range(len(result["ids"])) - ] + formatted_map[str(result_id)] = record + + ordered_results: list[dict[str, Any] | None] = [] + for requested_id in ids: + ordered_results.append(formatted_map.get(str(requested_id))) + + return ordered_results except Exception as e: logger.error(f"Error retrieving vector data for IDs {ids}: {e}") return [] diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 329c61c6..e6d101a7 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -72,15 +72,17 @@ class JsonDocStatusStorage(DocStatusStorage): return set(keys) - set(self._data.keys()) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - result: list[dict[str, Any]] = [] + ordered_results: list[dict[str, Any] | None] = [] if self._storage_lock is None: raise StorageNotInitializedError("JsonDocStatusStorage") async with self._storage_lock: for id in ids: data = self._data.get(id, None) if data: - result.append(data) - return result + ordered_results.append(data.copy()) + else: + ordered_results.append(None) + return ordered_results async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index f2368afe..128fd65e 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -1252,7 +1252,22 @@ class MilvusVectorDBStorage(BaseVectorStorage): output_fields=output_fields, ) - return result or [] + if not result: + return [] + + result_map: dict[str, dict[str, Any]] = {} + for row in result: + if not row: + continue + row_id = row.get("id") + if row_id is not None: + result_map[str(row_id)] = row + + ordered_results: list[dict[str, Any] | None] = [] + for requested_id in ids: + ordered_results.append(result_map.get(str(requested_id))) + + return ordered_results except Exception as e: logger.error( f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 0c11022e..a62c3031 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -155,12 +155,20 @@ class MongoKVStorage(BaseKVStorage): async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: cursor = self._data.find({"_id": {"$in": ids}}) - docs = await cursor.to_list() - # Ensure time fields are present for all documents + docs = await cursor.to_list(length=None) + + doc_map: dict[str, dict[str, Any]] = {} for doc in docs: + if not doc: + continue doc.setdefault("create_time", 0) doc.setdefault("update_time", 0) - return docs + doc_map[str(doc.get("_id"))] = doc + + ordered_results: list[dict[str, Any] | None] = [] + for id_value in ids: + ordered_results.append(doc_map.get(str(id_value))) + return ordered_results async def filter_keys(self, keys: set[str]) -> set[str]: cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1}) @@ -375,7 +383,18 @@ class MongoDocStatusStorage(DocStatusStorage): async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: cursor = self._data.find({"_id": {"$in": ids}}) - return await cursor.to_list() + docs = await cursor.to_list(length=None) + + doc_map: dict[str, dict[str, Any]] = {} + for doc in docs: + if not doc: + continue + doc_map[str(doc.get("_id"))] = doc + + ordered_results: list[dict[str, Any] | None] = [] + for id_value in ids: + ordered_results.append(doc_map.get(str(id_value))) + return ordered_results async def filter_keys(self, data: set[str]) -> set[str]: cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) @@ -2403,15 +2422,20 @@ class MongoVectorDBStorage(BaseVectorStorage): cursor = self._data.find({"_id": {"$in": ids}}) results = await cursor.to_list(length=None) - # Format results to include id field expected by API - formatted_results = [] + # Format results to include id field expected by API and preserve ordering + formatted_map: dict[str, dict[str, Any]] = {} for result in results: result_dict = dict(result) if "_id" in result_dict and "id" not in result_dict: result_dict["id"] = result_dict["_id"] - formatted_results.append(result_dict) + key = str(result_dict.get("id", result_dict.get("_id"))) + formatted_map[key] = result_dict - return formatted_results + ordered_results: list[dict[str, Any] | None] = [] + for id_value in ids: + ordered_results.append(formatted_map.get(str(id_value))) + + return ordered_results except Exception as e: logger.error( f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index def5a83d..e598e34c 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -326,14 +326,25 @@ class NanoVectorDBStorage(BaseVectorStorage): client = await self._get_client() results = client.get(ids) - return [ - { + result_map: dict[str, dict[str, Any]] = {} + + for dp in results: + if not dp: + continue + record = { **{k: v for k, v in dp.items() if k != "vector"}, "id": dp.get("__id__"), "created_at": dp.get("__created_at__"), } - for dp in results - ] + key = record.get("id") + if key is not None: + result_map[str(key)] = record + + ordered_results: list[dict[str, Any] | None] = [] + for requested_id in ids: + ordered_results.append(result_map.get(str(requested_id))) + + return ordered_results async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]: """Get vectors by their IDs, returning only ID and vector data for efficiency diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 93f0cad7..ef25ae77 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1849,6 +1849,26 @@ class PGKVStorage(BaseKVStorage): params = {"workspace": self.workspace} results = await self.db.query(sql, list(params.values()), multirows=True) + def _order_results( + rows: list[dict[str, Any]] | None, + ) -> list[dict[str, Any] | None]: + """Preserve the caller requested ordering for bulk id lookups.""" + if not rows: + return [None for _ in ids] + + id_map: dict[str, dict[str, Any]] = {} + for row in rows: + if row is None: + continue + row_id = row.get("id") + if row_id is not None: + id_map[str(row_id)] = row + + ordered: list[dict[str, Any] | None] = [] + for requested_id in ids: + ordered.append(id_map.get(str(requested_id))) + return ordered + if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): # Parse llm_cache_list JSON string back to list for each result for result in results: @@ -1891,7 +1911,7 @@ class PGKVStorage(BaseKVStorage): "update_time": create_time if update_time == 0 else update_time, } processed_results.append(processed_row) - return processed_results + return _order_results(processed_results) # Special handling for FULL_ENTITIES namespace if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES): @@ -1925,7 +1945,7 @@ class PGKVStorage(BaseKVStorage): result["create_time"] = create_time result["update_time"] = create_time if update_time == 0 else update_time - return results if results else [] + return _order_results(results) async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" @@ -2383,7 +2403,23 @@ class PGVectorStorage(BaseVectorStorage): try: results = await self.db.query(query, list(params.values()), multirows=True) - return [dict(record) for record in results] + if not results: + return [] + + # Preserve caller requested ordering while normalizing asyncpg rows to dicts. + id_map: dict[str, dict[str, Any]] = {} + for record in results: + if record is None: + continue + record_dict = dict(record) + row_id = record_dict.get("id") + if row_id is not None: + id_map[str(row_id)] = record_dict + + ordered_results: list[dict[str, Any] | None] = [] + for requested_id in ids: + ordered_results.append(id_map.get(str(requested_id))) + return ordered_results except Exception as e: logger.error( f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" @@ -2571,7 +2607,7 @@ class PGDocStatusStorage(DocStatusStorage): if not results: return [] - processed_results = [] + processed_map: dict[str, dict[str, Any]] = {} for row in results: # Parse chunks_list JSON string back to list chunks_list = row.get("chunks_list", []) @@ -2593,23 +2629,25 @@ class PGDocStatusStorage(DocStatusStorage): created_at = self._format_datetime_with_timezone(row["created_at"]) updated_at = self._format_datetime_with_timezone(row["updated_at"]) - processed_results.append( - { - "content_length": row["content_length"], - "content_summary": row["content_summary"], - "status": row["status"], - "chunks_count": row["chunks_count"], - "created_at": created_at, - "updated_at": updated_at, - "file_path": row["file_path"], - "chunks_list": chunks_list, - "metadata": metadata, - "error_msg": row.get("error_msg"), - "track_id": row.get("track_id"), - } - ) + processed_map[str(row.get("id"))] = { + "content_length": row["content_length"], + "content_summary": row["content_summary"], + "status": row["status"], + "chunks_count": row["chunks_count"], + "created_at": created_at, + "updated_at": updated_at, + "file_path": row["file_path"], + "chunks_list": chunks_list, + "metadata": metadata, + "error_msg": row.get("error_msg"), + "track_id": row.get("track_id"), + } - return processed_results + ordered_results: list[dict[str, Any] | None] = [] + for requested_id in ids: + ordered_results.append(processed_map.get(str(requested_id))) + + return ordered_results async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]: """Get document by file path diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index de1d07e7..0adfd279 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -404,15 +404,31 @@ class QdrantVectorDBStorage(BaseVectorStorage): with_payload=True, ) - # Ensure each result contains created_at field - payloads = [] + # Ensure each result contains created_at field and preserve caller ordering + payload_by_original_id: dict[str, dict[str, Any]] = {} + payload_by_qdrant_id: dict[str, dict[str, Any]] = {} + for point in results: - payload = point.payload + payload = dict(point.payload or {}) if "created_at" not in payload: payload["created_at"] = None - payloads.append(payload) - return payloads + qdrant_point_id = str(point.id) if point.id is not None else "" + if qdrant_point_id: + payload_by_qdrant_id[qdrant_point_id] = payload + + original_id = payload.get("id") + if original_id is not None: + payload_by_original_id[str(original_id)] = payload + + ordered_payloads: list[dict[str, Any] | None] = [] + for requested_id, qdrant_id in zip(ids, qdrant_ids): + payload = payload_by_original_id.get(str(requested_id)) + if payload is None: + payload = payload_by_qdrant_id.get(str(qdrant_id)) + ordered_payloads.append(payload) + + return ordered_payloads except Exception as e: logger.error( f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 476344a0..56569dda 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -693,7 +693,7 @@ class RedisDocStatusStorage(DocStatusStorage): return set(keys) - existing_ids async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - result: list[dict[str, Any]] = [] + ordered_results: list[dict[str, Any] | None] = [] async with self._get_redis_connection() as redis: try: pipe = redis.pipeline() @@ -704,15 +704,17 @@ class RedisDocStatusStorage(DocStatusStorage): for result_data in results: if result_data: try: - result.append(json.loads(result_data)) + ordered_results.append(json.loads(result_data)) except json.JSONDecodeError as e: logger.error( f"[{self.workspace}] JSON decode error in get_by_ids: {e}" ) - continue + ordered_results.append(None) + else: + ordered_results.append(None) except Exception as e: logger.error(f"[{self.workspace}] Error in get_by_ids: {e}") - return result + return ordered_results async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status"""