diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 29bd599c6..988351cf6 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -271,7 +271,7 @@ class Agent(LLM, ToolBase): last_calling = "" if len(hist) > 3: st = timer() - user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl) + user_request = await full_question(messages=history, chat_mdl=self.chat_mdl) self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st) else: user_request = history[-1]["content"] @@ -309,7 +309,7 @@ class Agent(LLM, ToolBase): if len(hist) > 12: _hist = [hist[0], hist[1], *hist[-10:]] entire_txt = "" - async for delta_ans in self._generate_streamly_async(_hist): + async for delta_ans in self._generate_streamly(_hist): if not need2cite or cited: yield delta_ans, 0 entire_txt += delta_ans @@ -397,7 +397,7 @@ Respond immediately with your final comprehensive answer. retrievals = self._canvas.get_reference() retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())} formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True) - async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, + async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, {"role": "user", "content": text} ]): yield delta_ans diff --git a/agent/component/categorize.py b/agent/component/categorize.py index 1333889bb..27cffb91c 100644 --- a/agent/component/categorize.py +++ b/agent/component/categorize.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import os import re @@ -97,7 +98,7 @@ class Categorize(LLM, ABC): component_name = "Categorize" @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) - def _invoke(self, **kwargs): + async def _invoke_async(self, **kwargs): if self.check_if_canceled("Categorize processing"): return @@ -121,7 +122,7 @@ class Categorize(LLM, ABC): if self.check_if_canceled("Categorize processing"): return - ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf()) + ans = await chat_mdl.async_chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf()) logging.info(f"input: {user_prompt}, answer: {str(ans)}") if ERROR_PREFIX in ans: raise Exception(ans) @@ -144,5 +145,9 @@ class Categorize(LLM, ABC): self.set_output("category_name", max_category) self.set_output("_next", cpn_ids) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + def _invoke(self, **kwargs): + return asyncio.run(self._invoke_async(**kwargs)) + def thoughts(self) -> str: return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()])) diff --git a/agent/component/llm.py b/agent/component/llm.py index a437025e9..39e043aeb 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -18,9 +18,8 @@ import json import logging import os import re -import threading from copy import deepcopy -from typing import Any, Generator, AsyncGenerator +from typing import Any, AsyncGenerator import json_repair from functools import partial from common.constants import LLMType @@ -168,53 +167,12 @@ class LLM(ComponentBase): sys_prompt = re.sub(rf"<{tag}>(.*?)", "", sys_prompt, flags=re.DOTALL|re.IGNORECASE) return pts, sys_prompt - def _generate(self, msg:list[dict], **kwargs) -> str: - if not self.imgs: - return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) - return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) - async def _generate_async(self, msg: list[dict], **kwargs) -> str: - if not self.imgs and hasattr(self.chat_mdl, "async_chat"): - return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) - if self.imgs and hasattr(self.chat_mdl, "async_chat"): - return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) - return await asyncio.to_thread(self._generate, msg, **kwargs) - - def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]: - ans = "" - last_idx = 0 - endswith_think = False - def delta(txt): - nonlocal ans, last_idx, endswith_think - delta_ans = txt[last_idx:] - ans = txt - - if delta_ans.find("") == 0: - last_idx += len("") - return "" - elif delta_ans.find("") > 0: - delta_ans = txt[last_idx:last_idx+delta_ans.find("")] - last_idx += delta_ans.find("") - return delta_ans - elif delta_ans.endswith(""): - endswith_think = True - elif endswith_think: - endswith_think = False - return "" - - last_idx = len(ans) - if ans.endswith(""): - last_idx -= len("") - return re.sub(r"(|)", "", delta_ans) - if not self.imgs: - for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs): - yield delta(txt) - else: - for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): - yield delta(txt) + return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) + return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) - async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]: + async def _generate_streamly(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]: async def delta_wrapper(txt_iter): ans = "" last_idx = 0 @@ -246,36 +204,13 @@ class LLM(ComponentBase): async for t in txt_iter: yield delta(t) - if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"): + if not self.imgs: async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)): yield t return - if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"): - async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)): - yield t - return - # fallback - loop = asyncio.get_running_loop() - queue: asyncio.Queue = asyncio.Queue() - - def worker(): - try: - for item in self._generate_streamly(msg, **kwargs): - loop.call_soon_threadsafe(queue.put_nowait, item) - except Exception as e: - loop.call_soon_threadsafe(queue.put_nowait, e) - finally: - loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) - - threading.Thread(target=worker, daemon=True).start() - while True: - item = await queue.get() - if item is StopAsyncIteration: - break - if isinstance(item, Exception): - raise item - yield item + async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)): + yield t async def _stream_output_async(self, prompt, msg): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) @@ -407,8 +342,8 @@ class LLM(ComponentBase): def _invoke(self, **kwargs): return asyncio.run(self._invoke_async(**kwargs)) - def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}): - summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt) + async def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}): + summ = await tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt) logging.info(f"[MEMORY]: {summ}") self._canvas.add_memory(user, assist, summ) diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 26753f8b2..60e98d627 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio from functools import partial import json import os @@ -81,7 +82,7 @@ class Retrieval(ToolBase, ABC): component_name = "Retrieval" @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) - def _invoke(self, **kwargs): + async def _invoke_async(self, **kwargs): if self.check_if_canceled("Retrieval processing"): return @@ -174,7 +175,7 @@ class Retrieval(ToolBase, ABC): ) if self._param.cross_languages: - query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages) + query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages) if kbs: query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE) @@ -247,6 +248,10 @@ class Retrieval(ToolBase, ABC): return form_cnt + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) + def _invoke(self, **kwargs): + return asyncio.run(self._invoke_async(**kwargs)) + def thoughts(self) -> str: return """ Keywords: {} diff --git a/agentic_reasoning/deep_research.py b/agentic_reasoning/deep_research.py index d7121245f..20f7017f4 100644 --- a/agentic_reasoning/deep_research.py +++ b/agentic_reasoning/deep_research.py @@ -51,7 +51,7 @@ class DeepResearcher: """Remove Result Tags""" return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT) - def _generate_reasoning(self, msg_history): + async def _generate_reasoning(self, msg_history): """Generate reasoning steps""" query_think = "" if msg_history[-1]["role"] != "user": @@ -59,13 +59,14 @@ class DeepResearcher: else: msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n" - for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}): + async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}): ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) if not ans: continue query_think = ans yield query_think - return query_think + query_think = "" + yield query_think def _extract_search_queries(self, query_think, question, step_index): """Extract search queries from thinking""" @@ -143,10 +144,10 @@ class DeepResearcher: if d["doc_id"] not in dids: chunk_info["doc_aggs"].append(d) - def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos): + async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos): """Extract and summarize relevant information""" summary_think = "" - for ans in self.chat_mdl.chat_streamly( + async for ans in self.chat_mdl.async_chat_streamly( RELEVANT_EXTRACTION_PROMPT.format( prev_reasoning=truncated_prev_reasoning, search_query=search_query, @@ -160,10 +161,11 @@ class DeepResearcher: continue summary_think = ans yield summary_think + summary_think = "" - return summary_think + yield summary_think - def thinking(self, chunk_info: dict, question: str): + async def thinking(self, chunk_info: dict, question: str): executed_search_queries = [] msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}] all_reasoning_steps = [] @@ -180,7 +182,7 @@ class DeepResearcher: # Step 1: Generate reasoning query_think = "" - for ans in self._generate_reasoning(msg_history): + async for ans in self._generate_reasoning(msg_history): query_think = ans yield {"answer": think + self._remove_query_tags(query_think) + "", "reference": {}, "audio_binary": None} @@ -223,7 +225,7 @@ class DeepResearcher: # Step 6: Extract relevant information think += "\n\n" summary_think = "" - for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos): + async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos): summary_think = ans yield {"answer": think + self._remove_result_tags(summary_think) + "", "reference": {}, "audio_binary": None} diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 494678c8b..7380f3524 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -313,7 +313,7 @@ async def retrieval_test(): langs = req.get("cross_languages", []) user_id = current_user.id - def _retrieval_sync(): + async def _retrieval(): local_doc_ids = list(doc_ids) if doc_ids else [] tenant_ids = [] @@ -351,7 +351,7 @@ async def retrieval_test(): _question = question if langs: - _question = cross_languages(kb.tenant_id, None, _question, langs) + _question = await cross_languages(kb.tenant_id, None, _question, langs) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) @@ -361,7 +361,7 @@ async def retrieval_test(): if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - _question += keyword_extraction(chat_mdl, _question) + _question += await keyword_extraction(chat_mdl, _question) labels = label_question(_question, [kb]) ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size, @@ -388,7 +388,7 @@ async def retrieval_test(): return get_json_result(data=ranks) try: - return await asyncio.to_thread(_retrieval_sync) + return await _retrieval() except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message='No chunk found! Check the chunk status please!', diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 8caaaffad..1f46a4098 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -192,6 +192,9 @@ async def add_llm(): elif factory == "OpenRouter": api_key = apikey_json(["api_key", "provider_order"]) + elif factory == "MinerU": + api_key = apikey_json(["api_key", "provider_order"]) + llm = { "tenant_id": current_user.id, "llm_factory": factory, diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index e03e09957..a5c120d31 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1549,11 +1549,11 @@ async def retrieval_test(tenant_id): rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"]) if langs: - question = cross_languages(kb.tenant_id, None, question, langs) + question = await cross_languages(kb.tenant_id, None, question, langs) if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - question += keyword_extraction(chat_mdl, question) + question += await keyword_extraction(chat_mdl, question) ranks = settings.retriever.retrieval( question, diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py index 2e9fd6df3..8bac19ccd 100644 --- a/api/apps/sdk/files.py +++ b/api/apps/sdk/files.py @@ -33,6 +33,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP from common import settings from common.constants import RetCode + @manager.route('/file/upload', methods=['POST']) # noqa: F821 @token_required async def upload(tenant_id): diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index df4a3416f..224a27ccd 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import asyncio import json import re import time @@ -45,6 +44,7 @@ from rag.prompts.generator import cross_languages, keyword_extraction, chunks_fo from common.constants import RetCode, LLMType, StatusEnum from common import settings + @manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @token_required async def create(tenant_id, chat_id): @@ -970,7 +970,7 @@ async def retrieval_test_embedded(): if not tenant_id: return get_error_data_result(message="permission denined.") - def _retrieval_sync(): + async def _retrieval(): local_doc_ids = list(doc_ids) if doc_ids else [] tenant_ids = [] _question = question @@ -991,7 +991,6 @@ async def retrieval_test_embedded(): metas = DocumentService.get_meta_by_kbs(kb_ids) local_doc_ids = apply_meta_data_filter(meta_data_filter, metas, _question, chat_mdl, local_doc_ids) - tenants = UserTenantService.query(user_id=tenant_id) for kb_id in kb_ids: for tenant in tenants: @@ -1007,7 +1006,7 @@ async def retrieval_test_embedded(): return get_error_data_result(message="Knowledgebase not found!") if langs: - _question = cross_languages(kb.tenant_id, None, _question, langs) + _question = await cross_languages(kb.tenant_id, None, _question, langs) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) @@ -1017,7 +1016,7 @@ async def retrieval_test_embedded(): if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - _question += keyword_extraction(chat_mdl, _question) + _question += await keyword_extraction(chat_mdl, _question) labels = label_question(_question, [kb]) ranks = settings.retriever.retrieval( @@ -1037,7 +1036,7 @@ async def retrieval_test_embedded(): return get_json_result(data=ranks) try: - return await asyncio.to_thread(_retrieval_sync) + return await _retrieval() except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message="No chunk found! Check the chunk status please!", @@ -1138,7 +1137,7 @@ async def mindmap(): search_id = req.get("search_id", "") search_app = SearchService.get_detail(search_id) if search_id else {} - mind_map = gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {})) + mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {})) if "error" in mind_map: return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) diff --git a/api/db/init_data.py b/api/db/init_data.py index 7454965eb..1ebc306d3 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import json import os @@ -76,8 +77,7 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_ f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.") chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) - msg = chat_mdl.chat(system="", history=[ - {"role": "user", "content": "Hello!"}], gen_conf={}) + msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})) if msg.find("ERROR: ") == 0: logging.error( "'{}' doesn't work. {}".format( diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 3eaa1cbe5..d5180e195 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -327,7 +327,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): # try to use sql if field mapping is good to go if field_map: logging.debug("Use SQL to retrieval:{}".format(questions[-1])) - ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) + ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) if ans: yield ans return @@ -341,12 +341,12 @@ async def async_chat(dialog, messages, stream=True, **kwargs): prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") if len(questions) > 1 and prompt_config.get("refine_multiturn"): - questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] + questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages)] else: questions = questions[-1:] if prompt_config.get("cross_languages"): - questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] + questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] if dialog.meta_data_filter: metas = DocumentService.get_meta_by_kbs(dialog.kb_ids) @@ -359,7 +359,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): ) if prompt_config.get("keyword", False): - questions[-1] += keyword_extraction(chat_mdl, questions[-1]) + questions[-1] += await keyword_extraction(chat_mdl, questions[-1]) refine_question_ts = timer() @@ -387,7 +387,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): ), ) - for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)): + async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)): if isinstance(think, str): thought = think knowledges = [t for t in think.split("\n") if t] @@ -564,7 +564,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): return -def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): +async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): sys_prompt = """ You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question. Ensure that: @@ -582,9 +582,9 @@ Please write the SQL, only SQL, without any other explanations or text. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question) tried_times = 0 - def get_table(): + async def get_table(): nonlocal sys_prompt, user_prompt, question, tried_times - sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) + sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) sql = re.sub(r"^.*", "", sql, flags=re.DOTALL) logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") sql = re.sub(r"[\r\n]+", " ", sql.lower()) @@ -623,7 +623,7 @@ Please write the SQL, only SQL, without any other explanations or text. return settings.retriever.sql_retrieval(sql, format="json"), sql try: - tbl, sql = get_table() + tbl, sql = await get_table() except Exception as e: user_prompt = """ Table name: {}; @@ -641,7 +641,7 @@ Please write the SQL, only SQL, without any other explanations or text. Please correct the error and write SQL again, only SQL, without any other explanations or text. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e) try: - tbl, sql = get_table() + tbl, sql = await get_table() except Exception: return diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index e4bf64aac..e5505af88 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -318,9 +318,6 @@ class LLMBundle(LLM4Tenant): return value raise value - def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str: - return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs)) - def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs): result_queue: queue.Queue = queue.Queue() @@ -350,23 +347,6 @@ class LLMBundle(LLM4Tenant): raise item yield item - def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): - ans = "" - for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs): - if isinstance(txt, int): - break - - if txt.endswith(""): - ans = txt[: -len("")] - continue - - if not self.verbose_tool_use: - txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - - # cancatination has beend done in async_chat_streamly - ans = txt - yield ans - def _bridge_sync_stream(self, gen): loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py index 84658d246..88689fdab 100644 --- a/api/db/services/tenant_llm_service.py +++ b/api/db/services/tenant_llm_service.py @@ -16,6 +16,7 @@ import os import json import logging +from peewee import IntegrityError from langfuse import Langfuse from common import settings from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, LLMType @@ -274,21 +275,28 @@ class TenantLLMService(CommonService): used_names = {item.llm_name for item in saved_mineru_models} idx = 1 base_name = "mineru-from-env" - candidate = f"{base_name}-{idx}" - while candidate in used_names: - idx += 1 + while True: candidate = f"{base_name}-{idx}" + if candidate in used_names: + idx += 1 + continue - cls.save( - tenant_id=tenant_id, - llm_factory="MinerU", - llm_name=candidate, - model_type=LLMType.OCR.value, - api_key=json.dumps(cfg), - api_base="", - max_tokens=0, - ) - return candidate + try: + cls.save( + tenant_id=tenant_id, + llm_factory="MinerU", + llm_name=candidate, + model_type=LLMType.OCR.value, + api_key=json.dumps(cfg), + api_base="", + max_tokens=0, + ) + return candidate + except IntegrityError: + logging.warning("MinerU env model %s already exists for tenant %s, retry with next name", candidate, tenant_id) + used_names.add(candidate) + idx += 1 + continue @classmethod @DB.connection_context() diff --git a/common/http_client.py b/common/http_client.py index 5c57f8638..5c633d78d 100644 --- a/common/http_client.py +++ b/common/http_client.py @@ -18,6 +18,7 @@ import time from typing import Any, Dict, Optional from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse +from common import settings import httpx logger = logging.getLogger(__name__) @@ -73,6 +74,34 @@ def _redact_sensitive_url_params(url: str) -> str: except Exception: return url +def _is_sensitive_url(url: str) -> bool: + """Return True if URL is one of the configured OAuth endpoints.""" + # Collect known sensitive endpoint URLs from settings + oauth_urls = set() + # GitHub OAuth endpoints + try: + if settings.GITHUB_OAUTH is not None: + url_val = settings.GITHUB_OAUTH.get("url") + if url_val: + oauth_urls.add(url_val) + except Exception: + pass + # Feishu OAuth endpoints + try: + if settings.FEISHU_OAUTH is not None: + for k in ("app_access_token_url", "user_access_token_url"): + url_val = settings.FEISHU_OAUTH.get(k) + if url_val: + oauth_urls.add(url_val) + except Exception: + pass + # Defensive normalization: compare only scheme+netloc+path + url_obj = urlparse(url) + for sensitive_url in oauth_urls: + sensitive_obj = urlparse(sensitive_url) + if (url_obj.scheme, url_obj.netloc, url_obj.path) == (sensitive_obj.scheme, sensitive_obj.netloc, sensitive_obj.path): + return True + return False async def async_request( method: str, @@ -115,20 +144,23 @@ async def async_request( method=method, url=url, headers=headers, **kwargs ) duration = time.monotonic() - start + log_url = "" if _is_sensitive_url else _redact_sensitive_url_params(url) logger.debug( - f"async_request {method} {_redact_sensitive_url_params(url)} -> {response.status_code} in {duration:.3f}s" + f"async_request {method} {log_url} -> {response.status_code} in {duration:.3f}s" ) return response except httpx.RequestError as exc: last_exc = exc if attempt >= retries: + log_url = "" if _is_sensitive_url else _redact_sensitive_url_params(url) logger.warning( - f"async_request exhausted retries for {method} {_redact_sensitive_url_params(url)}: {exc}" + f"async_request exhausted retries for {method} {log_url}" ) raise delay = _get_delay(backoff_factor, attempt) + log_url = "" if _is_sensitive_url else _redact_sensitive_url_params(url) logger.warning( - f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {_redact_sensitive_url_params(url)}: {exc}; retrying in {delay:.2f}s" + f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {log_url}; retrying in {delay:.2f}s" ) await asyncio.sleep(delay) raise last_exc # pragma: no cover diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 4d784d33c..d363e7f06 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -369,6 +369,13 @@ "model_type": "chat", "is_tools": true }, + { + "llm_name": "deepseek-v3.2", + "tags": "LLM,CHAT,128K", + "max_tokens": 128000, + "model_type": "chat", + "is_tools": true + }, { "llm_name": "deepseek-r1", "tags": "LLM,CHAT,64K", diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index 82f2e9248..afd9b98bc 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -17,6 +17,8 @@ minio: user: 'rag_flow' password: 'infini_rag_flow' host: 'localhost:9000' + bucket: '' + prefix_path: '' es: hosts: 'http://localhost:1200' username: 'elastic' diff --git a/deepdoc/parser/mineru_parser.py b/deepdoc/parser/mineru_parser.py index 57840ebb8..2883bf881 100644 --- a/deepdoc/parser/mineru_parser.py +++ b/deepdoc/parser/mineru_parser.py @@ -54,7 +54,7 @@ class MinerUContentType(StrEnum): class MinerUParser(RAGFlowPdfParser): - def __init__(self, mineru_path: str = "mineru", mineru_api: str = "http://host.docker.internal:9987", mineru_server_url: str = ""): + def __init__(self, mineru_path: str = "mineru", mineru_api: str = "", mineru_server_url: str = ""): self.mineru_path = Path(mineru_path) self.mineru_api = mineru_api.rstrip("/") self.mineru_server_url = mineru_server_url.rstrip("/") @@ -176,7 +176,9 @@ class MinerUParser(RAGFlowPdfParser): self.using_api = openapi_exists return openapi_exists, reason else: - self.logger.info("[MinerU] api not exists.") + reason = "[MinerU] api not exists. Setting MINERU_SERVER_URL if your backend is vlm-http-client." + self.logger.info(reason) + return False, reason except Exception as e: reason = f"[MinerU] Unexpected error during api check: {e}" self.logger.error(f"[MinerU] Unexpected error during api check: {e}") diff --git a/docker/.env b/docker/.env index 3d90d2c55..51d2cf73b 100644 --- a/docker/.env +++ b/docker/.env @@ -236,10 +236,11 @@ USE_DOCLING=false # Enable Mineru USE_MINERU=false MINERU_EXECUTABLE="$HOME/uv_tools/.venv/bin/mineru" -MINERU_DELETE_OUTPUT=0 # keep output directory -MINERU_BACKEND=pipeline # or another backend you prefer +# Uncommenting these lines will automatically add MinerU to the model provider whenever possible. +# MINERU_DELETE_OUTPUT=0 # keep output directory +# MINERU_BACKEND=pipeline # or another backend you prefer # pptx support -DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 \ No newline at end of file +DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 diff --git a/docker/.env.single-bucket-example b/docker/.env.single-bucket-example new file mode 100644 index 000000000..96120b388 --- /dev/null +++ b/docker/.env.single-bucket-example @@ -0,0 +1,108 @@ +# Example: Single Bucket Mode Configuration +# +# This file shows how to configure RAGFlow to use a single MinIO/S3 bucket +# with directory structure instead of creating multiple buckets. + +# ============================================================================ +# MinIO/S3 Configuration for Single Bucket Mode +# ============================================================================ + +# MinIO/S3 Endpoint (with port if not default) +# For HTTPS (port 443), the connection will automatically use secure=True +export MINIO_HOST=minio.example.com:443 + +# Access credentials +export MINIO_USER=your-access-key +export MINIO_PASSWORD=your-secret-password-here + +# Single Bucket Configuration (NEW!) +# If set, all data will be stored in this bucket instead of creating +# separate buckets for each knowledge base +export MINIO_BUCKET=ragflow-bucket + +# Optional: Prefix path within the bucket (NEW!) +# If set, all files will be stored under this prefix +# Example: bucket/prefix_path/kb_id/file.pdf +export MINIO_PREFIX_PATH=ragflow + +# ============================================================================ +# Alternative: Multi-Bucket Mode (Default) +# ============================================================================ +# +# To use the original multi-bucket mode, simply don't set MINIO_BUCKET +# and MINIO_PREFIX_PATH: +# +# export MINIO_HOST=minio.local +# export MINIO_USER=admin +# export MINIO_PASSWORD=password +# # MINIO_BUCKET not set +# # MINIO_PREFIX_PATH not set + +# ============================================================================ +# Storage Mode Selection (Environment Variable) +# ============================================================================ +# +# Make sure this is set to use MinIO (default) +export STORAGE_IMPL=MINIO + +# ============================================================================ +# Example Path Structures +# ============================================================================ +# +# Multi-Bucket Mode (default): +# bucket: kb_12345/file.pdf +# bucket: kb_67890/file.pdf +# bucket: folder_abc/file.txt +# +# Single Bucket Mode (MINIO_BUCKET set): +# bucket: ragflow-bucket/kb_12345/file.pdf +# bucket: ragflow-bucket/kb_67890/file.pdf +# bucket: ragflow-bucket/folder_abc/file.txt +# +# Single Bucket with Prefix (both set): +# bucket: ragflow-bucket/ragflow/kb_12345/file.pdf +# bucket: ragflow-bucket/ragflow/kb_67890/file.pdf +# bucket: ragflow-bucket/ragflow/folder_abc/file.txt + +# ============================================================================ +# IAM Policy for Single Bucket Mode +# ============================================================================ +# +# When using single bucket mode, you only need permissions for one bucket: +# +# { +# "Version": "2012-10-17", +# "Statement": [ +# { +# "Effect": "Allow", +# "Action": ["s3:*"], +# "Resource": [ +# "arn:aws:s3:::ragflow-bucket", +# "arn:aws:s3:::ragflow-bucket/*" +# ] +# } +# ] +# } + +# ============================================================================ +# Testing the Configuration +# ============================================================================ +# +# After setting these variables, you can test with MinIO Client (mc): +# +# # Configure mc alias +# mc alias set ragflow https://minio.example.com:443 \ +# your-access-key \ +# your-secret-password-here +# +# # List bucket contents +# mc ls ragflow/ragflow-bucket/ +# +# # If prefix is set, check the prefix +# mc ls ragflow/ragflow-bucket/ragflow/ +# +# # Test write permission +# echo "test" | mc pipe ragflow/ragflow-bucket/ragflow/_test.txt +# +# # Clean up test file +# mc rm ragflow/ragflow-bucket/ragflow/_test.txt diff --git a/docker/service_conf.yaml.template b/docker/service_conf.yaml.template index 72e7a6d73..1500c2eaf 100644 --- a/docker/service_conf.yaml.template +++ b/docker/service_conf.yaml.template @@ -17,6 +17,8 @@ minio: user: '${MINIO_USER:-rag_flow}' password: '${MINIO_PASSWORD:-infini_rag_flow}' host: '${MINIO_HOST:-minio}:9000' + bucket: '${MINIO_BUCKET:-}' + prefix_path: '${MINIO_PREFIX_PATH:-}' es: hosts: 'http://${ES_HOST:-es01}:9200' username: '${ES_USER:-elastic}' diff --git a/docs/single-bucket-mode.md b/docs/single-bucket-mode.md new file mode 100644 index 000000000..08179cfcf --- /dev/null +++ b/docs/single-bucket-mode.md @@ -0,0 +1,162 @@ +# Single Bucket Mode for MinIO/S3 + +## Overview + +By default, RAGFlow creates one bucket per Knowledge Base (dataset) and one bucket per user folder. This can be problematic when: + +- Your cloud provider charges per bucket +- Your IAM policy restricts bucket creation +- You want all data organized in a single bucket with directory structure + +The **Single Bucket Mode** allows you to configure RAGFlow to use a single bucket with a directory structure instead of multiple buckets. + +## How It Works + +### Default Mode (Multiple Buckets) + +``` +bucket: kb_12345/ + └── document_1.pdf +bucket: kb_67890/ + └── document_2.pdf +bucket: folder_abc/ + └── file_3.txt +``` + +### Single Bucket Mode (with prefix_path) + +``` +bucket: ragflow-bucket/ + └── ragflow/ + ├── kb_12345/ + │ └── document_1.pdf + ├── kb_67890/ + │ └── document_2.pdf + └── folder_abc/ + └── file_3.txt +``` + +## Configuration + +### MinIO Configuration + +Edit your `service_conf.yaml` or set environment variables: + +```yaml +minio: + user: "your-access-key" + password: "your-secret-key" + host: "minio.example.com:443" + bucket: "ragflow-bucket" # Default bucket name + prefix_path: "ragflow" # Optional prefix path +``` + +Or using environment variables: + +```bash +export MINIO_USER=your-access-key +export MINIO_PASSWORD=your-secret-key +export MINIO_HOST=minio.example.com:443 +export MINIO_BUCKET=ragflow-bucket +export MINIO_PREFIX_PATH=ragflow +``` + +### S3 Configuration (already supported) + +```yaml +s3: + access_key: "your-access-key" + secret_key: "your-secret-key" + endpoint_url: "https://s3.amazonaws.com" + bucket: "my-ragflow-bucket" + prefix_path: "production" + region: "us-east-1" +``` + +## IAM Policy Example + +When using single bucket mode, you only need permissions for one bucket: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": ["s3:*"], + "Resource": [ + "arn:aws:s3:::ragflow-bucket", + "arn:aws:s3:::ragflow-bucket/*" + ] + } + ] +} +``` + +## Migration from Multi-Bucket to Single Bucket + +If you're migrating from multi-bucket mode to single-bucket mode: + +1. **Set environment variables** for the new configuration +2. **Restart RAGFlow** services +3. **Migrate existing data** (optional): + +```bash +# Example using mc (MinIO Client) +mc alias set old-minio http://old-minio:9000 ACCESS_KEY SECRET_KEY +mc alias set new-minio https://new-minio:443 ACCESS_KEY SECRET_KEY + +# List all knowledge base buckets +mc ls old-minio/ | grep kb_ | while read -r line; do + bucket=$(echo $line | awk '{print $5}') + # Copy each bucket to the new structure + mc cp --recursive old-minio/$bucket/ new-minio/ragflow-bucket/ragflow/$bucket/ +done +``` + +## Toggle Between Modes + +### Enable Single Bucket Mode + +```yaml +minio: + bucket: "my-single-bucket" + prefix_path: "ragflow" +``` + +### Disable (Use Multi-Bucket Mode) + +```yaml +minio: + # Leave bucket and prefix_path empty or commented out + # bucket: '' + # prefix_path: '' +``` + +## Troubleshooting + +### Issue: Access Denied errors + +**Solution**: Ensure your IAM policy grants access to the bucket specified in the configuration. + +### Issue: Files not found after switching modes + +**Solution**: The path structure changes between modes. You'll need to migrate existing data. + +### Issue: Connection fails with HTTPS + +**Solution**: Ensure `secure: True` is set in the MinIO connection (automatically handled for port 443). + +## Storage Backends Supported + +- ✅ **MinIO** - Full support with single bucket mode +- ✅ **AWS S3** - Full support with single bucket mode +- ✅ **Alibaba OSS** - Full support with single bucket mode +- ✅ **Azure Blob** - Uses container-based structure (different paradigm) +- ⚠️ **OpenDAL** - Depends on underlying storage backend + +## Performance Considerations + +- **Single bucket mode** may have slightly better performance for bucket listing operations +- **Multi-bucket mode** provides better isolation and organization for large deployments +- Choose based on your specific requirements and infrastructure constraints diff --git a/graphrag/search.py b/graphrag/search.py index 7399ea393..860c58906 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import asyncio import json import logging from collections import defaultdict @@ -44,7 +43,7 @@ class KGSearch(Dealer): return response def query_rewrite(self, llm, question, idxnms, kb_ids): - ty2ents = asyncio.run(get_entity_type2samples(idxnms, kb_ids)) + ty2ents = get_entity_type2samples(idxnms, kb_ids) hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {}) diff --git a/graphrag/utils.py b/graphrag/utils.py index 9b3dc2c2b..7e3fec1a9 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -626,8 +626,8 @@ def merge_tuples(list1, list2): return result -async def get_entity_type2samples(idxnms, kb_ids: list): - es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids) +def get_entity_type2samples(idxnms, kb_ids: list): + es_res = settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids) res = defaultdict(list) for id in es_res.ids: diff --git a/rag/app/naive.py b/rag/app/naive.py index 8315f801f..353504d77 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -68,7 +68,7 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese" from api.db.services.tenant_llm_service import TenantLLMService env_name = TenantLLMService.ensure_mineru_from_env(tenant_id) - candidates = TenantLLMService.query(tenant_id=tenant_id, llm_factory="MinerU", model_type=LLMType.OCR.value) + candidates = TenantLLMService.query(tenant_id=tenant_id, llm_factory="MinerU", model_type=LLMType.OCR) if candidates: mineru_llm_name = candidates[0].llm_name elif env_name: @@ -78,7 +78,7 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese" if mineru_llm_name: try: - ocr_model = LLMBundle(tenant_id, LLMType.OCR, llm_name=mineru_llm_name, lang=lang) + ocr_model = LLMBundle(tenant_id=tenant_id, llm_type=LLMType.OCR, llm_name=mineru_llm_name, lang=lang) pdf_parser = ocr_model.mdl sections, tables = pdf_parser.parse_pdf( filepath=filename, @@ -711,8 +711,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca layout_recognizer = layout_recognizer_raw if isinstance(layout_recognizer_raw, str): lowered = layout_recognizer_raw.lower() - if lowered.startswith("mineru@"): - parser_model_name = layout_recognizer_raw.split("@", 1)[1] + if lowered.endswith("@mineru"): + parser_model_name = layout_recognizer_raw.split("@", 1)[0] layout_recognizer = "MinerU" if parser_config.get("analyze_hyperlink", False) and is_root: diff --git a/rag/flow/extractor/extractor.py b/rag/flow/extractor/extractor.py index 1b97fd1ee..8061086a7 100644 --- a/rag/flow/extractor/extractor.py +++ b/rag/flow/extractor/extractor.py @@ -98,7 +98,7 @@ class Extractor(ProcessBase, LLM): args[chunks_key] = ck["text"] msg, sys_prompt = self._sys_prompt_and_msg([], args) msg.insert(0, {"role": "system", "content": sys_prompt}) - ck[self._param.field_name] = self._generate(msg) + ck[self._param.field_name] = await self._generate_async(msg) prog += 1./len(chunks) if i % (len(chunks)//100+1) == 1: self.callback(prog, f"{i+1} / {len(chunks)}") @@ -106,6 +106,6 @@ class Extractor(ProcessBase, LLM): else: msg, sys_prompt = self._sys_prompt_and_msg([], args) msg.insert(0, {"role": "system", "content": sys_prompt}) - self.set_output("chunks", [{self._param.field_name: self._generate(msg)}]) + self.set_output("chunks", [{self._param.field_name: await self._generate_async(msg)}]) diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index 319a16d88..f32fb1719 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -240,10 +240,7 @@ class Parser(ProcessBase): parse_method = parse_method or "" if isinstance(raw_parse_method, str): lowered = raw_parse_method.lower() - if lowered.startswith("mineru@"): - parser_model_name = raw_parse_method.split("@", 1)[1] - parse_method = "MinerU" - elif lowered.endswith("@mineru"): + if lowered.endswith("@mineru"): parser_model_name = raw_parse_method.rsplit("@", 1)[0] parse_method = "MinerU" @@ -853,4 +850,4 @@ class Parser(ProcessBase): for t in tasks: t.cancel() await asyncio.gather(*tasks, return_exceptions=True) - raise \ No newline at end of file + raise diff --git a/rag/llm/ocr_model.py b/rag/llm/ocr_model.py index 183ef2041..b18a16a36 100644 --- a/rag/llm/ocr_model.py +++ b/rag/llm/ocr_model.py @@ -22,7 +22,7 @@ from deepdoc.parser.mineru_parser import MinerUParser class Base: - def __init__(self, key: str, model_name: str, **kwargs): + def __init__(self, key: str | dict, model_name: str, **kwargs): self.model_name = model_name def parse_pdf(self, filepath: str, binary=None, **kwargs) -> Tuple[Any, Any]: @@ -32,23 +32,23 @@ class Base: class MinerUOcrModel(Base, MinerUParser): _FACTORY_NAME = "MinerU" - def __init__(self, key: str, model_name: str, **kwargs): + def __init__(self, key: str | dict, model_name: str, **kwargs): Base.__init__(self, key, model_name, **kwargs) - cfg = {} + config = {} if key: try: - cfg = json.loads(key) + config = json.loads(key) except Exception: - cfg = {} - - self.mineru_api = cfg.get("MINERU_APISERVER", os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987")) - self.mineru_output_dir = cfg.get("MINERU_OUTPUT_DIR", os.environ.get("MINERU_OUTPUT_DIR", "")) - self.mineru_backend = cfg.get("MINERU_BACKEND", os.environ.get("MINERU_BACKEND", "pipeline")) - self.mineru_server_url = cfg.get("MINERU_SERVER_URL", os.environ.get("MINERU_SERVER_URL", "")) - self.mineru_delete_output = bool(int(cfg.get("MINERU_DELETE_OUTPUT", os.environ.get("MINERU_DELETE_OUTPUT", 1)))) + config = {} + config = config["api_key"] + self.mineru_api = config.get("mineru_apiserver", os.environ.get("MINERU_APISERVER", "")) + self.mineru_output_dir = config.get("mineru_output_dir", os.environ.get("MINERU_OUTPUT_DIR", "")) + self.mineru_backend = config.get("mineru_backend", os.environ.get("MINERU_BACKEND", "pipeline")) + self.mineru_server_url = config.get("mineru_server_url", os.environ.get("MINERU_SERVER_URL", "")) + self.mineru_delete_output = bool(int(config.get("mineru_delete_output", os.environ.get("MINERU_DELETE_OUTPUT", 1)))) self.mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru") - logging.info(f"Parsered MinerU config: {cfg}") + logging.info(f"Parsed MinerU config: {config}") MinerUParser.__init__(self, mineru_path=self.mineru_executable, mineru_api=self.mineru_api, mineru_server_url=self.mineru_server_url) diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index c50e84ebc..494e1915b 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -33,6 +33,22 @@ class RagTokenizer(infinity.rag_tokenizer.RagTokenizer): return super().fine_grained_tokenize(tks) +def is_chinese(s): + return infinity.rag_tokenizer.is_chinese(s) + + +def is_number(s): + return infinity.rag_tokenizer.is_number(s) + + +def is_alphabet(s): + return infinity.rag_tokenizer.is_alphabet(s) + + +def naive_qie(txt): + return infinity.rag_tokenizer.naive_qie(txt) + + tokenizer = RagTokenizer() tokenize = tokenizer.tokenize fine_grained_tokenize = tokenizer.fine_grained_tokenize diff --git a/rag/nlp/search.py b/rag/nlp/search.py index f5dd2d4de..d2129e77f 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import re @@ -607,7 +608,7 @@ class Dealer: if not toc: return chunks - ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2) + ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn*2)) if not ids: return chunks diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 523935277..621a460ad 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -170,13 +170,13 @@ def citation_plus(sources: str) -> str: return template.render(example=citation_prompt(), sources=sources) -def keyword_extraction(chat_mdl, content, topn=3): +async def keyword_extraction(chat_mdl, content, topn=3): template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE) rendered_prompt = template.render(content=content, topn=topn) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) - kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) + kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) @@ -185,13 +185,13 @@ def keyword_extraction(chat_mdl, content, topn=3): return kwd -def question_proposal(chat_mdl, content, topn=3): +async def question_proposal(chat_mdl, content, topn=3): template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE) rendered_prompt = template.render(content=content, topn=topn) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) - kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) + kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) @@ -200,7 +200,7 @@ def question_proposal(chat_mdl, content, topn=3): return kwd -def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): +async def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): from common.constants import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService @@ -229,12 +229,12 @@ def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_ language=language, ) - ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}]) + ans = await chat_mdl.async_chat(rendered_prompt, [{"role": "user", "content": "Output: "}]) ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] -def cross_languages(tenant_id, llm_id, query, languages=[]): +async def cross_languages(tenant_id, llm_id, query, languages=[]): from common.constants import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService @@ -247,14 +247,14 @@ def cross_languages(tenant_id, llm_id, query, languages=[]): rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render() rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages) - ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2}) + ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2}) ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) if ans.find("**ERROR**") >= 0: return query return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()]) -def content_tagging(chat_mdl, content, all_tags, examples, topn=3): +async def content_tagging(chat_mdl, content, all_tags, examples, topn=3): template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE) for ex in examples: @@ -269,7 +269,7 @@ def content_tagging(chat_mdl, content, all_tags, examples, topn=3): msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) - kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5}) + kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.5}) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) @@ -352,7 +352,7 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis else: template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER) context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc) - kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}]) + kwd = await chat_mdl.async_chat(context, [{"role": "user", "content": "Please analyze it."}]) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) @@ -361,14 +361,6 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis return kwd -async def _chat_async(chat_mdl, system: str, history: list, **kwargs): - chat_async = getattr(chat_mdl, "async_chat", None) - if chat_async and asyncio.iscoroutinefunction(chat_async): - return await chat_async(system, history, **kwargs) - return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs) - - - async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): if not tools_description: return "", 0 @@ -380,8 +372,7 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict], hist[-1]["content"] += user_prompt else: hist.append({"role": "user", "content": user_prompt}) - json_str = await _chat_async( - chat_mdl, + json_str = await chat_mdl.async_chat( template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")), hist[1:], stop=["<|stop|>"], @@ -402,7 +393,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple else: hist.append({"role": "user", "content": user_prompt}) _, msg = message_fit_in(hist, chat_mdl.max_length) - ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:]) + ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:]) ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) return """ **Observation** @@ -422,14 +413,14 @@ def structured_output_prompt(schema=None) -> str: return template.render(schema=schema) -def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str: +async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str: template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY) system_prompt = template.render(name=name, params=json.dumps(params, ensure_ascii=False, indent=2), result=result) user_prompt = "→ Summary: " _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) - ans = chat_mdl.chat(msg[0]["content"], msg[1:]) + ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:]) return re.sub(r"^.*", "", ans, flags=re.DOTALL) @@ -438,11 +429,11 @@ async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summar system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)]) user_prompt = " → rank: " _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) - ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>") + ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>") return re.sub(r"^.*", "", ans, flags=re.DOTALL) -def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: +async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: meta_data_structure = {} for key, values in meta_data.items(): meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values @@ -453,7 +444,7 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: user_question=query ) user_prompt = "Generate filters:" - ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}]) + ans = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}]) ans = re.sub(r"(^.*|```json\n|```\n*$)", "", ans, flags=re.DOTALL) try: ans = json_repair.loads(ans) @@ -466,13 +457,13 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: return {"conditions": []} -def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None): +async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None): from graphrag.utils import get_llm_cache, set_llm_cache cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf) if cached: return json_repair.loads(cached) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) - ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf) + ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:],gen_conf=gen_conf) ans = re.sub(r"(^.*|```json\n|```\n*$)", "", ans, flags=re.DOTALL) try: res = json_repair.loads(ans) @@ -483,10 +474,10 @@ def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None): TOC_DETECTION = load_prompt("toc_detection") -def detect_table_of_contents(page_1024:list[str], chat_mdl): +async def detect_table_of_contents(page_1024:list[str], chat_mdl): toc_secs = [] for i, sec in enumerate(page_1024[:22]): - ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl) + ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl) if toc_secs and not ans["exists"]: break toc_secs.append(sec) @@ -495,14 +486,14 @@ def detect_table_of_contents(page_1024:list[str], chat_mdl): TOC_EXTRACTION = load_prompt("toc_extraction") TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue") -def extract_table_of_contents(toc_pages, chat_mdl): +async def extract_table_of_contents(toc_pages, chat_mdl): if not toc_pages: return [] - return gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl) + return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl) -def toc_index_extractor(toc:list[dict], content:str, chat_mdl): +async def toc_index_extractor(toc:list[dict], content:str, chat_mdl): tob_extractor_prompt = """ You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format. @@ -525,11 +516,11 @@ def toc_index_extractor(toc:list[dict], content:str, chat_mdl): Directly return the final JSON structure. Do not output anything else.""" prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content - return gen_json(prompt, "Only JSON please.", chat_mdl) + return await gen_json(prompt, "Only JSON please.", chat_mdl) TOC_INDEX = load_prompt("toc_index") -def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): +async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): if not toc_arr or not sections: return [] @@ -601,7 +592,7 @@ def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): e = toc_arr[e]["indices"][0] for j in range(st_i, min(e+1, len(sections))): - ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render( + ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render( structure=it["structure"], title=it["title"], text=sections[j]), "Only JSON please.", chat_mdl) @@ -614,7 +605,7 @@ def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): return toc_arr -def check_if_toc_transformation_is_complete(content, toc, chat_mdl): +async def check_if_toc_transformation_is_complete(content, toc, chat_mdl): prompt = """ You are given a raw table of contents and a table of contents. Your job is to check if the table of contents is complete. @@ -627,11 +618,11 @@ def check_if_toc_transformation_is_complete(content, toc, chat_mdl): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc - response = gen_json(prompt, "Only JSON please.", chat_mdl) + response = await gen_json(prompt, "Only JSON please.", chat_mdl) return response['completed'] -def toc_transformer(toc_pages, chat_mdl): +async def toc_transformer(toc_pages, chat_mdl): init_prompt = """ You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents. @@ -654,8 +645,8 @@ def toc_transformer(toc_pages, chat_mdl): def clean_toc(arr): for a in arr: a["title"] = re.sub(r"[.·….]{2,}", "", a["title"]) - last_complete = gen_json(prompt, "Only JSON please.", chat_mdl) - if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) + last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl) + if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) clean_toc(last_complete) if if_complete == "yes": return last_complete @@ -672,21 +663,21 @@ def toc_transformer(toc_pages, chat_mdl): {json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)} Please continue the json structure, directly output the remaining part of the json structure.""" - new_complete = gen_json(prompt, "Only JSON please.", chat_mdl) + new_complete = await gen_json(prompt, "Only JSON please.", chat_mdl) if not new_complete or str(last_complete).find(str(new_complete)) >= 0: break clean_toc(new_complete) last_complete.extend(new_complete) - if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) + if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) return last_complete TOC_LEVELS = load_prompt("assign_toc_levels") -def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}): +async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}): if not toc_secs: return [] - return gen_json( + return await gen_json( PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(), str(toc_secs), chat_mdl, @@ -699,7 +690,7 @@ TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user") # Generate TOC from text chunks with text llms async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None): try: - ans = gen_json( + ans = await gen_json( PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(), PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])), chat_mdl, @@ -782,7 +773,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None): raw_structure = [x.get("title", "") for x in filtered] # Assign hierarchy levels using LLM - toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9}) + toc_with_levels = await assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9}) if not toc_with_levels: return [] @@ -807,10 +798,10 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None): TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system") TOC_RELEVANCE_USER = load_prompt("toc_relevance_user") -def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6): +async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6): import numpy as np try: - ans = gen_json( + ans = await gen_json( PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(), PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])), chat_mdl, diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 0094c081c..1a0c51600 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -323,12 +323,7 @@ async def build_chunks(task, progress_callback): cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn}) if not cached: async with chat_limiter: - cached = await asyncio.to_thread( - keyword_extraction, - chat_mdl, - d["content_with_weight"], - topn, - ) + cached = await keyword_extraction(chat_mdl, d["content_with_weight"], topn) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn}) if cached: d["important_kwd"] = cached.split(",") @@ -356,12 +351,7 @@ async def build_chunks(task, progress_callback): cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn}) if not cached: async with chat_limiter: - cached = await asyncio.to_thread( - question_proposal, - chat_mdl, - d["content_with_weight"], - topn, - ) + cached = await question_proposal(chat_mdl, d["content_with_weight"], topn) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn}) if cached: d["question_kwd"] = cached.split("\n") @@ -414,8 +404,7 @@ async def build_chunks(task, progress_callback): if not picked_examples: picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}}) async with chat_limiter: - cached = await asyncio.to_thread( - content_tagging, + cached = await content_tagging( chat_mdl, d["content_with_weight"], all_tags, diff --git a/rag/utils/minio_conn.py b/rag/utils/minio_conn.py index e0913e98b..a81fb38ab 100644 --- a/rag/utils/minio_conn.py +++ b/rag/utils/minio_conn.py @@ -28,8 +28,51 @@ from common import settings class RAGFlowMinio: def __init__(self): self.conn = None + # Use `or None` to convert empty strings to None, ensuring single-bucket + # mode is truly disabled when not configured + self.bucket = settings.MINIO.get('bucket', None) or None + self.prefix_path = settings.MINIO.get('prefix_path', None) or None self.__open__() + @staticmethod + def use_default_bucket(method): + def wrapper(self, bucket, *args, **kwargs): + # If there is a default bucket, use the default bucket + # but preserve the original bucket identifier so it can be + # used as a path prefix inside the physical/default bucket. + original_bucket = bucket + actual_bucket = self.bucket if self.bucket else bucket + if self.bucket: + # pass original identifier forward for use by other decorators + kwargs['_orig_bucket'] = original_bucket + return method(self, actual_bucket, *args, **kwargs) + return wrapper + + @staticmethod + def use_prefix_path(method): + def wrapper(self, bucket, fnm, *args, **kwargs): + # If a default MINIO bucket is configured, the use_default_bucket + # decorator will have replaced the `bucket` arg with the physical + # bucket name and forwarded the original identifier as `_orig_bucket`. + # Prefer that original identifier when constructing the key path so + # objects are stored under //... + orig_bucket = kwargs.pop('_orig_bucket', None) + + if self.prefix_path: + # If a prefix_path is configured, include it and then the identifier + if orig_bucket: + fnm = f"{self.prefix_path}/{orig_bucket}/{fnm}" + else: + fnm = f"{self.prefix_path}/{fnm}" + else: + # No prefix_path configured. If orig_bucket exists and the + # physical bucket equals configured default, use orig_bucket as a path. + if orig_bucket and bucket == self.bucket: + fnm = f"{orig_bucket}/{fnm}" + + return method(self, bucket, fnm, *args, **kwargs) + return wrapper + def __open__(self): try: if self.conn: @@ -52,19 +95,27 @@ class RAGFlowMinio: self.conn = None def health(self): - bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" - if not self.conn.bucket_exists(bucket): - self.conn.make_bucket(bucket) + bucket = self.bucket if self.bucket else "ragflow-bucket" + fnm = "_health_check" + if self.prefix_path: + fnm = f"{self.prefix_path}/{fnm}" + binary = b"_t@@@1" + # Don't try to create bucket - it should already exist + # if not self.conn.bucket_exists(bucket): + # self.conn.make_bucket(bucket) r = self.conn.put_object(bucket, fnm, BytesIO(binary), len(binary) ) return r + @use_default_bucket + @use_prefix_path def put(self, bucket, fnm, binary, tenant_id=None): for _ in range(3): try: - if not self.conn.bucket_exists(bucket): + # Note: bucket must already exist - we don't have permission to create buckets + if not self.bucket and not self.conn.bucket_exists(bucket): self.conn.make_bucket(bucket) r = self.conn.put_object(bucket, fnm, @@ -77,12 +128,16 @@ class RAGFlowMinio: self.__open__() time.sleep(1) + @use_default_bucket + @use_prefix_path def rm(self, bucket, fnm, tenant_id=None): try: self.conn.remove_object(bucket, fnm) except Exception: logging.exception(f"Fail to remove {bucket}/{fnm}:") + @use_default_bucket + @use_prefix_path def get(self, bucket, filename, tenant_id=None): for _ in range(1): try: @@ -92,8 +147,10 @@ class RAGFlowMinio: logging.exception(f"Fail to get {bucket}/{filename}") self.__open__() time.sleep(1) - return None + return + @use_default_bucket + @use_prefix_path def obj_exist(self, bucket, filename, tenant_id=None): try: if not self.conn.bucket_exists(bucket): @@ -109,6 +166,7 @@ class RAGFlowMinio: logging.exception(f"obj_exist {bucket}/{filename} got exception") return False + @use_default_bucket def bucket_exists(self, bucket): try: if not self.conn.bucket_exists(bucket): @@ -122,6 +180,8 @@ class RAGFlowMinio: logging.exception(f"bucket_exist {bucket} got exception") return False + @use_default_bucket + @use_prefix_path def get_presigned_url(self, bucket, fnm, expires, tenant_id=None): for _ in range(10): try: @@ -130,20 +190,50 @@ class RAGFlowMinio: logging.exception(f"Fail to get_presigned {bucket}/{fnm}:") self.__open__() time.sleep(1) - return None + return - def remove_bucket(self, bucket): + @use_default_bucket + def remove_bucket(self, bucket, **kwargs): + orig_bucket = kwargs.pop('_orig_bucket', None) try: - if self.conn.bucket_exists(bucket): - objects_to_delete = self.conn.list_objects(bucket, recursive=True) + if self.bucket: + # Single bucket mode: remove objects with prefix + prefix = "" + if self.prefix_path: + prefix = f"{self.prefix_path}/" + if orig_bucket: + prefix += f"{orig_bucket}/" + + # List objects with prefix + objects_to_delete = self.conn.list_objects(bucket, prefix=prefix, recursive=True) for obj in objects_to_delete: self.conn.remove_object(bucket, obj.object_name) - self.conn.remove_bucket(bucket) + # Do NOT remove the physical bucket + else: + if self.conn.bucket_exists(bucket): + objects_to_delete = self.conn.list_objects(bucket, recursive=True) + for obj in objects_to_delete: + self.conn.remove_object(bucket, obj.object_name) + self.conn.remove_bucket(bucket) except Exception: logging.exception(f"Fail to remove bucket {bucket}") + def _resolve_bucket_and_path(self, bucket, fnm): + if self.bucket: + if self.prefix_path: + fnm = f"{self.prefix_path}/{bucket}/{fnm}" + else: + fnm = f"{bucket}/{fnm}" + bucket = self.bucket + elif self.prefix_path: + fnm = f"{self.prefix_path}/{fnm}" + return bucket, fnm + def copy(self, src_bucket, src_path, dest_bucket, dest_path): try: + src_bucket, src_path = self._resolve_bucket_and_path(src_bucket, src_path) + dest_bucket, dest_path = self._resolve_bucket_and_path(dest_bucket, dest_path) + if not self.conn.bucket_exists(dest_bucket): self.conn.make_bucket(dest_bucket) diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py index a260daebc..1f52f6f63 100644 --- a/rag/utils/opendal_conn.py +++ b/rag/utils/opendal_conn.py @@ -41,13 +41,9 @@ def get_opendal_config(): scheme = opendal_config.get("scheme") config_data = opendal_config.get("config", {}) kwargs = {"scheme": scheme, **config_data} - redacted_kwargs = kwargs.copy() - if 'password' in redacted_kwargs: - redacted_kwargs['password'] = '***REDACTED***' - if 'connection_string' in redacted_kwargs and 'password' in redacted_kwargs: - import re - redacted_kwargs['connection_string'] = re.sub(r':[^@]+@', ':***REDACTED***@', redacted_kwargs['connection_string']) - logging.info("Loaded OpenDAL configuration from yaml: %s", redacted_kwargs) + safe_log_keys=['scheme', 'host', 'port', 'database', 'table'] + loggable_kwargs = {k: v for k, v in kwargs.items() if k in safe_log_keys} + logging.info("Loaded OpenDAL configuration(non sensitive): %s", loggable_kwargs) return kwargs except Exception as e: logging.error("Failed to load OpenDAL configuration from yaml: %s", str(e)) diff --git a/web/src/assets/svg/llm/mineru-bright.svg b/web/src/assets/svg/llm/mineru-bright.svg new file mode 100644 index 000000000..7b4c3257b --- /dev/null +++ b/web/src/assets/svg/llm/mineru-bright.svg @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/svg/llm/mineru-dark.svg b/web/src/assets/svg/llm/mineru-dark.svg new file mode 100644 index 000000000..755fe0f3c --- /dev/null +++ b/web/src/assets/svg/llm/mineru-dark.svg @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/components/layout-recognize-form-field.tsx b/web/src/components/layout-recognize-form-field.tsx index 2b2dc7eda..efaaba52f 100644 --- a/web/src/components/layout-recognize-form-field.tsx +++ b/web/src/components/layout-recognize-form-field.tsx @@ -17,7 +17,6 @@ import { export const enum ParseDocumentType { DeepDOC = 'DeepDOC', PlainText = 'Plain Text', - MinerU = 'MinerU', Docling = 'Docling', TCADPParser = 'TCADP Parser', } @@ -44,7 +43,6 @@ export function LayoutRecognizeFormField({ : [ ParseDocumentType.DeepDOC, ParseDocumentType.PlainText, - ParseDocumentType.MinerU, ParseDocumentType.Docling, ParseDocumentType.TCADPParser, ].map((x) => ({ @@ -52,7 +50,10 @@ export function LayoutRecognizeFormField({ value: x, })); - const image2TextList = allOptions[LlmModelType.Image2text].map((x) => { + const image2TextList = [ + ...allOptions[LlmModelType.Image2text], + ...allOptions[LlmModelType.Ocr], + ].map((x) => { return { ...x, options: x.options.map((y) => { diff --git a/web/src/components/svg-icon.tsx b/web/src/components/svg-icon.tsx index d21e55ace..756336621 100644 --- a/web/src/components/svg-icon.tsx +++ b/web/src/components/svg-icon.tsx @@ -69,6 +69,7 @@ export const LlmIcon = ({ LLMFactory.TogetherAI, LLMFactory.Meituan, LLMFactory.Longcat, + LLMFactory.MinerU, ]; let icon = useMemo(() => { const icontemp = IconMap[name as keyof typeof IconMap]; @@ -88,6 +89,7 @@ export const LlmIcon = ({ // LLMFactory.MiniMax, LLMFactory.Gemini, LLMFactory.StepFun, + LLMFactory.MinerU, // LLMFactory.DeerAPI, ]; if (svgIcons.includes(name as LLMFactory)) { diff --git a/web/src/constants/knowledge.ts b/web/src/constants/knowledge.ts index 130b7ed91..afd2e218b 100644 --- a/web/src/constants/knowledge.ts +++ b/web/src/constants/knowledge.ts @@ -62,6 +62,7 @@ export enum LlmModelType { Speech2text = 'speech2text', Rerank = 'rerank', TTS = 'tts', + Ocr = 'ocr', } export enum KnowledgeSearchParams { diff --git a/web/src/constants/llm.ts b/web/src/constants/llm.ts index a5f5e4b82..1ff5f5387 100644 --- a/web/src/constants/llm.ts +++ b/web/src/constants/llm.ts @@ -60,6 +60,7 @@ export enum LLMFactory { DeerAPI = 'DeerAPI', JiekouAI = 'Jiekou.AI', Builtin = 'Builtin', + MinerU = 'MinerU', } // Please lowercase the file name @@ -125,6 +126,7 @@ export const IconMap = { [LLMFactory.DeerAPI]: 'deerapi', [LLMFactory.JiekouAI]: 'jiekouai', [LLMFactory.Builtin]: 'builtin', + [LLMFactory.MinerU]: 'mineru', }; export const APIMapUrl = { diff --git a/web/src/hooks/use-llm-request.tsx b/web/src/hooks/use-llm-request.tsx index 3436b7506..cdd46c222 100644 --- a/web/src/hooks/use-llm-request.tsx +++ b/web/src/hooks/use-llm-request.tsx @@ -147,6 +147,7 @@ export const useSelectLlmOptionsByModelType = () => { ), [LlmModelType.Rerank]: groupOptionsByModelType(LlmModelType.Rerank), [LlmModelType.TTS]: groupOptionsByModelType(LlmModelType.TTS), + [LlmModelType.Ocr]: groupOptionsByModelType(LlmModelType.Ocr), }; }; @@ -245,7 +246,7 @@ export const useSelectLlmList = () => { name: key, logo: factoryList.find((x) => x.name === key)?.logo ?? '', ...value, - llm: value.llm.map((x) => ({ ...x, name: x.name })), + llm: value.llm?.map((x) => ({ ...x, name: x.name })), })); }, [myLlmList, factoryList]); diff --git a/web/src/interfaces/request/llm.ts b/web/src/interfaces/request/llm.ts index 05f8f470e..a5ca42fdc 100644 --- a/web/src/interfaces/request/llm.ts +++ b/web/src/interfaces/request/llm.ts @@ -3,7 +3,7 @@ export interface IAddLlmRequestBody { llm_name: string; model_type: string; api_base?: string; // chat|embedding|speech2text|image2text - api_key: string; + api_key: string | Record; max_tokens: number; } diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 6d1049eda..00b8552ca 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -1064,6 +1064,21 @@ Example: Virtual Hosted Style`, modelsToBeAddedTooltip: 'If your model provider is not listed but claims to be "OpenAI-compatible", select the OpenAI-API-compatible card to add the relevant model(s). ', mcp: 'MCP', + mineru: { + modelNameRequired: 'Model name is required', + apiserver: 'MinerU API Server Configuration', + outputDir: 'MinerU Output Directory Path', + backend: 'MinerU Processing Backend Type', + serverUrl: 'MinerU Server URL Address', + deleteOutput: 'Delete Output Files After Processing', + selectBackend: 'Select processing backend', + backendOptions: { + pipeline: 'Standard Pipeline Processing', + vlmTransformers: 'Vision Language Model with Transformers', + vlmVllmEngine: 'Vision Language Model with vLLM Engine', + vlmHttpClient: 'Vision Language Model via HTTP Client', + }, + }, }, message: { registered: 'Registered!', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index c06ec2886..b09b6ca21 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -936,6 +936,21 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 modelsToBeAddedTooltip: '如果你的模型供应商在这里没有列出,但是宣称 OpenAI-compatible,可以通过选择卡片 OpenAI-API-compatible 设置相关模型。', mcp: 'MCP', + mineru: { + modelNameRequired: '模型名称为必填项', + apiserver: 'MinerU API服务器配置', + outputDir: 'MinerU输出目录路径', + backend: 'MinerU处理后端类型', + serverUrl: 'MinerU服务器URL地址', + deleteOutput: '处理完成后删除输出文件', + selectBackend: '选择处理后端', + backendOptions: { + pipeline: '标准流水线处理', + vlmTransformers: '基于Transformers的视觉语言模型', + vlmVllmEngine: '基于vLLM引擎的视觉语言模型', + vlmHttpClient: '通过HTTP客户端连接的视觉语言模型', + }, + }, }, message: { registered: '注册成功', diff --git a/web/src/pages/user-setting/setting-model/components/modal-card.tsx b/web/src/pages/user-setting/setting-model/components/modal-card.tsx index 3eac19d92..70a232147 100644 --- a/web/src/pages/user-setting/setting-model/components/modal-card.tsx +++ b/web/src/pages/user-setting/setting-model/components/modal-card.tsx @@ -73,7 +73,7 @@ export const ModelProviderCard: FC = ({ {/* Header */}
- +
{item.name} diff --git a/web/src/pages/user-setting/setting-model/components/used-model.tsx b/web/src/pages/user-setting/setting-model/components/used-model.tsx index 3973f6fb6..91c7bc066 100644 --- a/web/src/pages/user-setting/setting-model/components/used-model.tsx +++ b/web/src/pages/user-setting/setting-model/components/used-model.tsx @@ -9,7 +9,7 @@ export const UsedModel = ({ handleAddModel: (factory: string) => void; handleEditModel: (model: any, factory: LlmItem) => void; }) => { - const { factoryList, myLlmList: llmList, loading } = useSelectLlmList(); + const { myLlmList: llmList } = useSelectLlmList(); return (
diff --git a/web/src/pages/user-setting/setting-model/hooks.tsx b/web/src/pages/user-setting/setting-model/hooks.tsx index 4d0708a42..9fc620d3d 100644 --- a/web/src/pages/user-setting/setting-model/hooks.tsx +++ b/web/src/pages/user-setting/setting-model/hooks.tsx @@ -1,3 +1,4 @@ +import { LLMFactory } from '@/constants/llm'; import { useSetModalState, useShowDeleteConfirm } from '@/hooks/common-hooks'; import { IApiKeySavingParams, @@ -16,6 +17,7 @@ import { getRealModelName } from '@/utils/llm-util'; import { useQueryClient } from '@tanstack/react-query'; import { useCallback, useState } from 'react'; import { ApiKeyPostBody } from '../interface'; +import { MinerUFormValues } from './modal/mineru-modal'; type SavingParamsState = Omit; @@ -459,3 +461,42 @@ export const useHandleDeleteFactory = (llmFactory: string) => { return { handleDeleteFactory, deleteFactory }; }; + +export const useSubmitMinerU = () => { + const { addLlm, loading } = useAddLlm(); + const { + visible: mineruVisible, + hideModal: hideMineruModal, + showModal: showMineruModal, + } = useSetModalState(); + + const onMineruOk = useCallback( + async (payload: MinerUFormValues) => { + const cfg = { + ...payload, + mineru_delete_output: payload.mineru_delete_output ?? true ? '1' : '0', + }; + const req: IAddLlmRequestBody = { + llm_factory: LLMFactory.MinerU, + llm_name: payload.llm_name, + model_type: 'ocr', + api_key: cfg, + api_base: '', + max_tokens: 0, + }; + const ret = await addLlm(req); + if (ret === 0) { + hideMineruModal(); + } + }, + [addLlm, hideMineruModal], + ); + + return { + mineruVisible, + hideMineruModal, + showMineruModal, + onMineruOk, + mineruLoading: loading, + }; +}; diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index 1e7086019..af7907bb0 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -13,6 +13,7 @@ import { useSubmitFishAudio, useSubmitGoogle, useSubmitHunyuan, + useSubmitMinerU, useSubmitOllama, useSubmitSpark, useSubmitSystemModelSetting, @@ -26,6 +27,7 @@ import BedrockModal from './modal/bedrock-modal'; import FishAudioModal from './modal/fish-audio-modal'; import GoogleModal from './modal/google-modal'; import HunyuanModal from './modal/hunyuan-modal'; +import MinerUModal from './modal/mineru-modal'; import TencentCloudModal from './modal/next-tencent-modal'; import OllamaModal from './modal/ollama-modal'; import SparkModal from './modal/spark-modal'; @@ -128,6 +130,14 @@ const ModelProviders = () => { AzureAddingLoading, } = useSubmitAzure(); + const { + mineruVisible, + hideMineruModal, + showMineruModal, + onMineruOk, + mineruLoading, + } = useSubmitMinerU(); + const ModalMap = useMemo( () => ({ [LLMFactory.Bedrock]: showBedrockAddingModal, @@ -139,17 +149,19 @@ const ModelProviders = () => { [LLMFactory.TencentCloud]: showTencentCloudAddingModal, [LLMFactory.GoogleCloud]: showGoogleAddingModal, [LLMFactory.AzureOpenAI]: showAzureAddingModal, + [LLMFactory.MinerU]: showMineruModal, }), [ showBedrockAddingModal, showVolcAddingModal, showHunyuanAddingModal, - showTencentCloudAddingModal, showSparkAddingModal, showyiyanAddingModal, showFishAudioAddingModal, + showTencentCloudAddingModal, showGoogleAddingModal, showAzureAddingModal, + showMineruModal, ], ); @@ -289,6 +301,12 @@ const ModelProviders = () => { loading={AzureAddingLoading} llmFactory={LLMFactory.AzureOpenAI} > +
); }; diff --git a/web/src/pages/user-setting/setting-model/modal/mineru-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/mineru-modal/index.tsx new file mode 100644 index 000000000..7833467db --- /dev/null +++ b/web/src/pages/user-setting/setting-model/modal/mineru-modal/index.tsx @@ -0,0 +1,148 @@ +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { ButtonLoading } from '@/components/ui/button'; +import { + Dialog, + DialogContent, + DialogFooter, + DialogHeader, + DialogTitle, +} from '@/components/ui/dialog'; +import { Form } from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { RAGFlowSelect } from '@/components/ui/select'; +import { Switch } from '@/components/ui/switch'; +import { LLMFactory } from '@/constants/llm'; +import { IModalProps } from '@/interfaces/common'; +import { buildOptions } from '@/utils/form'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { t } from 'i18next'; +import { useForm } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; +import { LLMHeader } from '../../components/llm-header'; + +const FormSchema = z.object({ + llm_name: z.string().min(1, { + message: t('setting.mineru.modelNameRequired'), + }), + mineru_apiserver: z.string().optional(), + mineru_output_dir: z.string().optional(), + mineru_backend: z.enum([ + 'pipeline', + 'vlm-transformers', + 'vlm-vllm-engine', + 'vlm-http-client', + ]), + mineru_server_url: z.string().optional(), + mineru_delete_output: z.boolean(), +}); + +export type MinerUFormValues = z.infer; + +const MinerUModal = ({ + visible, + hideModal, + onOk, + loading, +}: IModalProps) => { + const { t } = useTranslation(); + + const backendOptions = buildOptions([ + 'pipeline', + 'vlm-transformers', + 'vlm-vllm-engine', + 'vlm-http-client', + ]); + + const form = useForm({ + resolver: zodResolver(FormSchema), + defaultValues: { + mineru_backend: 'pipeline', + mineru_delete_output: true, + }, + }); + + const handleOk = async (values: MinerUFormValues) => { + const ret = await onOk?.(values as any); + if (ret) { + hideModal?.(); + } + }; + + return ( + + + + + + + +
+ + + + + + + + + + + + {(field) => ( + + )} + + + + + + {(field) => ( + + )} + +
+ + + + {t('common.save', 'Save')} + + +
+
+ ); +}; + +export default MinerUModal;