Merge branch 'main' of github.com:infiniflow/ragflow into structured-output
This commit is contained in:
commit
f8f3d96b5d
7 changed files with 213 additions and 32 deletions
106
agent/canvas.py
106
agent/canvas.py
|
|
@ -16,6 +16,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import inspect
|
import inspect
|
||||||
|
import binascii
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
@ -28,7 +29,9 @@ from typing import Any, Union, Tuple
|
||||||
from agent.component import component_class
|
from agent.component import component_class
|
||||||
from agent.component.base import ComponentBase
|
from agent.component.base import ComponentBase
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.task_service import has_canceled
|
from api.db.services.task_service import has_canceled
|
||||||
|
from common.constants import LLMType
|
||||||
from common.misc_utils import get_uuid, hash_str2int
|
from common.misc_utils import get_uuid, hash_str2int
|
||||||
from common.exceptions import TaskCanceledException
|
from common.exceptions import TaskCanceledException
|
||||||
from rag.prompts.generator import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
|
|
@ -356,8 +359,6 @@ class Canvas(Graph):
|
||||||
self.globals[k] = ""
|
self.globals[k] = ""
|
||||||
else:
|
else:
|
||||||
self.globals[k] = ""
|
self.globals[k] = ""
|
||||||
print(self.globals)
|
|
||||||
|
|
||||||
|
|
||||||
async def run(self, **kwargs):
|
async def run(self, **kwargs):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
|
@ -456,6 +457,7 @@ class Canvas(Graph):
|
||||||
self.error = ""
|
self.error = ""
|
||||||
idx = len(self.path) - 1
|
idx = len(self.path) - 1
|
||||||
partials = []
|
partials = []
|
||||||
|
tts_mdl = None
|
||||||
while idx < len(self.path):
|
while idx < len(self.path):
|
||||||
to = len(self.path)
|
to = len(self.path)
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
|
|
@ -473,31 +475,51 @@ class Canvas(Graph):
|
||||||
cpn = self.get_component(self.path[i])
|
cpn = self.get_component(self.path[i])
|
||||||
cpn_obj = self.get_component_obj(self.path[i])
|
cpn_obj = self.get_component_obj(self.path[i])
|
||||||
if cpn_obj.component_name.lower() == "message":
|
if cpn_obj.component_name.lower() == "message":
|
||||||
|
if cpn_obj.get_param("auto_play"):
|
||||||
|
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
|
||||||
if isinstance(cpn_obj.output("content"), partial):
|
if isinstance(cpn_obj.output("content"), partial):
|
||||||
_m = ""
|
_m = ""
|
||||||
|
buff_m = ""
|
||||||
stream = cpn_obj.output("content")()
|
stream = cpn_obj.output("content")()
|
||||||
|
async def _process_stream(m):
|
||||||
|
nonlocal buff_m, _m, tts_mdl
|
||||||
|
if not m:
|
||||||
|
return
|
||||||
|
if m == "<think>":
|
||||||
|
return decorate("message", {"content": "", "start_to_think": True})
|
||||||
|
|
||||||
|
elif m == "</think>":
|
||||||
|
return decorate("message", {"content": "", "end_to_think": True})
|
||||||
|
|
||||||
|
buff_m += m
|
||||||
|
_m += m
|
||||||
|
|
||||||
|
if len(buff_m) > 16:
|
||||||
|
ev = decorate(
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": m,
|
||||||
|
"audio_binary": self.tts(tts_mdl, buff_m)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
buff_m = ""
|
||||||
|
return ev
|
||||||
|
|
||||||
|
return decorate("message", {"content": m})
|
||||||
|
|
||||||
if inspect.isasyncgen(stream):
|
if inspect.isasyncgen(stream):
|
||||||
async for m in stream:
|
async for m in stream:
|
||||||
if not m:
|
ev= await _process_stream(m)
|
||||||
continue
|
if ev:
|
||||||
if m == "<think>":
|
yield ev
|
||||||
yield decorate("message", {"content": "", "start_to_think": True})
|
|
||||||
elif m == "</think>":
|
|
||||||
yield decorate("message", {"content": "", "end_to_think": True})
|
|
||||||
else:
|
|
||||||
yield decorate("message", {"content": m})
|
|
||||||
_m += m
|
|
||||||
else:
|
else:
|
||||||
for m in stream:
|
for m in stream:
|
||||||
if not m:
|
ev= await _process_stream(m)
|
||||||
continue
|
if ev:
|
||||||
if m == "<think>":
|
yield ev
|
||||||
yield decorate("message", {"content": "", "start_to_think": True})
|
if buff_m:
|
||||||
elif m == "</think>":
|
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
|
||||||
yield decorate("message", {"content": "", "end_to_think": True})
|
buff_m = ""
|
||||||
else:
|
|
||||||
yield decorate("message", {"content": m})
|
|
||||||
_m += m
|
|
||||||
cpn_obj.set_output("content", _m)
|
cpn_obj.set_output("content", _m)
|
||||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||||
else:
|
else:
|
||||||
|
|
@ -618,6 +640,50 @@ class Canvas(Graph):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def tts(self,tts_mdl, text):
|
||||||
|
def clean_tts_text(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||||
|
|
||||||
|
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||||
|
|
||||||
|
emoji_pattern = re.compile(
|
||||||
|
"[\U0001F600-\U0001F64F"
|
||||||
|
"\U0001F300-\U0001F5FF"
|
||||||
|
"\U0001F680-\U0001F6FF"
|
||||||
|
"\U0001F1E0-\U0001F1FF"
|
||||||
|
"\U00002700-\U000027BF"
|
||||||
|
"\U0001F900-\U0001F9FF"
|
||||||
|
"\U0001FA70-\U0001FAFF"
|
||||||
|
"\U0001FAD0-\U0001FAFF]+",
|
||||||
|
flags=re.UNICODE
|
||||||
|
)
|
||||||
|
text = emoji_pattern.sub("", text)
|
||||||
|
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
|
MAX_LEN = 500
|
||||||
|
if len(text) > MAX_LEN:
|
||||||
|
text = text[:MAX_LEN]
|
||||||
|
|
||||||
|
return text
|
||||||
|
if not tts_mdl or not text:
|
||||||
|
return None
|
||||||
|
text = clean_tts_text(text)
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
bin = b""
|
||||||
|
try:
|
||||||
|
for chunk in tts_mdl.tts(text):
|
||||||
|
bin += chunk
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||||
|
return None
|
||||||
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
def get_history(self, window_size):
|
def get_history(self, window_size):
|
||||||
convs = []
|
convs = []
|
||||||
if window_size <= 0:
|
if window_size <= 0:
|
||||||
|
|
|
||||||
|
|
@ -198,6 +198,7 @@ class Retrieval(ToolBase, ABC):
|
||||||
return
|
return
|
||||||
if cks:
|
if cks:
|
||||||
kbinfos["chunks"] = cks
|
kbinfos["chunks"] = cks
|
||||||
|
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
||||||
if self._param.use_kg:
|
if self._param.use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(query,
|
ck = settings.kg_retriever.retrieval(query,
|
||||||
[kb.tenant_id for kb in kbs],
|
[kb.tenant_id for kb in kbs],
|
||||||
|
|
|
||||||
|
|
@ -761,13 +761,48 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||||
"prompt": sys_prompt,
|
"prompt": sys_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def clean_tts_text(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||||
|
|
||||||
|
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||||
|
|
||||||
|
emoji_pattern = re.compile(
|
||||||
|
"[\U0001F600-\U0001F64F"
|
||||||
|
"\U0001F300-\U0001F5FF"
|
||||||
|
"\U0001F680-\U0001F6FF"
|
||||||
|
"\U0001F1E0-\U0001F1FF"
|
||||||
|
"\U00002700-\U000027BF"
|
||||||
|
"\U0001F900-\U0001F9FF"
|
||||||
|
"\U0001FA70-\U0001FAFF"
|
||||||
|
"\U0001FAD0-\U0001FAFF]+",
|
||||||
|
flags=re.UNICODE
|
||||||
|
)
|
||||||
|
text = emoji_pattern.sub("", text)
|
||||||
|
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
|
MAX_LEN = 500
|
||||||
|
if len(text) > MAX_LEN:
|
||||||
|
text = text[:MAX_LEN]
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
def tts(tts_mdl, text):
|
def tts(tts_mdl, text):
|
||||||
if not tts_mdl or not text:
|
if not tts_mdl or not text:
|
||||||
return None
|
return None
|
||||||
|
text = clean_tts_text(text)
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
bin = b""
|
bin = b""
|
||||||
for chunk in tts_mdl.tts(text):
|
try:
|
||||||
bin += chunk
|
for chunk in tts_mdl.tts(text):
|
||||||
|
bin += chunk
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||||
|
return None
|
||||||
return binascii.hexlify(bin).decode("utf-8")
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -173,8 +173,8 @@ def install_mineru() -> None:
|
||||||
Logging is used to indicate status.
|
Logging is used to indicate status.
|
||||||
"""
|
"""
|
||||||
# Check if MinerU is enabled
|
# Check if MinerU is enabled
|
||||||
use_mineru = os.getenv("USE_MINERU", "").strip().lower()
|
use_mineru = os.getenv("USE_MINERU", "false").strip().lower()
|
||||||
if use_mineru == "false":
|
if use_mineru != "true":
|
||||||
logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru)
|
logging.info("USE_MINERU=%r. Skipping MinerU installation.", use_mineru)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,17 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy, copy
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import xxhash
|
||||||
|
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
|
from rag.prompts.generator import run_toc_from_text
|
||||||
|
|
||||||
|
|
||||||
class ExtractorParam(ProcessParamBase, LLMParam):
|
class ExtractorParam(ProcessParamBase, LLMParam):
|
||||||
|
|
@ -31,6 +38,38 @@ class ExtractorParam(ProcessParamBase, LLMParam):
|
||||||
class Extractor(ProcessBase, LLM):
|
class Extractor(ProcessBase, LLM):
|
||||||
component_name = "Extractor"
|
component_name = "Extractor"
|
||||||
|
|
||||||
|
def _build_TOC(self, docs):
|
||||||
|
self.callback(message="Start to generate table of content ...")
|
||||||
|
docs = sorted(docs, key=lambda d:(
|
||||||
|
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||||
|
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||||
|
))
|
||||||
|
toc: list[dict] = trio.run(run_toc_from_text, [d["text"] for d in docs], self.chat_mdl)
|
||||||
|
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||||
|
ii = 0
|
||||||
|
while ii < len(toc):
|
||||||
|
try:
|
||||||
|
idx = int(toc[ii]["chunk_id"])
|
||||||
|
del toc[ii]["chunk_id"]
|
||||||
|
toc[ii]["ids"] = [docs[idx]["id"]]
|
||||||
|
if ii == len(toc) -1:
|
||||||
|
break
|
||||||
|
for jj in range(idx+1, int(toc[ii+1]["chunk_id"])+1):
|
||||||
|
toc[ii]["ids"].append(docs[jj]["id"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
ii += 1
|
||||||
|
|
||||||
|
if toc:
|
||||||
|
d = copy.deepcopy(docs[-1])
|
||||||
|
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
|
||||||
|
d["toc_kwd"] = "toc"
|
||||||
|
d["available_int"] = 0
|
||||||
|
d["page_num_int"] = [100000000]
|
||||||
|
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||||
|
return d
|
||||||
|
return None
|
||||||
|
|
||||||
async def _invoke(self, **kwargs):
|
async def _invoke(self, **kwargs):
|
||||||
self.set_output("output_format", "chunks")
|
self.set_output("output_format", "chunks")
|
||||||
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
||||||
|
|
@ -45,6 +84,12 @@ class Extractor(ProcessBase, LLM):
|
||||||
chunks_key = k
|
chunks_key = k
|
||||||
|
|
||||||
if chunks:
|
if chunks:
|
||||||
|
if self._param.field_name == "toc":
|
||||||
|
toc = self._build_TOC(chunks)
|
||||||
|
chunks.append(toc)
|
||||||
|
self.set_output("chunks", chunks)
|
||||||
|
return
|
||||||
|
|
||||||
prog = 0
|
prog = 0
|
||||||
for i, ck in enumerate(chunks):
|
for i, ck in enumerate(chunks):
|
||||||
args[chunks_key] = ck["text"]
|
args[chunks_key] = ck["text"]
|
||||||
|
|
|
||||||
|
|
@ -944,7 +944,7 @@ async def do_handle_task(task):
|
||||||
logging.info(progress_message)
|
logging.info(progress_message)
|
||||||
progress_callback(msg=progress_message)
|
progress_callback(msg=progress_message)
|
||||||
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
|
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
|
||||||
toc_thread = executor.submit(build_TOC,task, chunks, progress_callback)
|
toc_thread = executor.submit(build_TOC, task, chunks, progress_callback)
|
||||||
|
|
||||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||||
start_ts = timer()
|
start_ts = timer()
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
import message from '@/components/ui/message';
|
import message from '@/components/ui/message';
|
||||||
import { Spin } from '@/components/ui/spin';
|
import { Spin } from '@/components/ui/spin';
|
||||||
import { Authorization } from '@/constants/authorization';
|
|
||||||
import { getAuthorization } from '@/utils/authorization-util';
|
|
||||||
import request from '@/utils/request';
|
import request from '@/utils/request';
|
||||||
import classNames from 'classnames';
|
import classNames from 'classnames';
|
||||||
import mammoth from 'mammoth';
|
import mammoth from 'mammoth';
|
||||||
|
|
@ -16,22 +14,55 @@ export const DocPreviewer: React.FC<DocPreviewerProps> = ({
|
||||||
className,
|
className,
|
||||||
url,
|
url,
|
||||||
}) => {
|
}) => {
|
||||||
// const url = useGetDocumentUrl();
|
|
||||||
const [htmlContent, setHtmlContent] = useState<string>('');
|
const [htmlContent, setHtmlContent] = useState<string>('');
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
|
|
||||||
const fetchDocument = async () => {
|
const fetchDocument = async () => {
|
||||||
|
if (!url) return;
|
||||||
|
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
|
|
||||||
const res = await request(url, {
|
const res = await request(url, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
responseType: 'blob',
|
responseType: 'blob',
|
||||||
headers: { [Authorization]: getAuthorization() },
|
|
||||||
onError: () => {
|
onError: () => {
|
||||||
message.error('Document parsing failed');
|
message.error('Document parsing failed');
|
||||||
console.error('Error loading document:', url);
|
console.error('Error loading document:', url);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const arrayBuffer = await res.data.arrayBuffer();
|
const blob: Blob = res.data;
|
||||||
|
const contentType: string =
|
||||||
|
blob.type || (res as any).headers?.['content-type'] || '';
|
||||||
|
|
||||||
|
// ---- Detect legacy .doc via MIME or URL ----
|
||||||
|
const cleanUrl = url.split(/[?#]/)[0].toLowerCase();
|
||||||
|
const isDocMime = /application\/msword/i.test(contentType);
|
||||||
|
const isLegacyDocByUrl =
|
||||||
|
cleanUrl.endsWith('.doc') && !cleanUrl.endsWith('.docx');
|
||||||
|
const isLegacyDoc = isDocMime || isLegacyDocByUrl;
|
||||||
|
|
||||||
|
if (isLegacyDoc) {
|
||||||
|
// Do not call mammoth and do not throw an error; instead, show a note in the preview area
|
||||||
|
setHtmlContent(`
|
||||||
|
<div class="flex h-full items-center justify-center">
|
||||||
|
<div class="border border-dashed border-border-normal rounded-xl p-8 max-w-2xl text-center">
|
||||||
|
<p class="text-2xl font-bold mb-4">
|
||||||
|
Preview not available for .doc files
|
||||||
|
</p>
|
||||||
|
<p class="italic text-sm text-muted-foreground leading-relaxed">
|
||||||
|
Mammoth does not support <code>.doc</code> documents.<br/>
|
||||||
|
Inline preview is unavailable.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Standard .docx preview path ----
|
||||||
|
const arrayBuffer = await blob.arrayBuffer();
|
||||||
const result = await mammoth.convertToHtml(
|
const result = await mammoth.convertToHtml(
|
||||||
{ arrayBuffer },
|
{ arrayBuffer },
|
||||||
{ includeDefaultStyleMap: true },
|
{ includeDefaultStyleMap: true },
|
||||||
|
|
@ -43,10 +74,12 @@ export const DocPreviewer: React.FC<DocPreviewerProps> = ({
|
||||||
|
|
||||||
setHtmlContent(styledContent);
|
setHtmlContent(styledContent);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
// Only errors from the mammoth conversion path should surface here
|
||||||
message.error('Document parsing failed');
|
message.error('Document parsing failed');
|
||||||
console.error('Error parsing document:', err);
|
console.error('Error parsing document:', err);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
}
|
}
|
||||||
setLoading(false);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|
@ -54,6 +87,7 @@ export const DocPreviewer: React.FC<DocPreviewerProps> = ({
|
||||||
fetchDocument();
|
fetchDocument();
|
||||||
}
|
}
|
||||||
}, [url]);
|
}, [url]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={classNames(
|
className={classNames(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue