Preserve ordering in get_by_ids methods across all storage implementations

- Fix result ordering in vector stores
- Update KV storage get_by_ids methods
- Maintain order in doc status storage
- Return None for missing IDs
This commit is contained in:
yangdx 2025-10-11 12:37:59 +08:00
parent 49326f2b14
commit 9be22dd666
8 changed files with 165 additions and 51 deletions

View file

@ -295,17 +295,23 @@ class ChromaVectorDBStorage(BaseVectorStorage):
if not result or not result["ids"] or len(result["ids"]) == 0: if not result or not result["ids"] or len(result["ids"]) == 0:
return [] return []
# Format the results to match the expected structure # Format the results to match the expected structure and preserve ordering
return [ formatted_map: dict[str, dict[str, Any]] = {}
{ for i, result_id in enumerate(result["ids"]):
"id": result["ids"][i], record = {
"id": result_id,
"vector": result["embeddings"][i], "vector": result["embeddings"][i],
"content": result["documents"][i], "content": result["documents"][i],
"created_at": result["metadatas"][i].get("created_at"), "created_at": result["metadatas"][i].get("created_at"),
**result["metadatas"][i], **result["metadatas"][i],
} }
for i in range(len(result["ids"])) formatted_map[str(result_id)] = record
]
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(formatted_map.get(str(requested_id)))
return ordered_results
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}") logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return [] return []

View file

@ -72,15 +72,17 @@ class JsonDocStatusStorage(DocStatusStorage):
return set(keys) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = [] ordered_results: list[dict[str, Any] | None] = []
if self._storage_lock is None: if self._storage_lock is None:
raise StorageNotInitializedError("JsonDocStatusStorage") raise StorageNotInitializedError("JsonDocStatusStorage")
async with self._storage_lock: async with self._storage_lock:
for id in ids: for id in ids:
data = self._data.get(id, None) data = self._data.get(id, None)
if data: if data:
result.append(data) ordered_results.append(data.copy())
return result else:
ordered_results.append(None)
return ordered_results
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""

View file

@ -1252,7 +1252,22 @@ class MilvusVectorDBStorage(BaseVectorStorage):
output_fields=output_fields, output_fields=output_fields,
) )
return result or [] if not result:
return []
result_map: dict[str, dict[str, Any]] = {}
for row in result:
if not row:
continue
row_id = row.get("id")
if row_id is not None:
result_map[str(row_id)] = row
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(result_map.get(str(requested_id)))
return ordered_results
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"

View file

@ -155,12 +155,20 @@ class MongoKVStorage(BaseKVStorage):
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
cursor = self._data.find({"_id": {"$in": ids}}) cursor = self._data.find({"_id": {"$in": ids}})
docs = await cursor.to_list() docs = await cursor.to_list(length=None)
# Ensure time fields are present for all documents
doc_map: dict[str, dict[str, Any]] = {}
for doc in docs: for doc in docs:
if not doc:
continue
doc.setdefault("create_time", 0) doc.setdefault("create_time", 0)
doc.setdefault("update_time", 0) doc.setdefault("update_time", 0)
return docs doc_map[str(doc.get("_id"))] = doc
ordered_results: list[dict[str, Any] | None] = []
for id_value in ids:
ordered_results.append(doc_map.get(str(id_value)))
return ordered_results
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1}) cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
@ -375,7 +383,18 @@ class MongoDocStatusStorage(DocStatusStorage):
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
cursor = self._data.find({"_id": {"$in": ids}}) cursor = self._data.find({"_id": {"$in": ids}})
return await cursor.to_list() docs = await cursor.to_list(length=None)
doc_map: dict[str, dict[str, Any]] = {}
for doc in docs:
if not doc:
continue
doc_map[str(doc.get("_id"))] = doc
ordered_results: list[dict[str, Any] | None] = []
for id_value in ids:
ordered_results.append(doc_map.get(str(id_value)))
return ordered_results
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, data: set[str]) -> set[str]:
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
@ -2403,15 +2422,20 @@ class MongoVectorDBStorage(BaseVectorStorage):
cursor = self._data.find({"_id": {"$in": ids}}) cursor = self._data.find({"_id": {"$in": ids}})
results = await cursor.to_list(length=None) results = await cursor.to_list(length=None)
# Format results to include id field expected by API # Format results to include id field expected by API and preserve ordering
formatted_results = [] formatted_map: dict[str, dict[str, Any]] = {}
for result in results: for result in results:
result_dict = dict(result) result_dict = dict(result)
if "_id" in result_dict and "id" not in result_dict: if "_id" in result_dict and "id" not in result_dict:
result_dict["id"] = result_dict["_id"] result_dict["id"] = result_dict["_id"]
formatted_results.append(result_dict) key = str(result_dict.get("id", result_dict.get("_id")))
formatted_map[key] = result_dict
return formatted_results ordered_results: list[dict[str, Any] | None] = []
for id_value in ids:
ordered_results.append(formatted_map.get(str(id_value)))
return ordered_results
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"

View file

@ -326,14 +326,25 @@ class NanoVectorDBStorage(BaseVectorStorage):
client = await self._get_client() client = await self._get_client()
results = client.get(ids) results = client.get(ids)
return [ result_map: dict[str, dict[str, Any]] = {}
{
for dp in results:
if not dp:
continue
record = {
**{k: v for k, v in dp.items() if k != "vector"}, **{k: v for k, v in dp.items() if k != "vector"},
"id": dp.get("__id__"), "id": dp.get("__id__"),
"created_at": dp.get("__created_at__"), "created_at": dp.get("__created_at__"),
} }
for dp in results key = record.get("id")
] if key is not None:
result_map[str(key)] = record
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(result_map.get(str(requested_id)))
return ordered_results
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]: async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
"""Get vectors by their IDs, returning only ID and vector data for efficiency """Get vectors by their IDs, returning only ID and vector data for efficiency

View file

@ -1849,6 +1849,26 @@ class PGKVStorage(BaseKVStorage):
params = {"workspace": self.workspace} params = {"workspace": self.workspace}
results = await self.db.query(sql, list(params.values()), multirows=True) results = await self.db.query(sql, list(params.values()), multirows=True)
def _order_results(
rows: list[dict[str, Any]] | None,
) -> list[dict[str, Any] | None]:
"""Preserve the caller requested ordering for bulk id lookups."""
if not rows:
return [None for _ in ids]
id_map: dict[str, dict[str, Any]] = {}
for row in rows:
if row is None:
continue
row_id = row.get("id")
if row_id is not None:
id_map[str(row_id)] = row
ordered: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered.append(id_map.get(str(requested_id)))
return ordered
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list for each result # Parse llm_cache_list JSON string back to list for each result
for result in results: for result in results:
@ -1891,7 +1911,7 @@ class PGKVStorage(BaseKVStorage):
"update_time": create_time if update_time == 0 else update_time, "update_time": create_time if update_time == 0 else update_time,
} }
processed_results.append(processed_row) processed_results.append(processed_row)
return processed_results return _order_results(processed_results)
# Special handling for FULL_ENTITIES namespace # Special handling for FULL_ENTITIES namespace
if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES): if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
@ -1925,7 +1945,7 @@ class PGKVStorage(BaseKVStorage):
result["create_time"] = create_time result["create_time"] = create_time
result["update_time"] = create_time if update_time == 0 else update_time result["update_time"] = create_time if update_time == 0 else update_time
return results if results else [] return _order_results(results)
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
@ -2383,7 +2403,23 @@ class PGVectorStorage(BaseVectorStorage):
try: try:
results = await self.db.query(query, list(params.values()), multirows=True) results = await self.db.query(query, list(params.values()), multirows=True)
return [dict(record) for record in results] if not results:
return []
# Preserve caller requested ordering while normalizing asyncpg rows to dicts.
id_map: dict[str, dict[str, Any]] = {}
for record in results:
if record is None:
continue
record_dict = dict(record)
row_id = record_dict.get("id")
if row_id is not None:
id_map[str(row_id)] = record_dict
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(id_map.get(str(requested_id)))
return ordered_results
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
@ -2571,7 +2607,7 @@ class PGDocStatusStorage(DocStatusStorage):
if not results: if not results:
return [] return []
processed_results = [] processed_map: dict[str, dict[str, Any]] = {}
for row in results: for row in results:
# Parse chunks_list JSON string back to list # Parse chunks_list JSON string back to list
chunks_list = row.get("chunks_list", []) chunks_list = row.get("chunks_list", [])
@ -2593,23 +2629,25 @@ class PGDocStatusStorage(DocStatusStorage):
created_at = self._format_datetime_with_timezone(row["created_at"]) created_at = self._format_datetime_with_timezone(row["created_at"])
updated_at = self._format_datetime_with_timezone(row["updated_at"]) updated_at = self._format_datetime_with_timezone(row["updated_at"])
processed_results.append( processed_map[str(row.get("id"))] = {
{ "content_length": row["content_length"],
"content_length": row["content_length"], "content_summary": row["content_summary"],
"content_summary": row["content_summary"], "status": row["status"],
"status": row["status"], "chunks_count": row["chunks_count"],
"chunks_count": row["chunks_count"], "created_at": created_at,
"created_at": created_at, "updated_at": updated_at,
"updated_at": updated_at, "file_path": row["file_path"],
"file_path": row["file_path"], "chunks_list": chunks_list,
"chunks_list": chunks_list, "metadata": metadata,
"metadata": metadata, "error_msg": row.get("error_msg"),
"error_msg": row.get("error_msg"), "track_id": row.get("track_id"),
"track_id": row.get("track_id"), }
}
)
return processed_results ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(processed_map.get(str(requested_id)))
return ordered_results
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]: async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
"""Get document by file path """Get document by file path

View file

@ -404,15 +404,31 @@ class QdrantVectorDBStorage(BaseVectorStorage):
with_payload=True, with_payload=True,
) )
# Ensure each result contains created_at field # Ensure each result contains created_at field and preserve caller ordering
payloads = [] payload_by_original_id: dict[str, dict[str, Any]] = {}
payload_by_qdrant_id: dict[str, dict[str, Any]] = {}
for point in results: for point in results:
payload = point.payload payload = dict(point.payload or {})
if "created_at" not in payload: if "created_at" not in payload:
payload["created_at"] = None payload["created_at"] = None
payloads.append(payload)
return payloads qdrant_point_id = str(point.id) if point.id is not None else ""
if qdrant_point_id:
payload_by_qdrant_id[qdrant_point_id] = payload
original_id = payload.get("id")
if original_id is not None:
payload_by_original_id[str(original_id)] = payload
ordered_payloads: list[dict[str, Any] | None] = []
for requested_id, qdrant_id in zip(ids, qdrant_ids):
payload = payload_by_original_id.get(str(requested_id))
if payload is None:
payload = payload_by_qdrant_id.get(str(qdrant_id))
ordered_payloads.append(payload)
return ordered_payloads
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"

View file

@ -693,7 +693,7 @@ class RedisDocStatusStorage(DocStatusStorage):
return set(keys) - existing_ids return set(keys) - existing_ids
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = [] ordered_results: list[dict[str, Any] | None] = []
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
pipe = redis.pipeline() pipe = redis.pipeline()
@ -704,15 +704,17 @@ class RedisDocStatusStorage(DocStatusStorage):
for result_data in results: for result_data in results:
if result_data: if result_data:
try: try:
result.append(json.loads(result_data)) ordered_results.append(json.loads(result_data))
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error( logger.error(
f"[{self.workspace}] JSON decode error in get_by_ids: {e}" f"[{self.workspace}] JSON decode error in get_by_ids: {e}"
) )
continue ordered_results.append(None)
else:
ordered_results.append(None)
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error in get_by_ids: {e}") logger.error(f"[{self.workspace}] Error in get_by_ids: {e}")
return result return ordered_results
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""