Feat: add fault-tolerant mechanism to GraphRAG

This commit is contained in:
yongtenglei 2025-11-12 13:41:53 +08:00
parent a316e6bc1b
commit 1055ec2aa2
3 changed files with 83 additions and 76 deletions

View file

@ -114,7 +114,7 @@ class Extractor:
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = [] out_results = []
error_count = 0 error_count = 0
max_errors = 3 max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency) limiter = trio.Semaphore(max_concurrency)

View file

@ -15,27 +15,35 @@
# #
import logging import logging
import re import re
import umap
import numpy as np import numpy as np
from sklearn.mixture import GaussianMixture
import trio import trio
import umap
from sklearn.mixture import GaussianMixture
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled
from common.connection_utils import timeout from common.connection_utils import timeout
from common.exceptions import TaskCanceledException from common.exceptions import TaskCanceledException
from common.token_utils import truncate
from graphrag.utils import ( from graphrag.utils import (
get_llm_cache, chat_limiter,
get_embed_cache, get_embed_cache,
get_llm_cache,
set_embed_cache, set_embed_cache,
set_llm_cache, set_llm_cache,
chat_limiter,
) )
from common.token_utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__( def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1 self,
max_cluster,
llm_model,
embd_model,
prompt,
max_token=512,
threshold=0.1,
max_errors=3,
): ):
self._max_cluster = max_cluster self._max_cluster = max_cluster
self._llm_model = llm_model self._llm_model = llm_model
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._threshold = threshold self._threshold = threshold
self._prompt = prompt self._prompt = prompt
self._max_token = max_token self._max_token = max_token
self._max_errors = max(1, max_errors)
self._error_count = 0
@timeout(60*20) @timeout(60 * 20)
async def _chat(self, system, history, gen_conf): async def _chat(self, system, history, gen_conf):
response = await trio.to_thread.run_sync( cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) if cached:
) return cached
if response: last_exc = None
return response for attempt in range(3):
response = await trio.to_thread.run_sync( try:
lambda: self._llm_model.chat(system, history, gen_conf) response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
) response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL) if response.find("**ERROR**") >= 0:
if response.find("**ERROR**") >= 0: raise Exception(response)
raise Exception(response) await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
await trio.to_thread.run_sync( return response
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) except Exception as exc:
) last_exc = exc
return response logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2:
await trio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20) @timeout(20)
async def _embedding_encode(self, txt): async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync( response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
lambda: get_embed_cache(self._embd_model.llm_name, txt)
)
if response is not None: if response is not None:
return response return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_clusters = np.arange(1, max_clusters) n_clusters = np.arange(1, max_clusters)
bics = [] bics = []
for n in n_clusters: for n in n_clusters:
if task_id: if task_id:
if has_canceled(task_id): if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during get optimal clusters.") logging.info(f"Task {task_id} cancelled during get optimal clusters.")
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
layers = [(0, len(chunks))] layers = [(0, len(chunks))]
start, end = 0, len(chunks) start, end = 0, len(chunks)
@timeout(60*20) @timeout(60 * 20)
async def summarize(ck_idx: list[int]): async def summarize(ck_idx: list[int]):
nonlocal chunks nonlocal chunks
@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
texts = [chunks[i][0] for i in ck_idx] texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int( len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
(self._llm_model.max_length - self._max_token) / len(texts) cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
) try:
cluster_content = "\n".join( async with chat_limiter:
[truncate(t, max(1, len_per_chunk)) for t in texts] if task_id and has_canceled(task_id):
) logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
async with chat_limiter: raise TaskCanceledException(f"Task {task_id} was cancelled")
if task_id and has_canceled(task_id): cnt = await self._chat(
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") "You're a helpful assistant.",
raise TaskCanceledException(f"Task {task_id} was cancelled") [
{
"role": "user",
"content": self._prompt.format(cluster_content=cluster_content),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
cnt = await self._chat( if task_id and has_canceled(task_id):
"You're a helpful assistant.", logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
[ raise TaskCanceledException(f"Task {task_id} was cancelled")
{
"role": "user",
"content": self._prompt.format(
cluster_content=cluster_content
),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
if task_id and has_canceled(task_id): embds = await self._embedding_encode(cnt)
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.") chunks.append((cnt, embds))
raise TaskCanceledException(f"Task {task_id} was cancelled") except TaskCanceledException:
raise
embds = await self._embedding_encode(cnt) except Exception as exc:
chunks.append((cnt, embds)) self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
labels = [] labels = []
while end - start > 1: while end - start > 1:
if task_id: if task_id:
if has_canceled(task_id): if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.") logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
if len(embeddings) == 2: if len(embeddings) == 2:
await summarize([start, start + 1]) await summarize([start, start + 1])
if callback: if callback:
callback( callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
labels.extend([0, 0]) labels.extend([0, 0])
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
start = end start = end
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
nursery.start_soon(summarize, ck_idx) nursery.start_soon(summarize, ck_idx)
assert len(chunks) - end == n_clusters, "{} vs. {}".format( assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
len(chunks) - end, n_clusters
)
labels.extend(lbls) labels.extend(lbls)
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
if callback: if callback:
callback( callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
start = end start = end
end = len(chunks) end = len(chunks)

View file

@ -648,6 +648,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
res = [] res = []
tk_count = 0 tk_count = 0
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
async def generate(chunks, did): async def generate(chunks, did):
nonlocal tk_count, res nonlocal tk_count, res
raptor = Raptor( raptor = Raptor(
@ -657,6 +659,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
raptor_config["prompt"], raptor_config["prompt"],
raptor_config["max_token"], raptor_config["max_token"],
raptor_config["threshold"], raptor_config["threshold"],
max_errors=max_errors,
) )
original_length = len(chunks) original_length = len(chunks)
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"]) chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])