Feat: add fault-tolerant mechanism to GraphRAG
This commit is contained in:
parent
a316e6bc1b
commit
1055ec2aa2
3 changed files with 83 additions and 76 deletions
|
|
@ -114,7 +114,7 @@ class Extractor:
|
||||||
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
|
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
|
||||||
out_results = []
|
out_results = []
|
||||||
error_count = 0
|
error_count = 0
|
||||||
max_errors = 3
|
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
|
||||||
|
|
||||||
limiter = trio.Semaphore(max_concurrency)
|
limiter = trio.Semaphore(max_concurrency)
|
||||||
|
|
||||||
|
|
|
||||||
154
rag/raptor.py
154
rag/raptor.py
|
|
@ -15,27 +15,35 @@
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import umap
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.mixture import GaussianMixture
|
|
||||||
import trio
|
import trio
|
||||||
|
import umap
|
||||||
|
from sklearn.mixture import GaussianMixture
|
||||||
|
|
||||||
from api.db.services.task_service import has_canceled
|
from api.db.services.task_service import has_canceled
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from common.exceptions import TaskCanceledException
|
from common.exceptions import TaskCanceledException
|
||||||
|
from common.token_utils import truncate
|
||||||
from graphrag.utils import (
|
from graphrag.utils import (
|
||||||
get_llm_cache,
|
chat_limiter,
|
||||||
get_embed_cache,
|
get_embed_cache,
|
||||||
|
get_llm_cache,
|
||||||
set_embed_cache,
|
set_embed_cache,
|
||||||
set_llm_cache,
|
set_llm_cache,
|
||||||
chat_limiter,
|
|
||||||
)
|
)
|
||||||
from common.token_utils import truncate
|
|
||||||
|
|
||||||
|
|
||||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
|
self,
|
||||||
|
max_cluster,
|
||||||
|
llm_model,
|
||||||
|
embd_model,
|
||||||
|
prompt,
|
||||||
|
max_token=512,
|
||||||
|
threshold=0.1,
|
||||||
|
max_errors=3,
|
||||||
):
|
):
|
||||||
self._max_cluster = max_cluster
|
self._max_cluster = max_cluster
|
||||||
self._llm_model = llm_model
|
self._llm_model = llm_model
|
||||||
|
|
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||||
self._threshold = threshold
|
self._threshold = threshold
|
||||||
self._prompt = prompt
|
self._prompt = prompt
|
||||||
self._max_token = max_token
|
self._max_token = max_token
|
||||||
|
self._max_errors = max(1, max_errors)
|
||||||
|
self._error_count = 0
|
||||||
|
|
||||||
@timeout(60*20)
|
@timeout(60 * 20)
|
||||||
async def _chat(self, system, history, gen_conf):
|
async def _chat(self, system, history, gen_conf):
|
||||||
response = await trio.to_thread.run_sync(
|
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
|
||||||
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
|
if cached:
|
||||||
)
|
return cached
|
||||||
|
|
||||||
if response:
|
last_exc = None
|
||||||
return response
|
for attempt in range(3):
|
||||||
response = await trio.to_thread.run_sync(
|
try:
|
||||||
lambda: self._llm_model.chat(system, history, gen_conf)
|
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
|
||||||
)
|
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
if response.find("**ERROR**") >= 0:
|
||||||
if response.find("**ERROR**") >= 0:
|
raise Exception(response)
|
||||||
raise Exception(response)
|
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
|
||||||
await trio.to_thread.run_sync(
|
return response
|
||||||
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
|
except Exception as exc:
|
||||||
)
|
last_exc = exc
|
||||||
return response
|
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
|
||||||
|
if attempt < 2:
|
||||||
|
await trio.sleep(1 + attempt)
|
||||||
|
|
||||||
|
raise last_exc if last_exc else Exception("LLM chat failed without exception")
|
||||||
|
|
||||||
@timeout(20)
|
@timeout(20)
|
||||||
async def _embedding_encode(self, txt):
|
async def _embedding_encode(self, txt):
|
||||||
response = await trio.to_thread.run_sync(
|
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
|
||||||
lambda: get_embed_cache(self._embd_model.llm_name, txt)
|
|
||||||
)
|
|
||||||
if response is not None:
|
if response is not None:
|
||||||
return response
|
return response
|
||||||
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
|
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
|
||||||
|
|
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||||
n_clusters = np.arange(1, max_clusters)
|
n_clusters = np.arange(1, max_clusters)
|
||||||
bics = []
|
bics = []
|
||||||
for n in n_clusters:
|
for n in n_clusters:
|
||||||
|
|
||||||
if task_id:
|
if task_id:
|
||||||
if has_canceled(task_id):
|
if has_canceled(task_id):
|
||||||
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
|
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
|
||||||
|
|
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||||
layers = [(0, len(chunks))]
|
layers = [(0, len(chunks))]
|
||||||
start, end = 0, len(chunks)
|
start, end = 0, len(chunks)
|
||||||
|
|
||||||
@timeout(60*20)
|
@timeout(60 * 20)
|
||||||
async def summarize(ck_idx: list[int]):
|
async def summarize(ck_idx: list[int]):
|
||||||
nonlocal chunks
|
nonlocal chunks
|
||||||
|
|
||||||
|
|
@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||||
|
|
||||||
texts = [chunks[i][0] for i in ck_idx]
|
texts = [chunks[i][0] for i in ck_idx]
|
||||||
len_per_chunk = int(
|
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
||||||
(self._llm_model.max_length - self._max_token) / len(texts)
|
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||||
)
|
try:
|
||||||
cluster_content = "\n".join(
|
async with chat_limiter:
|
||||||
[truncate(t, max(1, len_per_chunk)) for t in texts]
|
if task_id and has_canceled(task_id):
|
||||||
)
|
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
|
||||||
async with chat_limiter:
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||||
|
|
||||||
if task_id and has_canceled(task_id):
|
cnt = await self._chat(
|
||||||
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
|
"You're a helpful assistant.",
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self._prompt.format(cluster_content=cluster_content),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
|
||||||
|
)
|
||||||
|
cnt = re.sub(
|
||||||
|
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
||||||
|
"",
|
||||||
|
cnt,
|
||||||
|
)
|
||||||
|
logging.debug(f"SUM: {cnt}")
|
||||||
|
|
||||||
cnt = await self._chat(
|
if task_id and has_canceled(task_id):
|
||||||
"You're a helpful assistant.",
|
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
|
||||||
[
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": self._prompt.format(
|
|
||||||
cluster_content=cluster_content
|
|
||||||
),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
|
|
||||||
)
|
|
||||||
cnt = re.sub(
|
|
||||||
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
|
||||||
"",
|
|
||||||
cnt,
|
|
||||||
)
|
|
||||||
logging.debug(f"SUM: {cnt}")
|
|
||||||
|
|
||||||
if task_id and has_canceled(task_id):
|
embds = await self._embedding_encode(cnt)
|
||||||
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
|
chunks.append((cnt, embds))
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
except TaskCanceledException:
|
||||||
|
raise
|
||||||
embds = await self._embedding_encode(cnt)
|
except Exception as exc:
|
||||||
chunks.append((cnt, embds))
|
self._error_count += 1
|
||||||
|
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
|
||||||
|
logging.warning(warn_msg)
|
||||||
|
if callback:
|
||||||
|
callback(msg=warn_msg)
|
||||||
|
if self._error_count >= self._max_errors:
|
||||||
|
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
|
||||||
|
|
||||||
labels = []
|
labels = []
|
||||||
while end - start > 1:
|
while end - start > 1:
|
||||||
|
|
||||||
if task_id:
|
if task_id:
|
||||||
if has_canceled(task_id):
|
if has_canceled(task_id):
|
||||||
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
|
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
|
||||||
|
|
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||||
if len(embeddings) == 2:
|
if len(embeddings) == 2:
|
||||||
await summarize([start, start + 1])
|
await summarize([start, start + 1])
|
||||||
if callback:
|
if callback:
|
||||||
callback(
|
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||||
msg="Cluster one layer: {} -> {}".format(
|
|
||||||
end - start, len(chunks) - end
|
|
||||||
)
|
|
||||||
)
|
|
||||||
labels.extend([0, 0])
|
labels.extend([0, 0])
|
||||||
layers.append((end, len(chunks)))
|
layers.append((end, len(chunks)))
|
||||||
start = end
|
start = end
|
||||||
|
|
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||||
|
|
||||||
nursery.start_soon(summarize, ck_idx)
|
nursery.start_soon(summarize, ck_idx)
|
||||||
|
|
||||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
|
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
||||||
len(chunks) - end, n_clusters
|
|
||||||
)
|
|
||||||
labels.extend(lbls)
|
labels.extend(lbls)
|
||||||
layers.append((end, len(chunks)))
|
layers.append((end, len(chunks)))
|
||||||
if callback:
|
if callback:
|
||||||
callback(
|
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||||
msg="Cluster one layer: {} -> {}".format(
|
|
||||||
end - start, len(chunks) - end
|
|
||||||
)
|
|
||||||
)
|
|
||||||
start = end
|
start = end
|
||||||
end = len(chunks)
|
end = len(chunks)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -648,6 +648,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
|
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
|
||||||
|
|
||||||
async def generate(chunks, did):
|
async def generate(chunks, did):
|
||||||
nonlocal tk_count, res
|
nonlocal tk_count, res
|
||||||
raptor = Raptor(
|
raptor = Raptor(
|
||||||
|
|
@ -657,6 +659,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||||
raptor_config["prompt"],
|
raptor_config["prompt"],
|
||||||
raptor_config["max_token"],
|
raptor_config["max_token"],
|
||||||
raptor_config["threshold"],
|
raptor_config["threshold"],
|
||||||
|
max_errors=max_errors,
|
||||||
)
|
)
|
||||||
original_length = len(chunks)
|
original_length = len(chunks)
|
||||||
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])
|
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue