Merge branch 'main' of https://github.com/infiniflow/ragflow into refa/async_migration_2

This commit is contained in:
yongtenglei 2025-12-03 13:29:16 +08:00
commit 775c897d80
8 changed files with 214 additions and 34 deletions

View file

@ -16,6 +16,7 @@
import asyncio
import base64
import inspect
import binascii
import json
import logging
import re
@ -28,7 +29,9 @@ from typing import Any, Union, Tuple
from agent.component import component_class
from agent.component.base import ComponentBase
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import has_canceled
from common.constants import LLMType
from common.misc_utils import get_uuid, hash_str2int
from common.exceptions import TaskCanceledException
from rag.prompts.generator import chunks_format
@ -356,8 +359,6 @@ class Canvas(Graph):
self.globals[k] = ""
else:
self.globals[k] = ""
print(self.globals)
async def run(self, **kwargs):
st = time.perf_counter()
@ -467,6 +468,7 @@ class Canvas(Graph):
self.error = ""
idx = len(self.path) - 1
partials = []
tts_mdl = None
while idx < len(self.path):
to = len(self.path)
for i in range(idx, to):
@ -484,31 +486,51 @@ class Canvas(Graph):
cpn = self.get_component(self.path[i])
cpn_obj = self.get_component_obj(self.path[i])
if cpn_obj.component_name.lower() == "message":
if cpn_obj.get_param("auto_play"):
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
if isinstance(cpn_obj.output("content"), partial):
_m = ""
buff_m = ""
stream = cpn_obj.output("content")()
async def _process_stream(m):
nonlocal buff_m, _m, tts_mdl
if not m:
return
if m == "<think>":
return decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
return decorate("message", {"content": "", "end_to_think": True})
buff_m += m
_m += m
if len(buff_m) > 16:
ev = decorate(
"message",
{
"content": m,
"audio_binary": self.tts(tts_mdl, buff_m)
}
)
buff_m = ""
return ev
return decorate("message", {"content": m})
if inspect.isasyncgen(stream):
async for m in stream:
if not m:
continue
if m == "<think>":
yield decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
_m += m
ev= await _process_stream(m)
if ev:
yield ev
else:
for m in stream:
if not m:
continue
if m == "<think>":
yield decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
_m += m
ev= await _process_stream(m)
if ev:
yield ev
if buff_m:
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
buff_m = ""
cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else:
@ -629,6 +651,50 @@ class Canvas(Graph):
return False
return True
def tts(self,tts_mdl, text):
def clean_tts_text(text: str) -> str:
if not text:
return ""
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
text = re.sub(r"\s+", " ", text).strip()
MAX_LEN = 500
if len(text) > MAX_LEN:
text = text[:MAX_LEN]
return text
if not tts_mdl or not text:
return None
text = clean_tts_text(text)
if not text:
return None
bin = b""
try:
for chunk in tts_mdl.tts(text):
bin += chunk
except Exception as e:
logging.error(f"TTS failed: {e}, text={text!r}")
return None
return binascii.hexlify(bin).decode("utf-8")
def get_history(self, window_size):
convs = []
if window_size <= 0:

View file

@ -215,7 +215,6 @@ class LLM(ComponentBase):
yield delta(txt)
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
# Prefer async chat_streamly if available
async def delta_wrapper(txt_iter):
ans = ""
last_idx = 0
@ -256,7 +255,7 @@ class LLM(ComponentBase):
yield t
return
# Fallback: run sync stream in thread, bridge results
# fallback
loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue()

View file

@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
return
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],

View file

@ -761,13 +761,48 @@ Please write the SQL, only SQL, without any other explanations or text.
"prompt": sys_prompt,
}
def clean_tts_text(text: str) -> str:
if not text:
return ""
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
text = re.sub(r"\s+", " ", text).strip()
MAX_LEN = 500
if len(text) > MAX_LEN:
text = text[:MAX_LEN]
return text
def tts(tts_mdl, text):
if not tts_mdl or not text:
return None
text = clean_tts_text(text)
if not text:
return None
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk
try:
for chunk in tts_mdl.tts(text):
bin += chunk
except Exception as e:
logging.error(f"TTS failed: {e}, text={text!r}")
return None
return binascii.hexlify(bin).decode("utf-8")

View file

@ -173,8 +173,8 @@ def install_mineru() -> None:
Logging is used to indicate status.
"""
# Check if MinerU is enabled
use_mineru = os.getenv("USE_MINERU", "").strip().lower()
if use_mineru == "false":
use_mineru = os.getenv("USE_MINERU", "false").strip().lower()
if use_mineru != "true":
logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru)
return

View file

@ -12,10 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import random
from copy import deepcopy
from copy import deepcopy, copy
import trio
import xxhash
from agent.component.llm import LLMParam, LLM
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.prompts.generator import run_toc_from_text
class ExtractorParam(ProcessParamBase, LLMParam):
@ -31,6 +38,38 @@ class ExtractorParam(ProcessParamBase, LLMParam):
class Extractor(ProcessBase, LLM):
component_name = "Extractor"
def _build_TOC(self, docs):
self.callback(message="Start to generate table of content ...")
docs = sorted(docs, key=lambda d:(
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc: list[dict] = trio.run(run_toc_from_text, [d["text"] for d in docs], self.chat_mdl)
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
try:
idx = int(toc[ii]["chunk_id"])
del toc[ii]["chunk_id"]
toc[ii]["ids"] = [docs[idx]["id"]]
if ii == len(toc) -1:
break
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
toc[ii]["ids"].append(docs[jj]["id"])
except Exception as e:
logging.exception(e)
ii += 1
if toc:
d = copy.deepcopy(docs[-1])
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
d["toc_kwd"] = "toc"
d["available_int"] = 0
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
async def _invoke(self, **kwargs):
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
@ -45,6 +84,12 @@ class Extractor(ProcessBase, LLM):
chunks_key = k
if chunks:
if self._param.field_name == "toc":
toc = self._build_TOC(chunks)
chunks.append(toc)
self.set_output("chunks", chunks)
return
prog = 0
for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"]

View file

@ -944,7 +944,7 @@ async def do_handle_task(task):
logging.info(progress_message)
progress_callback(msg=progress_message)
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
toc_thread = executor.submit(build_TOC,task, chunks, progress_callback)
toc_thread = executor.submit(build_TOC, task, chunks, progress_callback)
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()

View file

@ -1,7 +1,5 @@
import message from '@/components/ui/message';
import { Spin } from '@/components/ui/spin';
import { Authorization } from '@/constants/authorization';
import { getAuthorization } from '@/utils/authorization-util';
import request from '@/utils/request';
import classNames from 'classnames';
import mammoth from 'mammoth';
@ -16,22 +14,55 @@ export const DocPreviewer: React.FC<DocPreviewerProps> = ({
className,
url,
}) => {
// const url = useGetDocumentUrl();
const [htmlContent, setHtmlContent] = useState<string>('');
const [loading, setLoading] = useState(false);
const fetchDocument = async () => {
if (!url) return;
setLoading(true);
const res = await request(url, {
method: 'GET',
responseType: 'blob',
headers: { [Authorization]: getAuthorization() },
onError: () => {
message.error('Document parsing failed');
console.error('Error loading document:', url);
},
});
try {
const arrayBuffer = await res.data.arrayBuffer();
const blob: Blob = res.data;
const contentType: string =
blob.type || (res as any).headers?.['content-type'] || '';
// ---- Detect legacy .doc via MIME or URL ----
const cleanUrl = url.split(/[?#]/)[0].toLowerCase();
const isDocMime = /application\/msword/i.test(contentType);
const isLegacyDocByUrl =
cleanUrl.endsWith('.doc') && !cleanUrl.endsWith('.docx');
const isLegacyDoc = isDocMime || isLegacyDocByUrl;
if (isLegacyDoc) {
// Do not call mammoth and do not throw an error; instead, show a note in the preview area
setHtmlContent(`
<div class="flex h-full items-center justify-center">
<div class="border border-dashed border-border-normal rounded-xl p-8 max-w-2xl text-center">
<p class="text-2xl font-bold mb-4">
Preview not available for .doc files
</p>
<p class="italic text-sm text-muted-foreground leading-relaxed">
Mammoth does not support <code>.doc</code> documents.<br/>
Inline preview is unavailable.
</p>
</div>
</div>
`);
return;
}
// ---- Standard .docx preview path ----
const arrayBuffer = await blob.arrayBuffer();
const result = await mammoth.convertToHtml(
{ arrayBuffer },
{ includeDefaultStyleMap: true },
@ -43,10 +74,12 @@ export const DocPreviewer: React.FC<DocPreviewerProps> = ({
setHtmlContent(styledContent);
} catch (err) {
// Only errors from the mammoth conversion path should surface here
message.error('Document parsing failed');
console.error('Error parsing document:', err);
} finally {
setLoading(false);
}
setLoading(false);
};
useEffect(() => {
@ -54,6 +87,7 @@ export const DocPreviewer: React.FC<DocPreviewerProps> = ({
fetchDocument();
}
}, [url]);
return (
<div
className={classNames(