Merge pull request #1959 from danielaskdd/pick-trunk-by-vector
Feat: add KG related chunks selection by vector similarity
This commit is contained in:
commit
bdd1169cfb
14 changed files with 742 additions and 161 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -72,3 +72,4 @@ test_*
|
||||||
# Cline files
|
# Cline files
|
||||||
memory-bank
|
memory-bank
|
||||||
memory-bank/
|
memory-bank/
|
||||||
|
.clinerules
|
||||||
|
|
|
||||||
16
env.example
16
env.example
|
|
@ -71,10 +71,20 @@ ENABLE_LLM_CACHE=true
|
||||||
# MAX_RELATION_TOKENS=10000
|
# MAX_RELATION_TOKENS=10000
|
||||||
### control the maximum tokens send to LLM (include entities, raltions and chunks)
|
### control the maximum tokens send to LLM (include entities, raltions and chunks)
|
||||||
# MAX_TOTAL_TOKENS=30000
|
# MAX_TOTAL_TOKENS=30000
|
||||||
### maximum number of related chunks per source entity or relation (higher values increase re-ranking time)
|
|
||||||
|
### maximum number of related chunks per source entity or relation
|
||||||
|
### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph)
|
||||||
|
### Higher values increase re-ranking time
|
||||||
# RELATED_CHUNK_NUMBER=5
|
# RELATED_CHUNK_NUMBER=5
|
||||||
|
|
||||||
### Reranker configuration (Set ENABLE_RERANK to true in reranking model is configed)
|
### chunk selection strategies
|
||||||
|
### VECTOR: Pick KG chunks by vector similarity, delivered chunks to the LLM aligning more closely with naive retrieval
|
||||||
|
### WEIGHT: Pick KG chunks by entity and chunk weight, delivered more solely KG related chunks to the LLM
|
||||||
|
### If reranking is enabled, the impact of chunk selection strategies will be diminished.
|
||||||
|
# KG_CHUNK_PICK_METHOD=VECTOR
|
||||||
|
|
||||||
|
### Reranking configuration
|
||||||
|
### Reranker Set ENABLE_RERANK to true in reranking model is configed
|
||||||
# ENABLE_RERANK=True
|
# ENABLE_RERANK=True
|
||||||
### Minimum rerank score for document chunk exclusion (set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
|
### Minimum rerank score for document chunk exclusion (set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
|
||||||
# MIN_RERANK_SCORE=0.0
|
# MIN_RERANK_SCORE=0.0
|
||||||
|
|
@ -258,7 +268,7 @@ POSTGRES_IVFFLAT_LISTS=100
|
||||||
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
||||||
NEO4J_USERNAME=neo4j
|
NEO4J_USERNAME=neo4j
|
||||||
NEO4J_PASSWORD='your_password'
|
NEO4J_PASSWORD='your_password'
|
||||||
# NEO4J_DATABASE=chunk_entity_relation
|
# NEO4J_DATABASE=chunk-entity-relation
|
||||||
NEO4J_MAX_CONNECTION_POOL_SIZE=100
|
NEO4J_MAX_CONNECTION_POOL_SIZE=100
|
||||||
NEO4J_CONNECTION_TIMEOUT=30
|
NEO4J_CONNECTION_TIMEOUT=30
|
||||||
NEO4J_CONNECTION_ACQUISITION_TIMEOUT=30
|
NEO4J_CONNECTION_ACQUISITION_TIMEOUT=30
|
||||||
|
|
|
||||||
|
|
@ -290,6 +290,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||||
ids: List of vector IDs to be deleted
|
ids: List of vector IDs to be deleted
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||||
|
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of unique identifiers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping IDs to their vector embeddings
|
||||||
|
Format: {id: [vector_values], ...}
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseKVStorage(StorageNameSpace, ABC):
|
class BaseKVStorage(StorageNameSpace, ABC):
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ DEFAULT_MAX_RELATION_TOKENS = 10000
|
||||||
DEFAULT_MAX_TOTAL_TOKENS = 30000
|
DEFAULT_MAX_TOTAL_TOKENS = 30000
|
||||||
DEFAULT_COSINE_THRESHOLD = 0.2
|
DEFAULT_COSINE_THRESHOLD = 0.2
|
||||||
DEFAULT_RELATED_CHUNK_NUMBER = 5
|
DEFAULT_RELATED_CHUNK_NUMBER = 5
|
||||||
|
DEFAULT_KG_CHUNK_PICK_METHOD = "VECTOR"
|
||||||
# Deprated: history message have negtive effect on query performance
|
# Deprated: history message have negtive effect on query performance
|
||||||
DEFAULT_HISTORY_TURNS = 0
|
DEFAULT_HISTORY_TURNS = 0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -210,9 +210,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
meta = self._id_to_meta.get(idx, {})
|
meta = self._id_to_meta.get(idx, {})
|
||||||
|
# Filter out __vector__ from query results to avoid returning large vector data
|
||||||
|
filtered_meta = {k: v for k, v in meta.items() if k != "__vector__"}
|
||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
**meta,
|
**filtered_meta,
|
||||||
"id": meta.get("__id__"),
|
"id": meta.get("__id__"),
|
||||||
"distance": float(dist),
|
"distance": float(dist),
|
||||||
"created_at": meta.get("__created_at__"),
|
"created_at": meta.get("__created_at__"),
|
||||||
|
|
@ -424,8 +426,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||||
if not metadata:
|
if not metadata:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Filter out __vector__ from metadata to avoid returning large vector data
|
||||||
|
filtered_metadata = {k: v for k, v in metadata.items() if k != "__vector__"}
|
||||||
return {
|
return {
|
||||||
**metadata,
|
**filtered_metadata,
|
||||||
"id": metadata.get("__id__"),
|
"id": metadata.get("__id__"),
|
||||||
"created_at": metadata.get("__created_at__"),
|
"created_at": metadata.get("__created_at__"),
|
||||||
}
|
}
|
||||||
|
|
@ -448,9 +452,13 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||||
if fid is not None:
|
if fid is not None:
|
||||||
metadata = self._id_to_meta.get(fid, {})
|
metadata = self._id_to_meta.get(fid, {})
|
||||||
if metadata:
|
if metadata:
|
||||||
|
# Filter out __vector__ from metadata to avoid returning large vector data
|
||||||
|
filtered_metadata = {
|
||||||
|
k: v for k, v in metadata.items() if k != "__vector__"
|
||||||
|
}
|
||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
**metadata,
|
**filtered_metadata,
|
||||||
"id": metadata.get("__id__"),
|
"id": metadata.get("__id__"),
|
||||||
"created_at": metadata.get("__created_at__"),
|
"created_at": metadata.get("__created_at__"),
|
||||||
}
|
}
|
||||||
|
|
@ -458,6 +466,31 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||||
|
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of unique identifiers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping IDs to their vector embeddings
|
||||||
|
Format: {id: [vector_values], ...}
|
||||||
|
"""
|
||||||
|
if not ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
vectors_dict = {}
|
||||||
|
for id in ids:
|
||||||
|
# Find the Faiss internal ID for the custom ID
|
||||||
|
fid = self._find_faiss_id_by_custom_id(id)
|
||||||
|
if fid is not None and fid in self._id_to_meta:
|
||||||
|
metadata = self._id_to_meta[fid]
|
||||||
|
# Get the stored vector from metadata
|
||||||
|
if "__vector__" in metadata:
|
||||||
|
vectors_dict[id] = metadata["__vector__"]
|
||||||
|
|
||||||
|
return vectors_dict
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop all vector data from storage and clean up resources
|
"""Drop all vector data from storage and clean up resources
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1018,6 +1018,50 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||||
|
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of unique identifiers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping IDs to their vector embeddings
|
||||||
|
Format: {id: [vector_values], ...}
|
||||||
|
"""
|
||||||
|
if not ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ensure collection is loaded before querying
|
||||||
|
self._ensure_collection_loaded()
|
||||||
|
|
||||||
|
# Prepare the ID filter expression
|
||||||
|
id_list = '", "'.join(ids)
|
||||||
|
filter_expr = f'id in ["{id_list}"]'
|
||||||
|
|
||||||
|
# Query Milvus with the filter, requesting only vector field
|
||||||
|
result = self._client.query(
|
||||||
|
collection_name=self.final_namespace,
|
||||||
|
filter=filter_expr,
|
||||||
|
output_fields=["vector"],
|
||||||
|
)
|
||||||
|
|
||||||
|
vectors_dict = {}
|
||||||
|
for item in result:
|
||||||
|
if item and "vector" in item and "id" in item:
|
||||||
|
# Convert numpy array to list if needed
|
||||||
|
vector_data = item["vector"]
|
||||||
|
if isinstance(vector_data, np.ndarray):
|
||||||
|
vector_data = vector_data.tolist()
|
||||||
|
vectors_dict[item["id"]] = vector_data
|
||||||
|
|
||||||
|
return vectors_dict
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Error retrieving vectors by IDs from {self.namespace}: {e}"
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop all vector data from storage and clean up resources
|
"""Drop all vector data from storage and clean up resources
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1967,6 +1967,37 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||||
|
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of unique identifiers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping IDs to their vector embeddings
|
||||||
|
Format: {id: [vector_values], ...}
|
||||||
|
"""
|
||||||
|
if not ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Query MongoDB for the specified IDs, only retrieving the vector field
|
||||||
|
cursor = self._data.find({"_id": {"$in": ids}}, {"vector": 1})
|
||||||
|
results = await cursor.to_list(length=None)
|
||||||
|
|
||||||
|
vectors_dict = {}
|
||||||
|
for result in results:
|
||||||
|
if result and "vector" in result and "_id" in result:
|
||||||
|
# MongoDB stores vectors as arrays, so they should already be lists
|
||||||
|
vectors_dict[result["_id"]] = result["vector"]
|
||||||
|
|
||||||
|
return vectors_dict
|
||||||
|
except PyMongoError as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Error retrieving vectors by IDs from {self.namespace}: {e}"
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop the storage by removing all documents in the collection and recreating vector index.
|
"""Drop the storage by removing all documents in the collection and recreating vector index.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import os
|
import os
|
||||||
|
import zlib
|
||||||
from typing import Any, final
|
from typing import Any, final
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -93,8 +95,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
2. Only one process should updating the storage at a time before index_done_callback,
|
2. Only one process should updating the storage at a time before index_done_callback,
|
||||||
KG-storage-log should be used to avoid data corruption
|
KG-storage-log should be used to avoid data corruption
|
||||||
"""
|
"""
|
||||||
|
# logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
|
||||||
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
|
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -120,6 +121,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
embeddings = np.concatenate(embeddings_list)
|
embeddings = np.concatenate(embeddings_list)
|
||||||
if len(embeddings) == len(list_data):
|
if len(embeddings) == len(list_data):
|
||||||
for i, d in enumerate(list_data):
|
for i, d in enumerate(list_data):
|
||||||
|
# Compress vector using Float16 + zlib + Base64 for storage optimization
|
||||||
|
vector_f16 = embeddings[i].astype(np.float16)
|
||||||
|
compressed_vector = zlib.compress(vector_f16.tobytes())
|
||||||
|
encoded_vector = base64.b64encode(compressed_vector).decode("utf-8")
|
||||||
|
d["vector"] = encoded_vector
|
||||||
d["__vector__"] = embeddings[i]
|
d["__vector__"] = embeddings[i]
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
results = client.upsert(datas=list_data)
|
results = client.upsert(datas=list_data)
|
||||||
|
|
@ -147,7 +153,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
results = [
|
results = [
|
||||||
{
|
{
|
||||||
**dp,
|
**{k: v for k, v in dp.items() if k != "vector"},
|
||||||
"id": dp["__id__"],
|
"id": dp["__id__"],
|
||||||
"distance": dp["__metrics__"],
|
"distance": dp["__metrics__"],
|
||||||
"created_at": dp.get("__created_at__"),
|
"created_at": dp.get("__created_at__"),
|
||||||
|
|
@ -296,7 +302,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
if result:
|
if result:
|
||||||
dp = result[0]
|
dp = result[0]
|
||||||
return {
|
return {
|
||||||
**dp,
|
**{k: v for k, v in dp.items() if k != "vector"},
|
||||||
"id": dp.get("__id__"),
|
"id": dp.get("__id__"),
|
||||||
"created_at": dp.get("__created_at__"),
|
"created_at": dp.get("__created_at__"),
|
||||||
}
|
}
|
||||||
|
|
@ -318,13 +324,41 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
results = client.get(ids)
|
results = client.get(ids)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
**dp,
|
**{k: v for k, v in dp.items() if k != "vector"},
|
||||||
"id": dp.get("__id__"),
|
"id": dp.get("__id__"),
|
||||||
"created_at": dp.get("__created_at__"),
|
"created_at": dp.get("__created_at__"),
|
||||||
}
|
}
|
||||||
for dp in results
|
for dp in results
|
||||||
]
|
]
|
||||||
|
|
||||||
|
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||||
|
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of unique identifiers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping IDs to their vector embeddings
|
||||||
|
Format: {id: [vector_values], ...}
|
||||||
|
"""
|
||||||
|
if not ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
client = await self._get_client()
|
||||||
|
results = client.get(ids)
|
||||||
|
|
||||||
|
vectors_dict = {}
|
||||||
|
for result in results:
|
||||||
|
if result and "vector" in result and "__id__" in result:
|
||||||
|
# Decompress vector data (Base64 + zlib + Float16 compressed)
|
||||||
|
decoded = base64.b64decode(result["vector"])
|
||||||
|
decompressed = zlib.decompress(decoded)
|
||||||
|
vector_f16 = np.frombuffer(decompressed, dtype=np.float16)
|
||||||
|
vector_f32 = vector_f16.astype(np.float32).tolist()
|
||||||
|
vectors_dict[result["__id__"]] = vector_f32
|
||||||
|
|
||||||
|
return vectors_dict
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop all vector data from storage and clean up resources
|
"""Drop all vector data from storage and clean up resources
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -155,7 +155,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
except neo4jExceptions.ServiceUnavailable as e:
|
except neo4jExceptions.ServiceUnavailable as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] "
|
f"[{self.workspace}] "
|
||||||
+ f"{database} at {URI} is not available".capitalize()
|
+ f"Database {database} at {URI} is not available"
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
except neo4jExceptions.AuthError as e:
|
except neo4jExceptions.AuthError as e:
|
||||||
|
|
@ -167,7 +167,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] "
|
f"[{self.workspace}] "
|
||||||
+ f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
+ f"Database {database} at {URI} not found. Try to create specified database."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
async with self._driver.session() as session:
|
async with self._driver.session() as session:
|
||||||
|
|
@ -177,7 +177,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume() # Ensure result is consumed
|
await result.consume() # Ensure result is consumed
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] "
|
f"[{self.workspace}] "
|
||||||
+ f"{database} at {URI} created".capitalize()
|
+ f"Database {database} at {URI} created"
|
||||||
)
|
)
|
||||||
connected = True
|
connected = True
|
||||||
except (
|
except (
|
||||||
|
|
|
||||||
|
|
@ -2158,6 +2158,53 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||||
|
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of unique identifiers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping IDs to their vector embeddings
|
||||||
|
Format: {id: [vector_values], ...}
|
||||||
|
"""
|
||||||
|
if not ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
table_name = namespace_to_table_name(self.namespace)
|
||||||
|
if not table_name:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Unknown namespace for vector lookup: {self.namespace}"
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
ids_str = ",".join([f"'{id}'" for id in ids])
|
||||||
|
query = f"SELECT id, content_vector FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
|
||||||
|
params = {"workspace": self.workspace}
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = await self.db.query(query, params, multirows=True)
|
||||||
|
vectors_dict = {}
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
if result and "content_vector" in result and "id" in result:
|
||||||
|
try:
|
||||||
|
# Parse JSON string to get vector as list of floats
|
||||||
|
vector_data = json.loads(result["content_vector"])
|
||||||
|
if isinstance(vector_data, list):
|
||||||
|
vectors_dict[result["id"]] = vector_data
|
||||||
|
except (json.JSONDecodeError, TypeError) as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.workspace}] Failed to parse vector data for ID {result['id']}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return vectors_dict
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Error retrieving vectors by IDs from {self.namespace}: {e}"
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop the storage"""
|
"""Drop the storage"""
|
||||||
async with get_storage_lock():
|
async with get_storage_lock():
|
||||||
|
|
|
||||||
|
|
@ -402,6 +402,50 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]:
|
||||||
|
"""Get vectors by their IDs, returning only ID and vector data for efficiency
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of unique identifiers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping IDs to their vector embeddings
|
||||||
|
Format: {id: [vector_values], ...}
|
||||||
|
"""
|
||||||
|
if not ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert to Qdrant compatible IDs
|
||||||
|
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
|
||||||
|
|
||||||
|
# Retrieve the points by IDs with vectors
|
||||||
|
results = self._client.retrieve(
|
||||||
|
collection_name=self.final_namespace,
|
||||||
|
ids=qdrant_ids,
|
||||||
|
with_vectors=True, # Important: request vectors
|
||||||
|
with_payload=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
vectors_dict = {}
|
||||||
|
for point in results:
|
||||||
|
if point and point.vector is not None and point.payload:
|
||||||
|
# Get original ID from payload
|
||||||
|
original_id = point.payload.get("id")
|
||||||
|
if original_id:
|
||||||
|
# Convert numpy array to list if needed
|
||||||
|
vector_data = point.vector
|
||||||
|
if isinstance(vector_data, np.ndarray):
|
||||||
|
vector_data = vector_data.tolist()
|
||||||
|
vectors_dict[original_id] = vector_data
|
||||||
|
|
||||||
|
return vectors_dict
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Error retrieving vectors by IDs from {self.namespace}: {e}"
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop all vector data from storage and clean up resources
|
"""Drop all vector data from storage and clean up resources
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from lightrag.constants import (
|
||||||
DEFAULT_MAX_TOTAL_TOKENS,
|
DEFAULT_MAX_TOTAL_TOKENS,
|
||||||
DEFAULT_COSINE_THRESHOLD,
|
DEFAULT_COSINE_THRESHOLD,
|
||||||
DEFAULT_RELATED_CHUNK_NUMBER,
|
DEFAULT_RELATED_CHUNK_NUMBER,
|
||||||
|
DEFAULT_KG_CHUNK_PICK_METHOD,
|
||||||
DEFAULT_MIN_RERANK_SCORE,
|
DEFAULT_MIN_RERANK_SCORE,
|
||||||
DEFAULT_SUMMARY_MAX_TOKENS,
|
DEFAULT_SUMMARY_MAX_TOKENS,
|
||||||
DEFAULT_MAX_ASYNC,
|
DEFAULT_MAX_ASYNC,
|
||||||
|
|
@ -175,6 +176,11 @@ class LightRAG:
|
||||||
)
|
)
|
||||||
"""Number of related chunks to grab from single entity or relation."""
|
"""Number of related chunks to grab from single entity or relation."""
|
||||||
|
|
||||||
|
kg_chunk_pick_method: str = field(
|
||||||
|
default=get_env_value("KG_CHUNK_PICK_METHOD", DEFAULT_KG_CHUNK_PICK_METHOD, str)
|
||||||
|
)
|
||||||
|
"""Method for selecting text chunks: 'WEIGHT' for weight-based selection, 'VECTOR' for embedding similarity-based selection."""
|
||||||
|
|
||||||
# Entity extraction
|
# Entity extraction
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,8 @@ from .utils import (
|
||||||
use_llm_func_with_cache,
|
use_llm_func_with_cache,
|
||||||
update_chunk_cache_list,
|
update_chunk_cache_list,
|
||||||
remove_think_tags,
|
remove_think_tags,
|
||||||
linear_gradient_weighted_polling,
|
pick_by_weighted_polling,
|
||||||
|
pick_by_vector_similarity,
|
||||||
process_chunks_unified,
|
process_chunks_unified,
|
||||||
build_file_path,
|
build_file_path,
|
||||||
)
|
)
|
||||||
|
|
@ -45,6 +46,7 @@ from .constants import (
|
||||||
DEFAULT_MAX_RELATION_TOKENS,
|
DEFAULT_MAX_RELATION_TOKENS,
|
||||||
DEFAULT_MAX_TOTAL_TOKENS,
|
DEFAULT_MAX_TOTAL_TOKENS,
|
||||||
DEFAULT_RELATED_CHUNK_NUMBER,
|
DEFAULT_RELATED_CHUNK_NUMBER,
|
||||||
|
DEFAULT_KG_CHUNK_PICK_METHOD,
|
||||||
)
|
)
|
||||||
from .kg.shared_storage import get_storage_keyed_lock
|
from .kg.shared_storage import get_storage_keyed_lock
|
||||||
import time
|
import time
|
||||||
|
|
@ -2105,6 +2107,9 @@ async def _build_query_context(
|
||||||
global_entities = []
|
global_entities = []
|
||||||
global_relations = []
|
global_relations = []
|
||||||
|
|
||||||
|
# Track chunk sources and metadata for final logging
|
||||||
|
chunk_tracking = {} # chunk_id -> {source, frequency, order}
|
||||||
|
|
||||||
# Handle local and global modes
|
# Handle local and global modes
|
||||||
if query_param.mode == "local":
|
if query_param.mode == "local":
|
||||||
local_entities, local_relations = await _get_node_data(
|
local_entities, local_relations = await _get_node_data(
|
||||||
|
|
@ -2143,6 +2148,15 @@ async def _build_query_context(
|
||||||
chunks_vdb,
|
chunks_vdb,
|
||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
|
# Track vector chunks with source metadata
|
||||||
|
for i, chunk in enumerate(vector_chunks):
|
||||||
|
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
||||||
|
if chunk_id:
|
||||||
|
chunk_tracking[chunk_id] = {
|
||||||
|
"source": "C",
|
||||||
|
"frequency": 1, # Vector chunks always have frequency 1
|
||||||
|
"order": i + 1, # 1-based order in vector search results
|
||||||
|
}
|
||||||
|
|
||||||
# Use round-robin merge to combine local and global data fairly
|
# Use round-robin merge to combine local and global data fairly
|
||||||
final_entities = []
|
final_entities = []
|
||||||
|
|
@ -2340,20 +2354,29 @@ async def _build_query_context(
|
||||||
seen_edges.add(pair)
|
seen_edges.add(pair)
|
||||||
|
|
||||||
# Get text chunks based on final filtered data
|
# Get text chunks based on final filtered data
|
||||||
|
# To preserve the influence of entity order, entiy-based chunks should not be deduplcicated by vector_chunks
|
||||||
if final_node_datas:
|
if final_node_datas:
|
||||||
entity_chunks = await _find_most_related_text_unit_from_entities(
|
entity_chunks = await _find_related_text_unit_from_entities(
|
||||||
final_node_datas,
|
final_node_datas,
|
||||||
query_param,
|
query_param,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
|
query,
|
||||||
|
chunks_vdb,
|
||||||
|
chunk_tracking=chunk_tracking,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Find deduplcicated chunks from edge
|
||||||
|
# Deduplication cause chunks solely relation-based to be prioritized and sent to the LLM when re-ranking is disabled
|
||||||
if final_edge_datas:
|
if final_edge_datas:
|
||||||
relation_chunks = await _find_related_text_unit_from_relationships(
|
relation_chunks = await _find_related_text_unit_from_relations(
|
||||||
final_edge_datas,
|
final_edge_datas,
|
||||||
query_param,
|
query_param,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
entity_chunks,
|
entity_chunks,
|
||||||
|
query,
|
||||||
|
chunks_vdb,
|
||||||
|
chunk_tracking=chunk_tracking,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Round-robin merge chunks from different sources with deduplication by chunk_id
|
# Round-robin merge chunks from different sources with deduplication by chunk_id
|
||||||
|
|
@ -2373,6 +2396,7 @@ async def _build_query_context(
|
||||||
{
|
{
|
||||||
"content": chunk["content"],
|
"content": chunk["content"],
|
||||||
"file_path": chunk.get("file_path", "unknown_source"),
|
"file_path": chunk.get("file_path", "unknown_source"),
|
||||||
|
"chunk_id": chunk_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2386,6 +2410,7 @@ async def _build_query_context(
|
||||||
{
|
{
|
||||||
"content": chunk["content"],
|
"content": chunk["content"],
|
||||||
"file_path": chunk.get("file_path", "unknown_source"),
|
"file_path": chunk.get("file_path", "unknown_source"),
|
||||||
|
"chunk_id": chunk_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2399,10 +2424,11 @@ async def _build_query_context(
|
||||||
{
|
{
|
||||||
"content": chunk["content"],
|
"content": chunk["content"],
|
||||||
"file_path": chunk.get("file_path", "unknown_source"),
|
"file_path": chunk.get("file_path", "unknown_source"),
|
||||||
|
"chunk_id": chunk_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}"
|
f"Round-robin merged total chunks from {origin_len} to {len(merged_chunks)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2517,6 +2543,24 @@ async def _build_query_context(
|
||||||
if not entities_context and not relations_context:
|
if not entities_context and not relations_context:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# output chunks tracking infomations
|
||||||
|
# format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1)
|
||||||
|
if truncated_chunks and chunk_tracking:
|
||||||
|
chunk_tracking_log = []
|
||||||
|
for chunk in truncated_chunks:
|
||||||
|
chunk_id = chunk.get("chunk_id")
|
||||||
|
if chunk_id and chunk_id in chunk_tracking:
|
||||||
|
tracking_info = chunk_tracking[chunk_id]
|
||||||
|
source = tracking_info["source"]
|
||||||
|
frequency = tracking_info["frequency"]
|
||||||
|
order = tracking_info["order"]
|
||||||
|
chunk_tracking_log.append(f"{source}{frequency}/{order}")
|
||||||
|
else:
|
||||||
|
chunk_tracking_log.append("?0/0")
|
||||||
|
|
||||||
|
if chunk_tracking_log:
|
||||||
|
logger.info(f"chunks: {' '.join(chunk_tracking_log)}")
|
||||||
|
|
||||||
entities_str = json.dumps(entities_context, ensure_ascii=False)
|
entities_str = json.dumps(entities_context, ensure_ascii=False)
|
||||||
relations_str = json.dumps(relations_context, ensure_ascii=False)
|
relations_str = json.dumps(relations_context, ensure_ascii=False)
|
||||||
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
||||||
|
|
@ -2603,104 +2647,6 @@ async def _get_node_data(
|
||||||
return node_datas, use_relations
|
return node_datas, use_relations
|
||||||
|
|
||||||
|
|
||||||
async def _find_most_related_text_unit_from_entities(
|
|
||||||
node_datas: list[dict],
|
|
||||||
query_param: QueryParam,
|
|
||||||
text_chunks_db: BaseKVStorage,
|
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Find text chunks related to entities using linear gradient weighted polling algorithm.
|
|
||||||
|
|
||||||
This function implements the optimized text chunk selection strategy:
|
|
||||||
1. Sort text chunks for each entity by occurrence count in other entities
|
|
||||||
2. Use linear gradient weighted polling to select chunks fairly
|
|
||||||
"""
|
|
||||||
logger.debug(f"Searching text chunks for {len(node_datas)} entities")
|
|
||||||
|
|
||||||
if not node_datas:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Step 1: Collect all text chunks for each entity
|
|
||||||
entities_with_chunks = []
|
|
||||||
for entity in node_datas:
|
|
||||||
if entity.get("source_id"):
|
|
||||||
chunks = split_string_by_multi_markers(
|
|
||||||
entity["source_id"], [GRAPH_FIELD_SEP]
|
|
||||||
)
|
|
||||||
if chunks:
|
|
||||||
entities_with_chunks.append(
|
|
||||||
{
|
|
||||||
"entity_name": entity["entity_name"],
|
|
||||||
"chunks": chunks,
|
|
||||||
"entity_data": entity,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not entities_with_chunks:
|
|
||||||
logger.warning("No entities with text chunks found")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities)
|
|
||||||
chunk_occurrence_count = {}
|
|
||||||
for entity_info in entities_with_chunks:
|
|
||||||
deduplicated_chunks = []
|
|
||||||
for chunk_id in entity_info["chunks"]:
|
|
||||||
chunk_occurrence_count[chunk_id] = (
|
|
||||||
chunk_occurrence_count.get(chunk_id, 0) + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
|
|
||||||
if chunk_occurrence_count[chunk_id] == 1:
|
|
||||||
deduplicated_chunks.append(chunk_id)
|
|
||||||
# count > 1 means this chunk appeared in an earlier entity, so skip it
|
|
||||||
|
|
||||||
# Update entity's chunks to deduplicated chunks
|
|
||||||
entity_info["chunks"] = deduplicated_chunks
|
|
||||||
|
|
||||||
# Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority)
|
|
||||||
for entity_info in entities_with_chunks:
|
|
||||||
sorted_chunks = sorted(
|
|
||||||
entity_info["chunks"],
|
|
||||||
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
entity_info["sorted_chunks"] = sorted_chunks
|
|
||||||
|
|
||||||
# Step 4: Apply linear gradient weighted polling algorithm
|
|
||||||
max_related_chunks = text_chunks_db.global_config.get(
|
|
||||||
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
|
|
||||||
)
|
|
||||||
|
|
||||||
selected_chunk_ids = linear_gradient_weighted_polling(
|
|
||||||
entities_with_chunks, max_related_chunks, min_related_chunks=1
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Found {len(selected_chunk_ids)} entity-related chunks using linear gradient weighted polling"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not selected_chunk_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Step 5: Batch retrieve chunk data
|
|
||||||
unique_chunk_ids = list(
|
|
||||||
dict.fromkeys(selected_chunk_ids)
|
|
||||||
) # Remove duplicates while preserving order
|
|
||||||
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
|
|
||||||
|
|
||||||
# Step 6: Build result chunks with valid data
|
|
||||||
result_chunks = []
|
|
||||||
for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list):
|
|
||||||
if chunk_data is not None and "content" in chunk_data:
|
|
||||||
chunk_data_copy = chunk_data.copy()
|
|
||||||
chunk_data_copy["source_type"] = "entity"
|
|
||||||
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
|
|
||||||
result_chunks.append(chunk_data_copy)
|
|
||||||
|
|
||||||
return result_chunks
|
|
||||||
|
|
||||||
|
|
||||||
async def _find_most_related_edges_from_entities(
|
async def _find_most_related_edges_from_entities(
|
||||||
node_datas: list[dict],
|
node_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
|
|
@ -2757,6 +2703,167 @@ async def _find_most_related_edges_from_entities(
|
||||||
return all_edges_data
|
return all_edges_data
|
||||||
|
|
||||||
|
|
||||||
|
async def _find_related_text_unit_from_entities(
|
||||||
|
node_datas: list[dict],
|
||||||
|
query_param: QueryParam,
|
||||||
|
text_chunks_db: BaseKVStorage,
|
||||||
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
|
query: str = None,
|
||||||
|
chunks_vdb: BaseVectorStorage = None,
|
||||||
|
chunk_tracking: dict = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Find text chunks related to entities using configurable chunk selection method.
|
||||||
|
|
||||||
|
This function supports two chunk selection strategies:
|
||||||
|
1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count
|
||||||
|
2. VECTOR: Vector similarity-based selection using embedding cosine similarity
|
||||||
|
"""
|
||||||
|
logger.debug(f"Finding text chunks from {len(node_datas)} entities")
|
||||||
|
|
||||||
|
if not node_datas:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Step 1: Collect all text chunks for each entity
|
||||||
|
entities_with_chunks = []
|
||||||
|
for entity in node_datas:
|
||||||
|
if entity.get("source_id"):
|
||||||
|
chunks = split_string_by_multi_markers(
|
||||||
|
entity["source_id"], [GRAPH_FIELD_SEP]
|
||||||
|
)
|
||||||
|
if chunks:
|
||||||
|
entities_with_chunks.append(
|
||||||
|
{
|
||||||
|
"entity_name": entity["entity_name"],
|
||||||
|
"chunks": chunks,
|
||||||
|
"entity_data": entity,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not entities_with_chunks:
|
||||||
|
logger.warning("No entities with text chunks found")
|
||||||
|
return []
|
||||||
|
|
||||||
|
kg_chunk_pick_method = text_chunks_db.global_config.get(
|
||||||
|
"kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
|
||||||
|
)
|
||||||
|
max_related_chunks = text_chunks_db.global_config.get(
|
||||||
|
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities)
|
||||||
|
chunk_occurrence_count = {}
|
||||||
|
for entity_info in entities_with_chunks:
|
||||||
|
deduplicated_chunks = []
|
||||||
|
for chunk_id in entity_info["chunks"]:
|
||||||
|
chunk_occurrence_count[chunk_id] = (
|
||||||
|
chunk_occurrence_count.get(chunk_id, 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position)
|
||||||
|
if chunk_occurrence_count[chunk_id] == 1:
|
||||||
|
deduplicated_chunks.append(chunk_id)
|
||||||
|
# count > 1 means this chunk appeared in an earlier entity, so skip it
|
||||||
|
|
||||||
|
# Update entity's chunks to deduplicated chunks
|
||||||
|
entity_info["chunks"] = deduplicated_chunks
|
||||||
|
|
||||||
|
# Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority)
|
||||||
|
total_entity_chunks = 0
|
||||||
|
for entity_info in entities_with_chunks:
|
||||||
|
sorted_chunks = sorted(
|
||||||
|
entity_info["chunks"],
|
||||||
|
key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
entity_info["sorted_chunks"] = sorted_chunks
|
||||||
|
total_entity_chunks += len(sorted_chunks)
|
||||||
|
|
||||||
|
selected_chunk_ids = [] # Initialize to avoid UnboundLocalError
|
||||||
|
|
||||||
|
# Step 4: Apply the selected chunk selection algorithm
|
||||||
|
# Pick by vector similarity:
|
||||||
|
# The order of text chunks aligns with the naive retrieval's destination.
|
||||||
|
# When reranking is disabled, the text chunks delivered to the LLM tend to favor naive retrieval.
|
||||||
|
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
|
||||||
|
num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2)
|
||||||
|
|
||||||
|
# Get embedding function from global config
|
||||||
|
embedding_func_config = text_chunks_db.embedding_func
|
||||||
|
if not embedding_func_config:
|
||||||
|
logger.warning("No embedding function found, falling back to WEIGHT method")
|
||||||
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
actual_embedding_func = embedding_func_config.func
|
||||||
|
|
||||||
|
selected_chunk_ids = None
|
||||||
|
if actual_embedding_func:
|
||||||
|
selected_chunk_ids = await pick_by_vector_similarity(
|
||||||
|
query=query,
|
||||||
|
text_chunks_storage=text_chunks_db,
|
||||||
|
chunks_vdb=chunks_vdb,
|
||||||
|
num_of_chunks=num_of_chunks,
|
||||||
|
entity_info=entities_with_chunks,
|
||||||
|
embedding_func=actual_embedding_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
if selected_chunk_ids == []:
|
||||||
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
|
logger.warning(
|
||||||
|
"No entity-related chunks selected by vector similarity, falling back to WEIGHT method"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in vector similarity sorting: {e}, falling back to WEIGHT method"
|
||||||
|
)
|
||||||
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
|
|
||||||
|
if kg_chunk_pick_method == "WEIGHT":
|
||||||
|
# Pick by entity and chunk weight:
|
||||||
|
# When reranking is disabled, delivered more solely KG related chunks to the LLM
|
||||||
|
selected_chunk_ids = pick_by_weighted_polling(
|
||||||
|
entities_with_chunks, max_related_chunks, min_related_chunks=1
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not selected_chunk_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Step 5: Batch retrieve chunk data
|
||||||
|
unique_chunk_ids = list(
|
||||||
|
dict.fromkeys(selected_chunk_ids)
|
||||||
|
) # Remove duplicates while preserving order
|
||||||
|
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
|
||||||
|
|
||||||
|
# Step 6: Build result chunks with valid data and update chunk tracking
|
||||||
|
result_chunks = []
|
||||||
|
for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)):
|
||||||
|
if chunk_data is not None and "content" in chunk_data:
|
||||||
|
chunk_data_copy = chunk_data.copy()
|
||||||
|
chunk_data_copy["source_type"] = "entity"
|
||||||
|
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
|
||||||
|
result_chunks.append(chunk_data_copy)
|
||||||
|
|
||||||
|
# Update chunk tracking if provided
|
||||||
|
if chunk_tracking is not None:
|
||||||
|
chunk_tracking[chunk_id] = {
|
||||||
|
"source": "E",
|
||||||
|
"frequency": chunk_occurrence_count.get(chunk_id, 1),
|
||||||
|
"order": i + 1, # 1-based order in final entity-related results
|
||||||
|
}
|
||||||
|
|
||||||
|
return result_chunks
|
||||||
|
|
||||||
|
|
||||||
async def _get_edge_data(
|
async def _get_edge_data(
|
||||||
keywords,
|
keywords,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
|
|
@ -2848,20 +2955,23 @@ async def _find_most_related_entities_from_relationships(
|
||||||
return node_datas
|
return node_datas
|
||||||
|
|
||||||
|
|
||||||
async def _find_related_text_unit_from_relationships(
|
async def _find_related_text_unit_from_relations(
|
||||||
edge_datas: list[dict],
|
edge_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
entity_chunks: list[dict] = None,
|
entity_chunks: list[dict] = None,
|
||||||
|
query: str = None,
|
||||||
|
chunks_vdb: BaseVectorStorage = None,
|
||||||
|
chunk_tracking: dict = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Find text chunks related to relationships using linear gradient weighted polling algorithm.
|
Find text chunks related to relationships using configurable chunk selection method.
|
||||||
|
|
||||||
This function implements the optimized text chunk selection strategy:
|
This function supports two chunk selection strategies:
|
||||||
1. Sort text chunks for each relationship by occurrence count in other relationships
|
1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count
|
||||||
2. Use linear gradient weighted polling to select chunks fairly
|
2. VECTOR: Vector similarity-based selection using embedding cosine similarity
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Searching text chunks for {len(edge_datas)} relationships")
|
logger.debug(f"Finding text chunks from {len(edge_datas)} relations")
|
||||||
|
|
||||||
if not edge_datas:
|
if not edge_datas:
|
||||||
return []
|
return []
|
||||||
|
|
@ -2891,14 +3001,40 @@ async def _find_related_text_unit_from_relationships(
|
||||||
)
|
)
|
||||||
|
|
||||||
if not relations_with_chunks:
|
if not relations_with_chunks:
|
||||||
logger.warning("No relationships with text chunks found")
|
logger.warning("No relation-related chunks found")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
kg_chunk_pick_method = text_chunks_db.global_config.get(
|
||||||
|
"kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD
|
||||||
|
)
|
||||||
|
max_related_chunks = text_chunks_db.global_config.get(
|
||||||
|
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
|
||||||
|
)
|
||||||
|
|
||||||
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships)
|
# Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships)
|
||||||
|
# Also remove duplicates with entity_chunks
|
||||||
|
|
||||||
|
# Extract chunk IDs from entity_chunks for deduplication
|
||||||
|
entity_chunk_ids = set()
|
||||||
|
if entity_chunks:
|
||||||
|
for chunk in entity_chunks:
|
||||||
|
chunk_id = chunk.get("chunk_id")
|
||||||
|
if chunk_id:
|
||||||
|
entity_chunk_ids.add(chunk_id)
|
||||||
|
|
||||||
chunk_occurrence_count = {}
|
chunk_occurrence_count = {}
|
||||||
|
# Track unique chunk_ids that have been removed to avoid double counting
|
||||||
|
removed_entity_chunk_ids = set()
|
||||||
|
|
||||||
for relation_info in relations_with_chunks:
|
for relation_info in relations_with_chunks:
|
||||||
deduplicated_chunks = []
|
deduplicated_chunks = []
|
||||||
for chunk_id in relation_info["chunks"]:
|
for chunk_id in relation_info["chunks"]:
|
||||||
|
# Skip chunks that already exist in entity_chunks
|
||||||
|
if chunk_id in entity_chunk_ids:
|
||||||
|
# Only count each unique chunk_id once
|
||||||
|
removed_entity_chunk_ids.add(chunk_id)
|
||||||
|
continue
|
||||||
|
|
||||||
chunk_occurrence_count[chunk_id] = (
|
chunk_occurrence_count[chunk_id] = (
|
||||||
chunk_occurrence_count.get(chunk_id, 0) + 1
|
chunk_occurrence_count.get(chunk_id, 0) + 1
|
||||||
)
|
)
|
||||||
|
|
@ -2911,7 +3047,21 @@ async def _find_related_text_unit_from_relationships(
|
||||||
# Update relationship's chunks to deduplicated chunks
|
# Update relationship's chunks to deduplicated chunks
|
||||||
relation_info["chunks"] = deduplicated_chunks
|
relation_info["chunks"] = deduplicated_chunks
|
||||||
|
|
||||||
|
# Check if any relations still have chunks after deduplication
|
||||||
|
relations_with_chunks = [
|
||||||
|
relation_info
|
||||||
|
for relation_info in relations_with_chunks
|
||||||
|
if relation_info["chunks"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if not relations_with_chunks:
|
||||||
|
logger.info(
|
||||||
|
f"Find no additional relations-related chunks from {len(edge_datas)} relations"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
# Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority)
|
# Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority)
|
||||||
|
total_relation_chunks = 0
|
||||||
for relation_info in relations_with_chunks:
|
for relation_info in relations_with_chunks:
|
||||||
sorted_chunks = sorted(
|
sorted_chunks = sorted(
|
||||||
relation_info["chunks"],
|
relation_info["chunks"],
|
||||||
|
|
@ -2919,66 +3069,93 @@ async def _find_related_text_unit_from_relationships(
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
relation_info["sorted_chunks"] = sorted_chunks
|
relation_info["sorted_chunks"] = sorted_chunks
|
||||||
|
total_relation_chunks += len(sorted_chunks)
|
||||||
|
|
||||||
# Step 4: Apply linear gradient weighted polling algorithm
|
logger.info(
|
||||||
max_related_chunks = text_chunks_db.global_config.get(
|
f"Find {total_relation_chunks} additional chunks in {len(relations_with_chunks)} relations ({len(removed_entity_chunk_ids)} duplicated chunks removed)"
|
||||||
"related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER
|
|
||||||
)
|
)
|
||||||
|
|
||||||
selected_chunk_ids = linear_gradient_weighted_polling(
|
# Step 4: Apply the selected chunk selection algorithm
|
||||||
relations_with_chunks, max_related_chunks, min_related_chunks=1
|
selected_chunk_ids = [] # Initialize to avoid UnboundLocalError
|
||||||
)
|
|
||||||
|
if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb:
|
||||||
|
num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2)
|
||||||
|
|
||||||
|
# Get embedding function from global config
|
||||||
|
embedding_func_config = text_chunks_db.embedding_func
|
||||||
|
if not embedding_func_config:
|
||||||
|
logger.warning("No embedding function found, falling back to WEIGHT method")
|
||||||
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
actual_embedding_func = embedding_func_config.func
|
||||||
|
|
||||||
|
if actual_embedding_func:
|
||||||
|
selected_chunk_ids = await pick_by_vector_similarity(
|
||||||
|
query=query,
|
||||||
|
text_chunks_storage=text_chunks_db,
|
||||||
|
chunks_vdb=chunks_vdb,
|
||||||
|
num_of_chunks=num_of_chunks,
|
||||||
|
entity_info=relations_with_chunks,
|
||||||
|
embedding_func=actual_embedding_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
if selected_chunk_ids == []:
|
||||||
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
|
logger.warning(
|
||||||
|
"No relation-related chunks selected by vector similarity, falling back to WEIGHT method"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by vector similarity"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in vector similarity sorting: {e}, falling back to WEIGHT method"
|
||||||
|
)
|
||||||
|
kg_chunk_pick_method = "WEIGHT"
|
||||||
|
|
||||||
|
if kg_chunk_pick_method == "WEIGHT":
|
||||||
|
# Apply linear gradient weighted polling algorithm
|
||||||
|
selected_chunk_ids = pick_by_weighted_polling(
|
||||||
|
relations_with_chunks, max_related_chunks, min_related_chunks=1
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by weighted polling"
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Found {len(selected_chunk_ids)} relationship-related chunks using linear gradient weighted polling"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations"
|
f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not selected_chunk_ids:
|
if not selected_chunk_ids:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Step 4.5: Remove duplicates with entity_chunks before batch retrieval
|
|
||||||
if entity_chunks:
|
|
||||||
# Extract chunk IDs from entity_chunks
|
|
||||||
entity_chunk_ids = set()
|
|
||||||
for chunk in entity_chunks:
|
|
||||||
chunk_id = chunk.get("chunk_id")
|
|
||||||
if chunk_id:
|
|
||||||
entity_chunk_ids.add(chunk_id)
|
|
||||||
|
|
||||||
# Filter out duplicate chunk IDs
|
|
||||||
original_count = len(selected_chunk_ids)
|
|
||||||
selected_chunk_ids = [
|
|
||||||
chunk_id
|
|
||||||
for chunk_id in selected_chunk_ids
|
|
||||||
if chunk_id not in entity_chunk_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Deduplication relation-chunks with entity-chunks: {original_count} -> {len(selected_chunk_ids)} chunks "
|
|
||||||
)
|
|
||||||
|
|
||||||
# Early return if no chunks remain after deduplication
|
|
||||||
if not selected_chunk_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Step 5: Batch retrieve chunk data
|
# Step 5: Batch retrieve chunk data
|
||||||
unique_chunk_ids = list(
|
unique_chunk_ids = list(
|
||||||
dict.fromkeys(selected_chunk_ids)
|
dict.fromkeys(selected_chunk_ids)
|
||||||
) # Remove duplicates while preserving order
|
) # Remove duplicates while preserving order
|
||||||
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
|
chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids)
|
||||||
|
|
||||||
# Step 6: Build result chunks with valid data
|
# Step 6: Build result chunks with valid data and update chunk tracking
|
||||||
result_chunks = []
|
result_chunks = []
|
||||||
for chunk_id, chunk_data in zip(unique_chunk_ids, chunk_data_list):
|
for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)):
|
||||||
if chunk_data is not None and "content" in chunk_data:
|
if chunk_data is not None and "content" in chunk_data:
|
||||||
chunk_data_copy = chunk_data.copy()
|
chunk_data_copy = chunk_data.copy()
|
||||||
chunk_data_copy["source_type"] = "relationship"
|
chunk_data_copy["source_type"] = "relationship"
|
||||||
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
|
chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication
|
||||||
result_chunks.append(chunk_data_copy)
|
result_chunks.append(chunk_data_copy)
|
||||||
|
|
||||||
|
# Update chunk tracking if provided
|
||||||
|
if chunk_tracking is not None:
|
||||||
|
chunk_tracking[chunk_id] = {
|
||||||
|
"source": "R",
|
||||||
|
"frequency": chunk_occurrence_count.get(chunk_id, 1),
|
||||||
|
"order": i + 1, # 1-based order in final relation-related results
|
||||||
|
}
|
||||||
|
|
||||||
return result_chunks
|
return result_chunks
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ def get_env_value(
|
||||||
|
|
||||||
# Use TYPE_CHECKING to avoid circular imports
|
# Use TYPE_CHECKING to avoid circular imports
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from lightrag.base import BaseKVStorage, QueryParam
|
from lightrag.base import BaseKVStorage, BaseVectorStorage, QueryParam
|
||||||
|
|
||||||
# use the .env that is inside the current folder
|
# use the .env that is inside the current folder
|
||||||
# allows to use different .env file for each lightrag instance
|
# allows to use different .env file for each lightrag instance
|
||||||
|
|
@ -1570,7 +1570,7 @@ def check_storage_env_vars(storage_name: str) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def linear_gradient_weighted_polling(
|
def pick_by_weighted_polling(
|
||||||
entities_or_relations: list[dict],
|
entities_or_relations: list[dict],
|
||||||
max_related_chunks: int,
|
max_related_chunks: int,
|
||||||
min_related_chunks: int = 1,
|
min_related_chunks: int = 1,
|
||||||
|
|
@ -1650,6 +1650,120 @@ def linear_gradient_weighted_polling(
|
||||||
return selected_chunks
|
return selected_chunks
|
||||||
|
|
||||||
|
|
||||||
|
async def pick_by_vector_similarity(
|
||||||
|
query: str,
|
||||||
|
text_chunks_storage: "BaseKVStorage",
|
||||||
|
chunks_vdb: "BaseVectorStorage",
|
||||||
|
num_of_chunks: int,
|
||||||
|
entity_info: list[dict[str, Any]],
|
||||||
|
embedding_func: callable,
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Vector similarity-based text chunk selection algorithm.
|
||||||
|
|
||||||
|
This algorithm selects text chunks based on cosine similarity between
|
||||||
|
the query embedding and text chunk embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: User's original query string
|
||||||
|
text_chunks_storage: Text chunks storage instance
|
||||||
|
chunks_vdb: Vector database storage for chunks
|
||||||
|
num_of_chunks: Number of chunks to select
|
||||||
|
entity_info: List of entity information containing chunk IDs
|
||||||
|
embedding_func: Embedding function to compute query embedding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of selected text chunk IDs sorted by similarity (highest first)
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
f"Vector similarity chunk selection: num_of_chunks={num_of_chunks}, entity_info_count={len(entity_info) if entity_info else 0}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not entity_info or num_of_chunks <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Collect all unique chunk IDs from entity info
|
||||||
|
all_chunk_ids = set()
|
||||||
|
for i, entity in enumerate(entity_info):
|
||||||
|
chunk_ids = entity.get("sorted_chunks", [])
|
||||||
|
all_chunk_ids.update(chunk_ids)
|
||||||
|
|
||||||
|
if not all_chunk_ids:
|
||||||
|
logger.warning(
|
||||||
|
"Vector similarity chunk selection: no chunk IDs found in entity_info"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Vector similarity chunk selection: {len(all_chunk_ids)} unique chunk IDs collected"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_chunk_ids = list(all_chunk_ids)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get query embedding
|
||||||
|
query_embedding = await embedding_func([query])
|
||||||
|
query_embedding = query_embedding[
|
||||||
|
0
|
||||||
|
] # Extract first embedding from batch result
|
||||||
|
|
||||||
|
# Get chunk embeddings from vector database
|
||||||
|
chunk_vectors = await chunks_vdb.get_vectors_by_ids(all_chunk_ids)
|
||||||
|
logger.debug(
|
||||||
|
f"Vector similarity chunk selection: {len(chunk_vectors)} chunk vectors Retrieved"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not chunk_vectors or len(chunk_vectors) != len(all_chunk_ids):
|
||||||
|
if not chunk_vectors:
|
||||||
|
logger.warning(
|
||||||
|
"Vector similarity chunk selection: no vectors retrieved from chunks_vdb"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Vector similarity chunk selection: found {len(chunk_vectors)} but expecting {len(all_chunk_ids)}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Calculate cosine similarities
|
||||||
|
similarities = []
|
||||||
|
valid_vectors = 0
|
||||||
|
for chunk_id in all_chunk_ids:
|
||||||
|
if chunk_id in chunk_vectors:
|
||||||
|
chunk_embedding = chunk_vectors[chunk_id]
|
||||||
|
try:
|
||||||
|
# Calculate cosine similarity
|
||||||
|
similarity = cosine_similarity(query_embedding, chunk_embedding)
|
||||||
|
similarities.append((chunk_id, similarity))
|
||||||
|
valid_vectors += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Vector similarity chunk selection: failed to calculate similarity for chunk {chunk_id}: {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Vector similarity chunk selection: no vector found for chunk {chunk_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by similarity (highest first) and select top num_of_chunks
|
||||||
|
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
selected_chunks = [chunk_id for chunk_id, _ in similarities[:num_of_chunks]]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Vector similarity chunk selection: {len(selected_chunks)} chunks from {len(all_chunk_ids)} candidates"
|
||||||
|
)
|
||||||
|
|
||||||
|
return selected_chunks
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[VECTOR_SIMILARITY] Error in vector similarity sorting: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(f"[VECTOR_SIMILARITY] Traceback: {traceback.format_exc()}")
|
||||||
|
# Fallback to simple truncation
|
||||||
|
logger.debug("[VECTOR_SIMILARITY] Falling back to simple truncation")
|
||||||
|
return all_chunk_ids[:num_of_chunks]
|
||||||
|
|
||||||
|
|
||||||
class TokenTracker:
|
class TokenTracker:
|
||||||
"""Track token usage for LLM calls."""
|
"""Track token usage for LLM calls."""
|
||||||
|
|
||||||
|
|
@ -1787,6 +1901,13 @@ async def process_chunks_unified(
|
||||||
|
|
||||||
# 1. Apply reranking if enabled and query is provided
|
# 1. Apply reranking if enabled and query is provided
|
||||||
if query_param.enable_rerank and query and unique_chunks:
|
if query_param.enable_rerank and query and unique_chunks:
|
||||||
|
# 保存 chunk_id 字段,因为 rerank 可能会丢失这个字段
|
||||||
|
chunk_ids = {}
|
||||||
|
for chunk in unique_chunks:
|
||||||
|
chunk_id = chunk.get("chunk_id")
|
||||||
|
if chunk_id:
|
||||||
|
chunk_ids[id(chunk)] = chunk_id
|
||||||
|
|
||||||
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
|
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
|
||||||
unique_chunks = await apply_rerank_if_enabled(
|
unique_chunks = await apply_rerank_if_enabled(
|
||||||
query=query,
|
query=query,
|
||||||
|
|
@ -1796,6 +1917,11 @@ async def process_chunks_unified(
|
||||||
top_n=rerank_top_k,
|
top_n=rerank_top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 恢复 chunk_id 字段
|
||||||
|
for chunk in unique_chunks:
|
||||||
|
if id(chunk) in chunk_ids:
|
||||||
|
chunk["chunk_id"] = chunk_ids[id(chunk)]
|
||||||
|
|
||||||
# 2. Filter by minimum rerank score if reranking is enabled
|
# 2. Filter by minimum rerank score if reranking is enabled
|
||||||
if query_param.enable_rerank and unique_chunks:
|
if query_param.enable_rerank and unique_chunks:
|
||||||
min_rerank_score = global_config.get("min_rerank_score", 0.5)
|
min_rerank_score = global_config.get("min_rerank_score", 0.5)
|
||||||
|
|
@ -1842,12 +1968,26 @@ async def process_chunks_unified(
|
||||||
)
|
)
|
||||||
|
|
||||||
original_count = len(unique_chunks)
|
original_count = len(unique_chunks)
|
||||||
|
|
||||||
|
# Keep chunk_id field, cause truncate_list_by_token_size will lose it
|
||||||
|
chunk_ids_map = {}
|
||||||
|
for i, chunk in enumerate(unique_chunks):
|
||||||
|
chunk_id = chunk.get("chunk_id")
|
||||||
|
if chunk_id:
|
||||||
|
chunk_ids_map[i] = chunk_id
|
||||||
|
|
||||||
unique_chunks = truncate_list_by_token_size(
|
unique_chunks = truncate_list_by_token_size(
|
||||||
unique_chunks,
|
unique_chunks,
|
||||||
key=lambda x: x.get("content", ""),
|
key=lambda x: x.get("content", ""),
|
||||||
max_token_size=chunk_token_limit,
|
max_token_size=chunk_token_limit,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# restore chunk_id feiled
|
||||||
|
for i, chunk in enumerate(unique_chunks):
|
||||||
|
if i in chunk_ids_map:
|
||||||
|
chunk["chunk_id"] = chunk_ids_map[i]
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
|
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
|
||||||
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"
|
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue