Feat: add fault-tolerant mechanism to GraphRAG

This commit is contained in:
yongtenglei 2025-11-12 13:41:53 +08:00
parent a316e6bc1b
commit 1055ec2aa2
3 changed files with 83 additions and 76 deletions

View file

@ -114,7 +114,7 @@ class Extractor:
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = [] out_results = []
error_count = 0 error_count = 0
max_errors = 3 max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency) limiter = trio.Semaphore(max_concurrency)

View file

@ -15,27 +15,35 @@
# #
import logging import logging
import re import re
import umap
import numpy as np import numpy as np
from sklearn.mixture import GaussianMixture
import trio import trio
import umap
from sklearn.mixture import GaussianMixture
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled
from common.connection_utils import timeout from common.connection_utils import timeout
from common.exceptions import TaskCanceledException from common.exceptions import TaskCanceledException
from common.token_utils import truncate
from graphrag.utils import ( from graphrag.utils import (
get_llm_cache, chat_limiter,
get_embed_cache, get_embed_cache,
get_llm_cache,
set_embed_cache, set_embed_cache,
set_llm_cache, set_llm_cache,
chat_limiter,
) )
from common.token_utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__( def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1 self,
max_cluster,
llm_model,
embd_model,
prompt,
max_token=512,
threshold=0.1,
max_errors=3,
): ):
self._max_cluster = max_cluster self._max_cluster = max_cluster
self._llm_model = llm_model self._llm_model = llm_model
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._threshold = threshold self._threshold = threshold
self._prompt = prompt self._prompt = prompt
self._max_token = max_token self._max_token = max_token
self._max_errors = max(1, max_errors)
self._error_count = 0
@timeout(60*20) @timeout(60 * 20)
async def _chat(self, system, history, gen_conf): async def _chat(self, system, history, gen_conf):
response = await trio.to_thread.run_sync( cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) if cached:
) return cached
if response: last_exc = None
return response for attempt in range(3):
response = await trio.to_thread.run_sync( try:
lambda: self._llm_model.chat(system, history, gen_conf) response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL) response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0: if response.find("**ERROR**") >= 0:
raise Exception(response) raise Exception(response)
await trio.to_thread.run_sync( await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
)
return response return response
except Exception as exc:
last_exc = exc
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2:
await trio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20) @timeout(20)
async def _embedding_encode(self, txt): async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync( response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
lambda: get_embed_cache(self._embd_model.llm_name, txt)
)
if response is not None: if response is not None:
return response return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_clusters = np.arange(1, max_clusters) n_clusters = np.arange(1, max_clusters)
bics = [] bics = []
for n in n_clusters: for n in n_clusters:
if task_id: if task_id:
if has_canceled(task_id): if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during get optimal clusters.") logging.info(f"Task {task_id} cancelled during get optimal clusters.")
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
layers = [(0, len(chunks))] layers = [(0, len(chunks))]
start, end = 0, len(chunks) start, end = 0, len(chunks)
@timeout(60*20) @timeout(60 * 20)
async def summarize(ck_idx: list[int]): async def summarize(ck_idx: list[int]):
nonlocal chunks nonlocal chunks
@ -111,14 +122,10 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
texts = [chunks[i][0] for i in ck_idx] texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int( len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
(self._llm_model.max_length - self._max_token) / len(texts) cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
) try:
cluster_content = "\n".join(
[truncate(t, max(1, len_per_chunk)) for t in texts]
)
async with chat_limiter: async with chat_limiter:
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
@ -128,9 +135,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
[ [
{ {
"role": "user", "role": "user",
"content": self._prompt.format( "content": self._prompt.format(cluster_content=cluster_content),
cluster_content=cluster_content
),
} }
], ],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
@ -148,10 +153,19 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
embds = await self._embedding_encode(cnt) embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds)) chunks.append((cnt, embds))
except TaskCanceledException:
raise
except Exception as exc:
self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
labels = [] labels = []
while end - start > 1: while end - start > 1:
if task_id: if task_id:
if has_canceled(task_id): if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.") logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
if len(embeddings) == 2: if len(embeddings) == 2:
await summarize([start, start + 1]) await summarize([start, start + 1])
if callback: if callback:
callback( callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
labels.extend([0, 0]) labels.extend([0, 0])
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
start = end start = end
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
nursery.start_soon(summarize, ck_idx) nursery.start_soon(summarize, ck_idx)
assert len(chunks) - end == n_clusters, "{} vs. {}".format( assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
len(chunks) - end, n_clusters
)
labels.extend(lbls) labels.extend(lbls)
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
if callback: if callback:
callback( callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
start = end start = end
end = len(chunks) end = len(chunks)

View file

@ -648,6 +648,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
res = [] res = []
tk_count = 0 tk_count = 0
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
async def generate(chunks, did): async def generate(chunks, did):
nonlocal tk_count, res nonlocal tk_count, res
raptor = Raptor( raptor = Raptor(
@ -657,6 +659,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
raptor_config["prompt"], raptor_config["prompt"],
raptor_config["max_token"], raptor_config["max_token"],
raptor_config["threshold"], raptor_config["threshold"],
max_errors=max_errors,
) )
original_length = len(chunks) original_length = len(chunks)
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"]) chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])