From 1055ec2aa2a532f9f0070f7a8b2a8e48ce75b6f2 Mon Sep 17 00:00:00 2001 From: yongtenglei Date: Wed, 12 Nov 2025 13:41:53 +0800 Subject: [PATCH] Feat: add fault-tolerant mechanism to GraphRAG --- graphrag/general/extractor.py | 2 +- rag/raptor.py | 154 +++++++++++++++++----------------- rag/svr/task_executor.py | 3 + 3 files changed, 83 insertions(+), 76 deletions(-) diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 1df38ed1c..495e562ed 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -114,7 +114,7 @@ class Extractor: async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): out_results = [] error_count = 0 - max_errors = 3 + max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3)) limiter = trio.Semaphore(max_concurrency) diff --git a/rag/raptor.py b/rag/raptor.py index e6efe3504..a455d0127 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -15,27 +15,35 @@ # import logging import re -import umap + import numpy as np -from sklearn.mixture import GaussianMixture import trio +import umap +from sklearn.mixture import GaussianMixture from api.db.services.task_service import has_canceled from common.connection_utils import timeout from common.exceptions import TaskCanceledException +from common.token_utils import truncate from graphrag.utils import ( - get_llm_cache, + chat_limiter, get_embed_cache, + get_llm_cache, set_embed_cache, set_llm_cache, - chat_limiter, ) -from common.token_utils import truncate class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: def __init__( - self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1 + self, + max_cluster, + llm_model, + embd_model, + prompt, + max_token=512, + threshold=0.1, + max_errors=3, ): self._max_cluster = max_cluster self._llm_model = llm_model @@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: self._threshold = threshold self._prompt = prompt self._max_token = max_token + self._max_errors = max(1, max_errors) + self._error_count = 0 - @timeout(60*20) + @timeout(60 * 20) async def _chat(self, system, history, gen_conf): - response = await trio.to_thread.run_sync( - lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) - ) + cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)) + if cached: + return cached - if response: - return response - response = await trio.to_thread.run_sync( - lambda: self._llm_model.chat(system, history, gen_conf) - ) - response = re.sub(r"^.*", "", response, flags=re.DOTALL) - if response.find("**ERROR**") >= 0: - raise Exception(response) - await trio.to_thread.run_sync( - lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) - ) - return response + last_exc = None + for attempt in range(3): + try: + response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf)) + response = re.sub(r"^.*", "", response, flags=re.DOTALL) + if response.find("**ERROR**") >= 0: + raise Exception(response) + await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)) + return response + except Exception as exc: + last_exc = exc + logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc) + if attempt < 2: + await trio.sleep(1 + attempt) + + raise last_exc if last_exc else Exception("LLM chat failed without exception") @timeout(20) async def _embedding_encode(self, txt): - response = await trio.to_thread.run_sync( - lambda: get_embed_cache(self._embd_model.llm_name, txt) - ) + response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt)) if response is not None: return response embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) @@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: n_clusters = np.arange(1, max_clusters) bics = [] for n in n_clusters: - if task_id: if has_canceled(task_id): logging.info(f"Task {task_id} cancelled during get optimal clusters.") @@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: layers = [(0, len(chunks))] start, end = 0, len(chunks) - @timeout(60*20) + @timeout(60 * 20) async def summarize(ck_idx: list[int]): nonlocal chunks @@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: raise TaskCanceledException(f"Task {task_id} was cancelled") texts = [chunks[i][0] for i in ck_idx] - len_per_chunk = int( - (self._llm_model.max_length - self._max_token) / len(texts) - ) - cluster_content = "\n".join( - [truncate(t, max(1, len_per_chunk)) for t in texts] - ) - async with chat_limiter: + len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) + cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) + try: + async with chat_limiter: + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") + raise TaskCanceledException(f"Task {task_id} was cancelled") - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") - raise TaskCanceledException(f"Task {task_id} was cancelled") + cnt = await self._chat( + "You're a helpful assistant.", + [ + { + "role": "user", + "content": self._prompt.format(cluster_content=cluster_content), + } + ], + {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 + ) + cnt = re.sub( + "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", + "", + cnt, + ) + logging.debug(f"SUM: {cnt}") - cnt = await self._chat( - "You're a helpful assistant.", - [ - { - "role": "user", - "content": self._prompt.format( - cluster_content=cluster_content - ), - } - ], - {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 - ) - cnt = re.sub( - "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", - "", - cnt, - ) - logging.debug(f"SUM: {cnt}") + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before RAPTOR embedding.") + raise TaskCanceledException(f"Task {task_id} was cancelled") - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before RAPTOR embedding.") - raise TaskCanceledException(f"Task {task_id} was cancelled") - - embds = await self._embedding_encode(cnt) - chunks.append((cnt, embds)) + embds = await self._embedding_encode(cnt) + chunks.append((cnt, embds)) + except TaskCanceledException: + raise + except Exception as exc: + self._error_count += 1 + warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}" + logging.warning(warn_msg) + if callback: + callback(msg=warn_msg) + if self._error_count >= self._max_errors: + raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc labels = [] while end - start > 1: - if task_id: if has_canceled(task_id): logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.") @@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: if len(embeddings) == 2: await summarize([start, start + 1]) if callback: - callback( - msg="Cluster one layer: {} -> {}".format( - end - start, len(chunks) - end - ) - ) + callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) labels.extend([0, 0]) layers.append((end, len(chunks))) start = end @@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: nursery.start_soon(summarize, ck_idx) - assert len(chunks) - end == n_clusters, "{} vs. {}".format( - len(chunks) - end, n_clusters - ) + assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) labels.extend(lbls) layers.append((end, len(chunks))) if callback: - callback( - msg="Cluster one layer: {} -> {}".format( - end - start, len(chunks) - end - ) - ) + callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) start = end end = len(chunks) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index af8dfc186..7befc6110 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -648,6 +648,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si res = [] tk_count = 0 + max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3)) + async def generate(chunks, did): nonlocal tk_count, res raptor = Raptor( @@ -657,6 +659,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si raptor_config["prompt"], raptor_config["max_token"], raptor_config["threshold"], + max_errors=max_errors, ) original_length = len(chunks) chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])