Merge branch 'main' of github.com:infiniflow/ragflow into structured-output

This commit is contained in:
bill 2025-12-03 14:01:35 +08:00
commit f8f3d96b5d
7 changed files with 213 additions and 32 deletions

View file

@ -16,6 +16,7 @@
import asyncio import asyncio
import base64 import base64
import inspect import inspect
import binascii
import json import json
import logging import logging
import re import re
@ -28,7 +29,9 @@ from typing import Any, Union, Tuple
from agent.component import component_class from agent.component import component_class
from agent.component.base import ComponentBase from agent.component.base import ComponentBase
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled
from common.constants import LLMType
from common.misc_utils import get_uuid, hash_str2int from common.misc_utils import get_uuid, hash_str2int
from common.exceptions import TaskCanceledException from common.exceptions import TaskCanceledException
from rag.prompts.generator import chunks_format from rag.prompts.generator import chunks_format
@ -356,8 +359,6 @@ class Canvas(Graph):
self.globals[k] = "" self.globals[k] = ""
else: else:
self.globals[k] = "" self.globals[k] = ""
print(self.globals)
async def run(self, **kwargs): async def run(self, **kwargs):
st = time.perf_counter() st = time.perf_counter()
@ -456,6 +457,7 @@ class Canvas(Graph):
self.error = "" self.error = ""
idx = len(self.path) - 1 idx = len(self.path) - 1
partials = [] partials = []
tts_mdl = None
while idx < len(self.path): while idx < len(self.path):
to = len(self.path) to = len(self.path)
for i in range(idx, to): for i in range(idx, to):
@ -473,31 +475,51 @@ class Canvas(Graph):
cpn = self.get_component(self.path[i]) cpn = self.get_component(self.path[i])
cpn_obj = self.get_component_obj(self.path[i]) cpn_obj = self.get_component_obj(self.path[i])
if cpn_obj.component_name.lower() == "message": if cpn_obj.component_name.lower() == "message":
if cpn_obj.get_param("auto_play"):
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
if isinstance(cpn_obj.output("content"), partial): if isinstance(cpn_obj.output("content"), partial):
_m = "" _m = ""
buff_m = ""
stream = cpn_obj.output("content")() stream = cpn_obj.output("content")()
async def _process_stream(m):
nonlocal buff_m, _m, tts_mdl
if not m:
return
if m == "<think>":
return decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
return decorate("message", {"content": "", "end_to_think": True})
buff_m += m
_m += m
if len(buff_m) > 16:
ev = decorate(
"message",
{
"content": m,
"audio_binary": self.tts(tts_mdl, buff_m)
}
)
buff_m = ""
return ev
return decorate("message", {"content": m})
if inspect.isasyncgen(stream): if inspect.isasyncgen(stream):
async for m in stream: async for m in stream:
if not m: ev= await _process_stream(m)
continue if ev:
if m == "<think>": yield ev
yield decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
_m += m
else: else:
for m in stream: for m in stream:
if not m: ev= await _process_stream(m)
continue if ev:
if m == "<think>": yield ev
yield decorate("message", {"content": "", "start_to_think": True}) if buff_m:
elif m == "</think>": yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
yield decorate("message", {"content": "", "end_to_think": True}) buff_m = ""
else:
yield decorate("message", {"content": m})
_m += m
cpn_obj.set_output("content", _m) cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m) cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else: else:
@ -618,6 +640,50 @@ class Canvas(Graph):
return False return False
return True return True
def tts(self,tts_mdl, text):
def clean_tts_text(text: str) -> str:
if not text:
return ""
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
text = re.sub(r"\s+", " ", text).strip()
MAX_LEN = 500
if len(text) > MAX_LEN:
text = text[:MAX_LEN]
return text
if not tts_mdl or not text:
return None
text = clean_tts_text(text)
if not text:
return None
bin = b""
try:
for chunk in tts_mdl.tts(text):
bin += chunk
except Exception as e:
logging.error(f"TTS failed: {e}, text={text!r}")
return None
return binascii.hexlify(bin).decode("utf-8")
def get_history(self, window_size): def get_history(self, window_size):
convs = [] convs = []
if window_size <= 0: if window_size <= 0:

View file

@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
return return
if cks: if cks:
kbinfos["chunks"] = cks kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
if self._param.use_kg: if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query, ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs], [kb.tenant_id for kb in kbs],

View file

@ -761,13 +761,48 @@ Please write the SQL, only SQL, without any other explanations or text.
"prompt": sys_prompt, "prompt": sys_prompt,
} }
def clean_tts_text(text: str) -> str:
if not text:
return ""
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
"[\U0001F600-\U0001F64F"
"\U0001F300-\U0001F5FF"
"\U0001F680-\U0001F6FF"
"\U0001F1E0-\U0001F1FF"
"\U00002700-\U000027BF"
"\U0001F900-\U0001F9FF"
"\U0001FA70-\U0001FAFF"
"\U0001FAD0-\U0001FAFF]+",
flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
text = re.sub(r"\s+", " ", text).strip()
MAX_LEN = 500
if len(text) > MAX_LEN:
text = text[:MAX_LEN]
return text
def tts(tts_mdl, text): def tts(tts_mdl, text):
if not tts_mdl or not text: if not tts_mdl or not text:
return None return None
text = clean_tts_text(text)
if not text:
return None
bin = b"" bin = b""
try:
for chunk in tts_mdl.tts(text): for chunk in tts_mdl.tts(text):
bin += chunk bin += chunk
except Exception as e:
logging.error(f"TTS failed: {e}, text={text!r}")
return None
return binascii.hexlify(bin).decode("utf-8") return binascii.hexlify(bin).decode("utf-8")

View file

@ -173,8 +173,8 @@ def install_mineru() -> None:
Logging is used to indicate status. Logging is used to indicate status.
""" """
# Check if MinerU is enabled # Check if MinerU is enabled
use_mineru = os.getenv("USE_MINERU", "").strip().lower() use_mineru = os.getenv("USE_MINERU", "false").strip().lower()
if use_mineru == "false": if use_mineru != "true":
logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru) logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru)
return return

View file

@ -12,10 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging
import random import random
from copy import deepcopy from copy import deepcopy, copy
import trio
import xxhash
from agent.component.llm import LLMParam, LLM from agent.component.llm import LLMParam, LLM
from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.base import ProcessBase, ProcessParamBase
from rag.prompts.generator import run_toc_from_text
class ExtractorParam(ProcessParamBase, LLMParam): class ExtractorParam(ProcessParamBase, LLMParam):
@ -31,6 +38,38 @@ class ExtractorParam(ProcessParamBase, LLMParam):
class Extractor(ProcessBase, LLM): class Extractor(ProcessBase, LLM):
component_name = "Extractor" component_name = "Extractor"
def _build_TOC(self, docs):
self.callback(message="Start to generate table of content ...")
docs = sorted(docs, key=lambda d:(
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc: list[dict] = trio.run(run_toc_from_text, [d["text"] for d in docs], self.chat_mdl)
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
ii = 0
while ii < len(toc):
try:
idx = int(toc[ii]["chunk_id"])
del toc[ii]["chunk_id"]
toc[ii]["ids"] = [docs[idx]["id"]]
if ii == len(toc) -1:
break
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
toc[ii]["ids"].append(docs[jj]["id"])
except Exception as e:
logging.exception(e)
ii += 1
if toc:
d = copy.deepcopy(docs[-1])
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
d["toc_kwd"] = "toc"
d["available_int"] = 0
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
async def _invoke(self, **kwargs): async def _invoke(self, **kwargs):
self.set_output("output_format", "chunks") self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to generate.") self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
@ -45,6 +84,12 @@ class Extractor(ProcessBase, LLM):
chunks_key = k chunks_key = k
if chunks: if chunks:
if self._param.field_name == "toc":
toc = self._build_TOC(chunks)
chunks.append(toc)
self.set_output("chunks", chunks)
return
prog = 0 prog = 0
for i, ck in enumerate(chunks): for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"] args[chunks_key] = ck["text"]

View file

@ -1,7 +1,5 @@
import message from '@/components/ui/message'; import message from '@/components/ui/message';
import { Spin } from '@/components/ui/spin'; import { Spin } from '@/components/ui/spin';
import { Authorization } from '@/constants/authorization';
import { getAuthorization } from '@/utils/authorization-util';
import request from '@/utils/request'; import request from '@/utils/request';
import classNames from 'classnames'; import classNames from 'classnames';
import mammoth from 'mammoth'; import mammoth from 'mammoth';
@ -16,22 +14,55 @@ export const DocPreviewer: React.FC<DocPreviewerProps> = ({
className, className,
url, url,
}) => { }) => {
// const url = useGetDocumentUrl();
const [htmlContent, setHtmlContent] = useState<string>(''); const [htmlContent, setHtmlContent] = useState<string>('');
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const fetchDocument = async () => { const fetchDocument = async () => {
if (!url) return;
setLoading(true); setLoading(true);
const res = await request(url, { const res = await request(url, {
method: 'GET', method: 'GET',
responseType: 'blob', responseType: 'blob',
headers: { [Authorization]: getAuthorization() },
onError: () => { onError: () => {
message.error('Document parsing failed'); message.error('Document parsing failed');
console.error('Error loading document:', url); console.error('Error loading document:', url);
}, },
}); });
try { try {
const arrayBuffer = await res.data.arrayBuffer(); const blob: Blob = res.data;
const contentType: string =
blob.type || (res as any).headers?.['content-type'] || '';
// ---- Detect legacy .doc via MIME or URL ----
const cleanUrl = url.split(/[?#]/)[0].toLowerCase();
const isDocMime = /application\/msword/i.test(contentType);
const isLegacyDocByUrl =
cleanUrl.endsWith('.doc') && !cleanUrl.endsWith('.docx');
const isLegacyDoc = isDocMime || isLegacyDocByUrl;
if (isLegacyDoc) {
// Do not call mammoth and do not throw an error; instead, show a note in the preview area
setHtmlContent(`
<div class="flex h-full items-center justify-center">
<div class="border border-dashed border-border-normal rounded-xl p-8 max-w-2xl text-center">
<p class="text-2xl font-bold mb-4">
Preview not available for .doc files
</p>
<p class="italic text-sm text-muted-foreground leading-relaxed">
Mammoth does not support <code>.doc</code> documents.<br/>
Inline preview is unavailable.
</p>
</div>
</div>
`);
return;
}
// ---- Standard .docx preview path ----
const arrayBuffer = await blob.arrayBuffer();
const result = await mammoth.convertToHtml( const result = await mammoth.convertToHtml(
{ arrayBuffer }, { arrayBuffer },
{ includeDefaultStyleMap: true }, { includeDefaultStyleMap: true },
@ -43,10 +74,12 @@ export const DocPreviewer: React.FC<DocPreviewerProps> = ({
setHtmlContent(styledContent); setHtmlContent(styledContent);
} catch (err) { } catch (err) {
// Only errors from the mammoth conversion path should surface here
message.error('Document parsing failed'); message.error('Document parsing failed');
console.error('Error parsing document:', err); console.error('Error parsing document:', err);
} } finally {
setLoading(false); setLoading(false);
}
}; };
useEffect(() => { useEffect(() => {
@ -54,6 +87,7 @@ export const DocPreviewer: React.FC<DocPreviewerProps> = ({
fetchDocument(); fetchDocument();
} }
}, [url]); }, [url]);
return ( return (
<div <div
className={classNames( className={classNames(