Preserve ordering in get_by_ids methods across all storage implementations

- Fix result ordering in vector stores
- Update KV storage get_by_ids methods
- Maintain order in doc status storage
- Return None for missing IDs

(cherry picked from commit 9be22dd666)
This commit is contained in:
yangdx 2025-10-11 12:37:59 +08:00 committed by Raphaël MANSUY
parent 17106225dd
commit e19a4be0af

View file

@ -1819,6 +1819,26 @@ class PGKVStorage(BaseKVStorage):
params = {"workspace": self.workspace} params = {"workspace": self.workspace}
results = await self.db.query(sql, list(params.values()), multirows=True) results = await self.db.query(sql, list(params.values()), multirows=True)
def _order_results(
rows: list[dict[str, Any]] | None,
) -> list[dict[str, Any] | None]:
"""Preserve the caller requested ordering for bulk id lookups."""
if not rows:
return [None for _ in ids]
id_map: dict[str, dict[str, Any]] = {}
for row in rows:
if row is None:
continue
row_id = row.get("id")
if row_id is not None:
id_map[str(row_id)] = row
ordered: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered.append(id_map.get(str(requested_id)))
return ordered
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list for each result # Parse llm_cache_list JSON string back to list for each result
for result in results: for result in results:
@ -1861,7 +1881,7 @@ class PGKVStorage(BaseKVStorage):
"update_time": create_time if update_time == 0 else update_time, "update_time": create_time if update_time == 0 else update_time,
} }
processed_results.append(processed_row) processed_results.append(processed_row)
return processed_results return _order_results(processed_results)
# Special handling for FULL_ENTITIES namespace # Special handling for FULL_ENTITIES namespace
if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES): if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
@ -1895,7 +1915,7 @@ class PGKVStorage(BaseKVStorage):
result["create_time"] = create_time result["create_time"] = create_time
result["update_time"] = create_time if update_time == 0 else update_time result["update_time"] = create_time if update_time == 0 else update_time
return results if results else [] return _order_results(results)
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
@ -2353,7 +2373,23 @@ class PGVectorStorage(BaseVectorStorage):
try: try:
results = await self.db.query(query, list(params.values()), multirows=True) results = await self.db.query(query, list(params.values()), multirows=True)
return [dict(record) for record in results] if not results:
return []
# Preserve caller requested ordering while normalizing asyncpg rows to dicts.
id_map: dict[str, dict[str, Any]] = {}
for record in results:
if record is None:
continue
record_dict = dict(record)
row_id = record_dict.get("id")
if row_id is not None:
id_map[str(row_id)] = record_dict
ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(id_map.get(str(requested_id)))
return ordered_results
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
@ -2541,7 +2577,7 @@ class PGDocStatusStorage(DocStatusStorage):
if not results: if not results:
return [] return []
processed_results = [] processed_map: dict[str, dict[str, Any]] = {}
for row in results: for row in results:
# Parse chunks_list JSON string back to list # Parse chunks_list JSON string back to list
chunks_list = row.get("chunks_list", []) chunks_list = row.get("chunks_list", [])
@ -2563,23 +2599,25 @@ class PGDocStatusStorage(DocStatusStorage):
created_at = self._format_datetime_with_timezone(row["created_at"]) created_at = self._format_datetime_with_timezone(row["created_at"])
updated_at = self._format_datetime_with_timezone(row["updated_at"]) updated_at = self._format_datetime_with_timezone(row["updated_at"])
processed_results.append( processed_map[str(row.get("id"))] = {
{ "content_length": row["content_length"],
"content_length": row["content_length"], "content_summary": row["content_summary"],
"content_summary": row["content_summary"], "status": row["status"],
"status": row["status"], "chunks_count": row["chunks_count"],
"chunks_count": row["chunks_count"], "created_at": created_at,
"created_at": created_at, "updated_at": updated_at,
"updated_at": updated_at, "file_path": row["file_path"],
"file_path": row["file_path"], "chunks_list": chunks_list,
"chunks_list": chunks_list, "metadata": metadata,
"metadata": metadata, "error_msg": row.get("error_msg"),
"error_msg": row.get("error_msg"), "track_id": row.get("track_id"),
"track_id": row.get("track_id"), }
}
)
return processed_results ordered_results: list[dict[str, Any] | None] = []
for requested_id in ids:
ordered_results.append(processed_map.get(str(requested_id)))
return ordered_results
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]: async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
"""Get document by file path """Get document by file path