Preserve ordering in get_by_ids methods across all storage implementations
- Fix result ordering in vector stores
- Update KV storage get_by_ids methods
- Maintain order in doc status storage
- Return None for missing IDs
(cherry picked from commit 9be22dd666)
This commit is contained in:
parent
de2713ca93
commit
770fd64c70
8 changed files with 292 additions and 514 deletions
|
|
@ -295,17 +295,23 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||||
if not result or not result["ids"] or len(result["ids"]) == 0:
|
if not result or not result["ids"] or len(result["ids"]) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Format the results to match the expected structure
|
# Format the results to match the expected structure and preserve ordering
|
||||||
return [
|
formatted_map: dict[str, dict[str, Any]] = {}
|
||||||
{
|
for i, result_id in enumerate(result["ids"]):
|
||||||
"id": result["ids"][i],
|
record = {
|
||||||
|
"id": result_id,
|
||||||
"vector": result["embeddings"][i],
|
"vector": result["embeddings"][i],
|
||||||
"content": result["documents"][i],
|
"content": result["documents"][i],
|
||||||
"created_at": result["metadatas"][i].get("created_at"),
|
"created_at": result["metadatas"][i].get("created_at"),
|
||||||
**result["metadatas"][i],
|
**result["metadatas"][i],
|
||||||
}
|
}
|
||||||
for i in range(len(result["ids"]))
|
formatted_map[str(result_id)] = record
|
||||||
]
|
|
||||||
|
ordered_results: list[dict[str, Any] | None] = []
|
||||||
|
for requested_id in ids:
|
||||||
|
ordered_results.append(formatted_map.get(str(requested_id)))
|
||||||
|
|
||||||
|
return ordered_results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -78,15 +78,17 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||||
return set(keys) - set(self._data.keys())
|
return set(keys) - set(self._data.keys())
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
result: list[dict[str, Any]] = []
|
ordered_results: list[dict[str, Any] | None] = []
|
||||||
if self._storage_lock is None:
|
if self._storage_lock is None:
|
||||||
raise StorageNotInitializedError("JsonDocStatusStorage")
|
raise StorageNotInitializedError("JsonDocStatusStorage")
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
for id in ids:
|
for id in ids:
|
||||||
data = self._data.get(id, None)
|
data = self._data.get(id, None)
|
||||||
if data:
|
if data:
|
||||||
result.append(data)
|
ordered_results.append(data.copy())
|
||||||
return result
|
else:
|
||||||
|
ordered_results.append(None)
|
||||||
|
return ordered_results
|
||||||
|
|
||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
|
|
|
||||||
|
|
@ -1258,7 +1258,22 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
output_fields=output_fields,
|
output_fields=output_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result or []
|
if not result:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result_map: dict[str, dict[str, Any]] = {}
|
||||||
|
for row in result:
|
||||||
|
if not row:
|
||||||
|
continue
|
||||||
|
row_id = row.get("id")
|
||||||
|
if row_id is not None:
|
||||||
|
result_map[str(row_id)] = row
|
||||||
|
|
||||||
|
ordered_results: list[dict[str, Any] | None] = []
|
||||||
|
for requested_id in ids:
|
||||||
|
ordered_results.append(result_map.get(str(requested_id)))
|
||||||
|
|
||||||
|
return ordered_results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
|
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -332,14 +332,25 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
|
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
results = client.get(ids)
|
results = client.get(ids)
|
||||||
return [
|
result_map: dict[str, dict[str, Any]] = {}
|
||||||
{
|
|
||||||
|
for dp in results:
|
||||||
|
if not dp:
|
||||||
|
continue
|
||||||
|
record = {
|
||||||
**{k: v for k, v in dp.items() if k != "vector"},
|
**{k: v for k, v in dp.items() if k != "vector"},
|
||||||
"id": dp.get("__id__"),
|
"id": dp.get("__id__"),
|
||||||
"created_at": dp.get("__created_at__"),
|
"created_at": dp.get("__created_at__"),
|
||||||
}
|
}
|
||||||
for dp in results
|
key = record.get("id")
|
||||||
]
|
if key is not None:
|
||||||
|
result_map[str(key)] = record
|
||||||
|
|
||||||
|
ordered_results: list[dict[str, Any] | None] = []
|
||||||
|
for requested_id in ids:
|
||||||
|
ordered_results.append(result_map.get(str(requested_id)))
|
||||||
|
|
||||||
|
return ordered_results
|
||||||
|
|
||||||
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||||
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||||
|
|
|
||||||
|
|
@ -1819,6 +1819,26 @@ class PGKVStorage(BaseKVStorage):
|
||||||
params = {"workspace": self.workspace}
|
params = {"workspace": self.workspace}
|
||||||
results = await self.db.query(sql, list(params.values()), multirows=True)
|
results = await self.db.query(sql, list(params.values()), multirows=True)
|
||||||
|
|
||||||
|
def _order_results(
|
||||||
|
rows: list[dict[str, Any]] | None,
|
||||||
|
) -> list[dict[str, Any] | None]:
|
||||||
|
"""Preserve the caller requested ordering for bulk id lookups."""
|
||||||
|
if not rows:
|
||||||
|
return [None for _ in ids]
|
||||||
|
|
||||||
|
id_map: dict[str, dict[str, Any]] = {}
|
||||||
|
for row in rows:
|
||||||
|
if row is None:
|
||||||
|
continue
|
||||||
|
row_id = row.get("id")
|
||||||
|
if row_id is not None:
|
||||||
|
id_map[str(row_id)] = row
|
||||||
|
|
||||||
|
ordered: list[dict[str, Any] | None] = []
|
||||||
|
for requested_id in ids:
|
||||||
|
ordered.append(id_map.get(str(requested_id)))
|
||||||
|
return ordered
|
||||||
|
|
||||||
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
||||||
# Parse llm_cache_list JSON string back to list for each result
|
# Parse llm_cache_list JSON string back to list for each result
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|
@ -1861,7 +1881,7 @@ class PGKVStorage(BaseKVStorage):
|
||||||
"update_time": create_time if update_time == 0 else update_time,
|
"update_time": create_time if update_time == 0 else update_time,
|
||||||
}
|
}
|
||||||
processed_results.append(processed_row)
|
processed_results.append(processed_row)
|
||||||
return processed_results
|
return _order_results(processed_results)
|
||||||
|
|
||||||
# Special handling for FULL_ENTITIES namespace
|
# Special handling for FULL_ENTITIES namespace
|
||||||
if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
|
if results and is_namespace(self.namespace, NameSpace.KV_STORE_FULL_ENTITIES):
|
||||||
|
|
@ -1895,7 +1915,7 @@ class PGKVStorage(BaseKVStorage):
|
||||||
result["create_time"] = create_time
|
result["create_time"] = create_time
|
||||||
result["update_time"] = create_time if update_time == 0 else update_time
|
result["update_time"] = create_time if update_time == 0 else update_time
|
||||||
|
|
||||||
return results if results else []
|
return _order_results(results)
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Filter out duplicated content"""
|
"""Filter out duplicated content"""
|
||||||
|
|
@ -2353,7 +2373,23 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = await self.db.query(query, list(params.values()), multirows=True)
|
results = await self.db.query(query, list(params.values()), multirows=True)
|
||||||
return [dict(record) for record in results]
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Preserve caller requested ordering while normalizing asyncpg rows to dicts.
|
||||||
|
id_map: dict[str, dict[str, Any]] = {}
|
||||||
|
for record in results:
|
||||||
|
if record is None:
|
||||||
|
continue
|
||||||
|
record_dict = dict(record)
|
||||||
|
row_id = record_dict.get("id")
|
||||||
|
if row_id is not None:
|
||||||
|
id_map[str(row_id)] = record_dict
|
||||||
|
|
||||||
|
ordered_results: list[dict[str, Any] | None] = []
|
||||||
|
for requested_id in ids:
|
||||||
|
ordered_results.append(id_map.get(str(requested_id)))
|
||||||
|
return ordered_results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
|
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
|
||||||
|
|
@ -2541,7 +2577,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
if not results:
|
if not results:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
processed_results = []
|
processed_map: dict[str, dict[str, Any]] = {}
|
||||||
for row in results:
|
for row in results:
|
||||||
# Parse chunks_list JSON string back to list
|
# Parse chunks_list JSON string back to list
|
||||||
chunks_list = row.get("chunks_list", [])
|
chunks_list = row.get("chunks_list", [])
|
||||||
|
|
@ -2563,23 +2599,25 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
created_at = self._format_datetime_with_timezone(row["created_at"])
|
created_at = self._format_datetime_with_timezone(row["created_at"])
|
||||||
updated_at = self._format_datetime_with_timezone(row["updated_at"])
|
updated_at = self._format_datetime_with_timezone(row["updated_at"])
|
||||||
|
|
||||||
processed_results.append(
|
processed_map[str(row.get("id"))] = {
|
||||||
{
|
"content_length": row["content_length"],
|
||||||
"content_length": row["content_length"],
|
"content_summary": row["content_summary"],
|
||||||
"content_summary": row["content_summary"],
|
"status": row["status"],
|
||||||
"status": row["status"],
|
"chunks_count": row["chunks_count"],
|
||||||
"chunks_count": row["chunks_count"],
|
"created_at": created_at,
|
||||||
"created_at": created_at,
|
"updated_at": updated_at,
|
||||||
"updated_at": updated_at,
|
"file_path": row["file_path"],
|
||||||
"file_path": row["file_path"],
|
"chunks_list": chunks_list,
|
||||||
"chunks_list": chunks_list,
|
"metadata": metadata,
|
||||||
"metadata": metadata,
|
"error_msg": row.get("error_msg"),
|
||||||
"error_msg": row.get("error_msg"),
|
"track_id": row.get("track_id"),
|
||||||
"track_id": row.get("track_id"),
|
}
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return processed_results
|
ordered_results: list[dict[str, Any] | None] = []
|
||||||
|
for requested_id in ids:
|
||||||
|
ordered_results.append(processed_map.get(str(requested_id)))
|
||||||
|
|
||||||
|
return ordered_results
|
||||||
|
|
||||||
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
|
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
|
||||||
"""Get document by file path
|
"""Get document by file path
|
||||||
|
|
|
||||||
|
|
@ -409,15 +409,31 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
with_payload=True,
|
with_payload=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure each result contains created_at field
|
# Ensure each result contains created_at field and preserve caller ordering
|
||||||
payloads = []
|
payload_by_original_id: dict[str, dict[str, Any]] = {}
|
||||||
|
payload_by_qdrant_id: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
for point in results:
|
for point in results:
|
||||||
payload = point.payload
|
payload = dict(point.payload or {})
|
||||||
if "created_at" not in payload:
|
if "created_at" not in payload:
|
||||||
payload["created_at"] = None
|
payload["created_at"] = None
|
||||||
payloads.append(payload)
|
|
||||||
|
|
||||||
return payloads
|
qdrant_point_id = str(point.id) if point.id is not None else ""
|
||||||
|
if qdrant_point_id:
|
||||||
|
payload_by_qdrant_id[qdrant_point_id] = payload
|
||||||
|
|
||||||
|
original_id = payload.get("id")
|
||||||
|
if original_id is not None:
|
||||||
|
payload_by_original_id[str(original_id)] = payload
|
||||||
|
|
||||||
|
ordered_payloads: list[dict[str, Any] | None] = []
|
||||||
|
for requested_id, qdrant_id in zip(ids, qdrant_ids):
|
||||||
|
payload = payload_by_original_id.get(str(requested_id))
|
||||||
|
if payload is None:
|
||||||
|
payload = payload_by_qdrant_id.get(str(qdrant_id))
|
||||||
|
ordered_payloads.append(payload)
|
||||||
|
|
||||||
|
return ordered_payloads
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
|
f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}"
|
||||||
|
|
|
||||||
|
|
@ -736,7 +736,7 @@ class RedisDocStatusStorage(DocStatusStorage):
|
||||||
return set(keys) - existing_ids
|
return set(keys) - existing_ids
|
||||||
|
|
||||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||||
result: list[dict[str, Any]] = []
|
ordered_results: list[dict[str, Any] | None] = []
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
try:
|
try:
|
||||||
pipe = redis.pipeline()
|
pipe = redis.pipeline()
|
||||||
|
|
@ -747,15 +747,17 @@ class RedisDocStatusStorage(DocStatusStorage):
|
||||||
for result_data in results:
|
for result_data in results:
|
||||||
if result_data:
|
if result_data:
|
||||||
try:
|
try:
|
||||||
result.append(json.loads(result_data))
|
ordered_results.append(json.loads(result_data))
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] JSON decode error in get_by_ids: {e}"
|
f"[{self.workspace}] JSON decode error in get_by_ids: {e}"
|
||||||
)
|
)
|
||||||
continue
|
ordered_results.append(None)
|
||||||
|
else:
|
||||||
|
ordered_results.append(None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.workspace}] Error in get_by_ids: {e}")
|
logger.error(f"[{self.workspace}] Error in get_by_ids: {e}")
|
||||||
return result
|
return ordered_results
|
||||||
|
|
||||||
async def get_status_counts(self) -> dict[str, int]:
|
async def get_status_counts(self) -> dict[str, int]:
|
||||||
"""Get counts of documents in each status"""
|
"""Get counts of documents in each status"""
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue