Merge pull request #1967 from danielaskdd/pinyin-sort
Add Chinese pinyin sorting support across document operations
This commit is contained in:
commit
9cc9d62c89
5 changed files with 163 additions and 57 deletions
|
|
@ -3,8 +3,7 @@ This module contains all document-related routes for the LightRAG API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from pyuca import Collator
|
from lightrag.utils import logger, get_pinyin_sort_key
|
||||||
from lightrag.utils import logger
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import shutil
|
import shutil
|
||||||
import traceback
|
import traceback
|
||||||
|
|
@ -1269,9 +1268,10 @@ async def pipeline_index_files(
|
||||||
try:
|
try:
|
||||||
enqueued = False
|
enqueued = False
|
||||||
|
|
||||||
# Create Collator for Unicode sorting
|
# Use get_pinyin_sort_key for Chinese pinyin sorting
|
||||||
collator = Collator()
|
sorted_file_paths = sorted(
|
||||||
sorted_file_paths = sorted(file_paths, key=lambda p: collator.sort_key(str(p)))
|
file_paths, key=lambda p: get_pinyin_sort_key(str(p))
|
||||||
|
)
|
||||||
|
|
||||||
# Process files sequentially with track_id
|
# Process files sequentially with track_id
|
||||||
for file_path in sorted_file_paths:
|
for file_path in sorted_file_paths:
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from lightrag.utils import (
|
||||||
load_json,
|
load_json,
|
||||||
logger,
|
logger,
|
||||||
write_json,
|
write_json,
|
||||||
|
get_pinyin_sort_key,
|
||||||
)
|
)
|
||||||
from .shared_storage import (
|
from .shared_storage import (
|
||||||
get_namespace_data,
|
get_namespace_data,
|
||||||
|
|
@ -241,6 +242,10 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||||
# Add sort key for sorting
|
# Add sort key for sorting
|
||||||
if sort_field == "id":
|
if sort_field == "id":
|
||||||
doc_status._sort_key = doc_id
|
doc_status._sort_key = doc_id
|
||||||
|
elif sort_field == "file_path":
|
||||||
|
# Use pinyin sorting for file_path field to support Chinese characters
|
||||||
|
file_path_value = getattr(doc_status, sort_field, "")
|
||||||
|
doc_status._sort_key = get_pinyin_sort_key(file_path_value)
|
||||||
else:
|
else:
|
||||||
doc_status._sort_key = getattr(doc_status, sort_field, "")
|
doc_status._sort_key = getattr(doc_status, sort_field, "")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -329,11 +329,8 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||||
|
|
||||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||||
|
|
||||||
# Create track_id index for better query performance
|
# Create and migrate all indexes including Chinese collation for file_path
|
||||||
await self.create_track_id_index_if_not_exists()
|
await self.create_and_migrate_indexes_if_not_exists()
|
||||||
|
|
||||||
# Create pagination indexes for better query performance
|
|
||||||
await self.create_pagination_indexes_if_not_exists()
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{self.workspace}] Use MongoDB as DocStatus {self._collection_name}"
|
f"[{self.workspace}] Use MongoDB as DocStatus {self._collection_name}"
|
||||||
|
|
@ -476,39 +473,19 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||||
async def delete(self, ids: list[str]) -> None:
|
async def delete(self, ids: list[str]) -> None:
|
||||||
await self._data.delete_many({"_id": {"$in": ids}})
|
await self._data.delete_many({"_id": {"$in": ids}})
|
||||||
|
|
||||||
async def create_track_id_index_if_not_exists(self):
|
async def create_and_migrate_indexes_if_not_exists(self):
|
||||||
"""Create track_id index for better query performance"""
|
"""Create indexes to optimize pagination queries and migrate file_path indexes for Chinese collation"""
|
||||||
try:
|
|
||||||
# Check if index already exists
|
|
||||||
indexes_cursor = await self._data.list_indexes()
|
|
||||||
existing_indexes = await indexes_cursor.to_list(length=None)
|
|
||||||
track_id_index_exists = any(
|
|
||||||
"track_id" in idx.get("key", {}) for idx in existing_indexes
|
|
||||||
)
|
|
||||||
|
|
||||||
if not track_id_index_exists:
|
|
||||||
await self._data.create_index("track_id")
|
|
||||||
logger.info(
|
|
||||||
f"[{self.workspace}] Created track_id index for collection {self._collection_name}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
f"[{self.workspace}] track_id index already exists for collection {self._collection_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except PyMongoError as e:
|
|
||||||
logger.error(
|
|
||||||
f"[{self.workspace}] Error creating track_id index for {self._collection_name}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_pagination_indexes_if_not_exists(self):
|
|
||||||
"""Create indexes to optimize pagination queries"""
|
|
||||||
try:
|
try:
|
||||||
indexes_cursor = await self._data.list_indexes()
|
indexes_cursor = await self._data.list_indexes()
|
||||||
existing_indexes = await indexes_cursor.to_list(length=None)
|
existing_indexes = await indexes_cursor.to_list(length=None)
|
||||||
|
existing_index_names = {idx.get("name", "") for idx in existing_indexes}
|
||||||
|
|
||||||
# Define indexes needed for pagination
|
# Define collation configuration for Chinese pinyin sorting
|
||||||
pagination_indexes = [
|
collation_config = {"locale": "zh", "numericOrdering": True}
|
||||||
|
|
||||||
|
# 1. Define all indexes needed (including original pagination indexes and new collation indexes)
|
||||||
|
all_indexes = [
|
||||||
|
# Original pagination indexes
|
||||||
{
|
{
|
||||||
"name": "status_updated_at",
|
"name": "status_updated_at",
|
||||||
"keys": [("status", 1), ("updated_at", -1)],
|
"keys": [("status", 1), ("updated_at", -1)],
|
||||||
|
|
@ -520,27 +497,93 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||||
{"name": "updated_at", "keys": [("updated_at", -1)]},
|
{"name": "updated_at", "keys": [("updated_at", -1)]},
|
||||||
{"name": "created_at", "keys": [("created_at", -1)]},
|
{"name": "created_at", "keys": [("created_at", -1)]},
|
||||||
{"name": "id", "keys": [("_id", 1)]},
|
{"name": "id", "keys": [("_id", 1)]},
|
||||||
{"name": "file_path", "keys": [("file_path", 1)]},
|
{"name": "track_id", "keys": [("track_id", 1)]},
|
||||||
|
# New file_path indexes with Chinese collation
|
||||||
|
{
|
||||||
|
"name": "file_path_zh_collation",
|
||||||
|
"keys": [("file_path", 1)],
|
||||||
|
"collation": collation_config,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "status_file_path_zh_collation",
|
||||||
|
"keys": [("status", 1), ("file_path", 1)],
|
||||||
|
"collation": collation_config,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check which indexes already exist
|
# 2. Handle index migration: drop conflicting indexes with different names but same key patterns
|
||||||
existing_index_names = {idx.get("name", "") for idx in existing_indexes}
|
for index_info in all_indexes:
|
||||||
|
target_keys = index_info["keys"]
|
||||||
|
target_name = index_info["name"]
|
||||||
|
target_collation = index_info.get("collation")
|
||||||
|
|
||||||
for index_info in pagination_indexes:
|
# Find existing indexes with the same key pattern but different names or collation
|
||||||
|
conflicting_indexes = []
|
||||||
|
for idx in existing_indexes:
|
||||||
|
idx_name = idx.get("name", "")
|
||||||
|
idx_keys = list(idx.get("key", {}).items())
|
||||||
|
idx_collation = idx.get("collation")
|
||||||
|
|
||||||
|
# Skip the _id_ index (MongoDB default)
|
||||||
|
if idx_name == "_id_":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if keys match but name or collation differs
|
||||||
|
if idx_keys == target_keys:
|
||||||
|
if (
|
||||||
|
idx_name != target_name
|
||||||
|
or (target_collation and not idx_collation)
|
||||||
|
or (not target_collation and idx_collation)
|
||||||
|
or (
|
||||||
|
target_collation
|
||||||
|
and idx_collation
|
||||||
|
and target_collation != idx_collation
|
||||||
|
)
|
||||||
|
):
|
||||||
|
conflicting_indexes.append(idx_name)
|
||||||
|
|
||||||
|
# Drop conflicting indexes
|
||||||
|
for conflicting_name in conflicting_indexes:
|
||||||
|
try:
|
||||||
|
await self._data.drop_index(conflicting_name)
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] Migrated: dropped conflicting index '{conflicting_name}' for collection {self._collection_name}"
|
||||||
|
)
|
||||||
|
# Remove from existing_index_names to allow recreation
|
||||||
|
existing_index_names.discard(conflicting_name)
|
||||||
|
except PyMongoError as drop_error:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.workspace}] Failed to drop conflicting index '{conflicting_name}': {drop_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Create all needed indexes
|
||||||
|
for index_info in all_indexes:
|
||||||
index_name = index_info["name"]
|
index_name = index_info["name"]
|
||||||
if index_name not in existing_index_names:
|
if index_name not in existing_index_names:
|
||||||
await self._data.create_index(index_info["keys"], name=index_name)
|
create_kwargs = {"name": index_name}
|
||||||
logger.info(
|
if "collation" in index_info:
|
||||||
f"[{self.workspace}] Created pagination index '{index_name}' for collection {self._collection_name}"
|
create_kwargs["collation"] = index_info["collation"]
|
||||||
)
|
|
||||||
|
try:
|
||||||
|
await self._data.create_index(
|
||||||
|
index_info["keys"], **create_kwargs
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] Created index '{index_name}' for collection {self._collection_name}"
|
||||||
|
)
|
||||||
|
except PyMongoError as create_error:
|
||||||
|
# If creation still fails, log the error but continue with other indexes
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Failed to create index '{index_name}' for collection {self._collection_name}: {create_error}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{self.workspace}] Pagination index '{index_name}' already exists for collection {self._collection_name}"
|
f"[{self.workspace}] Index '{index_name}' already exists for collection {self._collection_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error creating pagination indexes for {self._collection_name}: {e}"
|
f"[{self.workspace}] Error creating/migrating indexes for {self._collection_name}: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_docs_paginated(
|
async def get_docs_paginated(
|
||||||
|
|
@ -592,13 +635,24 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||||
sort_direction_value = 1 if sort_direction.lower() == "asc" else -1
|
sort_direction_value = 1 if sort_direction.lower() == "asc" else -1
|
||||||
sort_criteria = [(sort_field, sort_direction_value)]
|
sort_criteria = [(sort_field, sort_direction_value)]
|
||||||
|
|
||||||
# Query for paginated data
|
# Query for paginated data with Chinese collation for file_path sorting
|
||||||
cursor = (
|
if sort_field == "file_path":
|
||||||
self._data.find(query_filter)
|
# Use Chinese collation for pinyin sorting
|
||||||
.sort(sort_criteria)
|
cursor = (
|
||||||
.skip(skip)
|
self._data.find(query_filter)
|
||||||
.limit(page_size)
|
.sort(sort_criteria)
|
||||||
)
|
.collation({"locale": "zh", "numericOrdering": True})
|
||||||
|
.skip(skip)
|
||||||
|
.limit(page_size)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use default sorting for other fields
|
||||||
|
cursor = (
|
||||||
|
self._data.find(query_filter)
|
||||||
|
.sort(sort_criteria)
|
||||||
|
.skip(skip)
|
||||||
|
.limit(page_size)
|
||||||
|
)
|
||||||
result = await cursor.to_list(length=page_size)
|
result = await cursor.to_list(length=page_size)
|
||||||
|
|
||||||
# Convert to (doc_id, DocProcessingStatus) tuples
|
# Convert to (doc_id, DocProcessingStatus) tuples
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ if not pm.is_installed("redis"):
|
||||||
# aioredis is a depricated library, replaced with redis
|
# aioredis is a depricated library, replaced with redis
|
||||||
from redis.asyncio import Redis, ConnectionPool # type: ignore
|
from redis.asyncio import Redis, ConnectionPool # type: ignore
|
||||||
from redis.exceptions import RedisError, ConnectionError, TimeoutError # type: ignore
|
from redis.exceptions import RedisError, ConnectionError, TimeoutError # type: ignore
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger, get_pinyin_sort_key
|
||||||
|
|
||||||
from lightrag.base import (
|
from lightrag.base import (
|
||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
|
|
@ -998,6 +998,10 @@ class RedisDocStatusStorage(DocStatusStorage):
|
||||||
# Calculate sort key for sorting (but don't add to data)
|
# Calculate sort key for sorting (but don't add to data)
|
||||||
if sort_field == "id":
|
if sort_field == "id":
|
||||||
sort_key = doc_id
|
sort_key = doc_id
|
||||||
|
elif sort_field == "file_path":
|
||||||
|
# Use pinyin sorting for file_path field to support Chinese characters
|
||||||
|
file_path_value = data.get(sort_field, "")
|
||||||
|
sort_key = get_pinyin_sort_key(file_path_value)
|
||||||
else:
|
else:
|
||||||
sort_key = data.get(sort_field, "")
|
sort_key = data.get(sort_field, "")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,20 @@ from hashlib import md5
|
||||||
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Import pyuca for Chinese pinyin sorting
|
||||||
|
try:
|
||||||
|
import pyuca
|
||||||
|
|
||||||
|
_pinyin_collator = pyuca.Collator()
|
||||||
|
_pyuca_available = True
|
||||||
|
except ImportError:
|
||||||
|
_pinyin_collator = None
|
||||||
|
_pyuca_available = False
|
||||||
|
except Exception:
|
||||||
|
_pinyin_collator = None
|
||||||
|
_pyuca_available = False
|
||||||
|
|
||||||
from lightrag.constants import (
|
from lightrag.constants import (
|
||||||
DEFAULT_LOG_MAX_BYTES,
|
DEFAULT_LOG_MAX_BYTES,
|
||||||
DEFAULT_LOG_BACKUP_COUNT,
|
DEFAULT_LOG_BACKUP_COUNT,
|
||||||
|
|
@ -2059,3 +2073,32 @@ def generate_track_id(prefix: str = "upload") -> str:
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
unique_id = str(uuid.uuid4())[:8] # Use first 8 characters of UUID
|
unique_id = str(uuid.uuid4())[:8] # Use first 8 characters of UUID
|
||||||
return f"{prefix}_{timestamp}_{unique_id}"
|
return f"{prefix}_{timestamp}_{unique_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_pinyin_sort_key(text: str) -> str:
|
||||||
|
"""Generate sort key for Chinese pinyin sorting
|
||||||
|
|
||||||
|
This function uses pyuca (Python Unicode Collation Algorithm) to generate
|
||||||
|
sort keys that handle Chinese characters by their pinyin pronunciation.
|
||||||
|
For non-Chinese text, it falls back to standard Unicode sorting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to generate sort key for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sort key that can be used for comparison and sorting
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Use the globally initialized collator
|
||||||
|
if _pyuca_available and _pinyin_collator is not None:
|
||||||
|
try:
|
||||||
|
return _pinyin_collator.sort_key(text)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to generate pinyin sort key for '{text}': {e}. Using fallback."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback to standard string sorting
|
||||||
|
return text.lower()
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue