Merge branch 'main' of github.com:infiniflow/ragflow into feature/1203
This commit is contained in:
commit
ec02428dbd
36 changed files with 604 additions and 979 deletions
|
|
@ -78,12 +78,12 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
||||||
# A modern version of cargo is needed for the latest version of the Rust compiler.
|
# A modern version of cargo is needed for the latest version of the Rust compiler.
|
||||||
RUN apt update && apt install -y curl build-essential \
|
RUN apt update && apt install -y curl build-essential \
|
||||||
&& if [ "$NEED_MIRROR" == "1" ]; then \
|
&& if [ "$NEED_MIRROR" == "1" ]; then \
|
||||||
# Use TUNA mirrors for rustup/rust dist files
|
# Use TUNA mirrors for rustup/rust dist files \
|
||||||
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
|
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
|
||||||
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
|
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
|
||||||
echo "Using TUNA mirrors for Rustup."; \
|
echo "Using TUNA mirrors for Rustup."; \
|
||||||
fi; \
|
fi; \
|
||||||
# Force curl to use HTTP/1.1
|
# Force curl to use HTTP/1.1 \
|
||||||
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
|
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
|
||||||
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc
|
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -534,10 +534,12 @@ class Canvas(Graph):
|
||||||
yield decorate("message", {"content": cpn_obj.output("content")})
|
yield decorate("message", {"content": cpn_obj.output("content")})
|
||||||
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
||||||
|
|
||||||
if isinstance(cpn_obj.output("attachment"), tuple):
|
message_end = {}
|
||||||
yield decorate("message", {"attachment": cpn_obj.output("attachment")})
|
if isinstance(cpn_obj.output("attachment"), dict):
|
||||||
|
message_end["attachment"] = cpn_obj.output("attachment")
|
||||||
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
|
if cite:
|
||||||
|
message_end["reference"] = self.get_reference()
|
||||||
|
yield decorate("message_end", message_end)
|
||||||
|
|
||||||
while partials:
|
while partials:
|
||||||
_cpn_obj = self.get_component_obj(partials[0])
|
_cpn_obj = self.get_component_obj(partials[0])
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
class UserFillUpParam(ComponentParamBase):
|
class UserFillUpParam(ComponentParamBase):
|
||||||
|
|
@ -63,6 +64,13 @@ class UserFillUp(ComponentBase):
|
||||||
for k, v in kwargs.get("inputs", {}).items():
|
for k, v in kwargs.get("inputs", {}).items():
|
||||||
if self.check_if_canceled("UserFillUp processing"):
|
if self.check_if_canceled("UserFillUp processing"):
|
||||||
return
|
return
|
||||||
|
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
|
||||||
|
if v.get("optional") and v.get("value", None) is None:
|
||||||
|
v = None
|
||||||
|
else:
|
||||||
|
v = FileService.get_files([v["value"]])
|
||||||
|
else:
|
||||||
|
v = v.get("value")
|
||||||
self.set_output(k, v)
|
self.set_output(k, v)
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from quart import Response, request
|
||||||
from api.apps import current_user, login_required
|
from api.apps import current_user, login_required
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
|
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
|
|
@ -218,10 +218,10 @@ async def completion():
|
||||||
dia.llm_setting = chat_model_config
|
dia.llm_setting = chat_model_config
|
||||||
|
|
||||||
is_embedded = bool(chat_model_id)
|
is_embedded = bool(chat_model_id)
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal dia, msg, req, conv
|
nonlocal dia, msg, req, conv
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **req):
|
async for ans in async_chat(dia, msg, True, **req):
|
||||||
ans = structure_answer(conv, ans, message_id, conv.id)
|
ans = structure_answer(conv, ans, message_id, conv.id)
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
if not is_embedded:
|
if not is_embedded:
|
||||||
|
|
@ -241,7 +241,7 @@ async def completion():
|
||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, **req):
|
async for ans in async_chat(dia, msg, **req):
|
||||||
answer = structure_answer(conv, ans, message_id, conv.id)
|
answer = structure_answer(conv, ans, message_id, conv.id)
|
||||||
if not is_embedded:
|
if not is_embedded:
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
|
|
@ -406,10 +406,10 @@ async def ask_about():
|
||||||
if search_app:
|
if search_app:
|
||||||
search_config = search_app.get("search_config", {})
|
search_config = search_app.get("search_config", {})
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
|
||||||
|
|
@ -34,8 +34,9 @@ async def set_api_key():
|
||||||
if not all([secret_key, public_key, host]):
|
if not all([secret_key, public_key, host]):
|
||||||
return get_error_data_result(message="Missing required fields")
|
return get_error_data_result(message="Missing required fields")
|
||||||
|
|
||||||
|
current_user_id = current_user.id
|
||||||
langfuse_keys = dict(
|
langfuse_keys = dict(
|
||||||
tenant_id=current_user.id,
|
tenant_id=current_user_id,
|
||||||
secret_key=secret_key,
|
secret_key=secret_key,
|
||||||
public_key=public_key,
|
public_key=public_key,
|
||||||
host=host,
|
host=host,
|
||||||
|
|
@ -45,23 +46,24 @@ async def set_api_key():
|
||||||
if not langfuse.auth_check():
|
if not langfuse.auth_check():
|
||||||
return get_error_data_result(message="Invalid Langfuse keys")
|
return get_error_data_result(message="Invalid Langfuse keys")
|
||||||
|
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
try:
|
try:
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
TenantLangfuseService.save(**langfuse_keys)
|
TenantLangfuseService.save(**langfuse_keys)
|
||||||
else:
|
else:
|
||||||
TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys)
|
TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys)
|
||||||
return get_json_result(data=langfuse_keys)
|
return get_json_result(data=langfuse_keys)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/api_key", methods=["GET"]) # noqa: F821
|
@manager.route("/api_key", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request()
|
@validate_request()
|
||||||
def get_api_key():
|
def get_api_key():
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id)
|
current_user_id = current_user.id
|
||||||
|
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id)
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
return get_json_result(message="Have not record any Langfuse keys.")
|
return get_json_result(message="Have not record any Langfuse keys.")
|
||||||
|
|
||||||
|
|
@ -72,7 +74,7 @@ def get_api_key():
|
||||||
except langfuse.api.core.api_error.ApiError as api_err:
|
except langfuse.api.core.api_error.ApiError as api_err:
|
||||||
return get_json_result(message=f"Error from Langfuse: {api_err}")
|
return get_json_result(message=f"Error from Langfuse: {api_err}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
|
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
|
||||||
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
|
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
|
||||||
|
|
@ -84,7 +86,8 @@ def get_api_key():
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request()
|
@validate_request()
|
||||||
def delete_api_key():
|
def delete_api_key():
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
current_user_id = current_user.id
|
||||||
|
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
return get_json_result(message="Have not record any Langfuse keys.")
|
return get_json_result(message="Have not record any Langfuse keys.")
|
||||||
|
|
||||||
|
|
@ -93,4 +96,4 @@ def delete_api_key():
|
||||||
TenantLangfuseService.delete_model(langfuse_entry)
|
TenantLangfuseService.delete_model(langfuse_entry)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ async def set_api_key():
|
||||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||||
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||||
try:
|
try:
|
||||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
||||||
if m.find("**ERROR**") >= 0:
|
if m.find("**ERROR**") >= 0:
|
||||||
raise Exception(m)
|
raise Exception(m)
|
||||||
chat_passed = True
|
chat_passed = True
|
||||||
|
|
@ -217,7 +217,7 @@ async def add_llm():
|
||||||
**extra,
|
**extra,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||||
if not tc and m.find("**ERROR**:") >= 0:
|
if not tc and m.find("**ERROR**:") >= 0:
|
||||||
raise Exception(m)
|
raise Exception(m)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from quart import request, make_response
|
from quart import request, make_response
|
||||||
|
|
@ -29,6 +29,7 @@ from api.db import FileType
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.utils.file_utils import filename_type
|
from api.utils.file_utils import filename_type
|
||||||
|
from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||||
from common import settings
|
from common import settings
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
|
|
||||||
|
|
@ -629,6 +630,19 @@ async def get(tenant_id, file_id):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@token_required
|
||||||
|
async def download_attachment(tenant_id,attachment_id):
|
||||||
|
try:
|
||||||
|
ext = request.args.get("ext", "markdown")
|
||||||
|
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||||
|
response = await make_response(data)
|
||||||
|
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
|
|
|
||||||
|
|
@ -26,9 +26,10 @@ from api.db.db_models import APIToken
|
||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
||||||
from api.db.services.canvas_service import completion as agent_completion
|
from api.db.services.canvas_service import completion as agent_completion
|
||||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
from api.db.services.conversation_service import ConversationService
|
||||||
from api.db.services.conversation_service import completion as rag_completion
|
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter
|
from api.db.services.conversation_service import async_completion as rag_completion
|
||||||
|
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
|
@ -141,7 +142,7 @@ async def chat_completion(tenant_id, chat_id):
|
||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in rag_completion(tenant_id, chat_id, **req):
|
async for ans in rag_completion(tenant_id, chat_id, **req):
|
||||||
answer = ans
|
answer = ans
|
||||||
break
|
break
|
||||||
return get_result(data=answer)
|
return get_result(data=answer)
|
||||||
|
|
@ -245,7 +246,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||||
# The value for the usage field on all chunks except for the last one will be null.
|
# The value for the usage field on all chunks except for the last one will be null.
|
||||||
# The usage field on the last chunk contains token usage statistics for the entire request.
|
# The usage field on the last chunk contains token usage statistics for the entire request.
|
||||||
# The choices field on the last chunk will always be an empty array [].
|
# The choices field on the last chunk will always be an empty array [].
|
||||||
def streamed_response_generator(chat_id, dia, msg):
|
async def streamed_response_generator(chat_id, dia, msg):
|
||||||
token_used = 0
|
token_used = 0
|
||||||
answer_cache = ""
|
answer_cache = ""
|
||||||
reasoning_cache = ""
|
reasoning_cache = ""
|
||||||
|
|
@ -274,7 +275,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||||
last_ans = ans
|
last_ans = ans
|
||||||
answer = ans["answer"]
|
answer = ans["answer"]
|
||||||
|
|
||||||
|
|
@ -342,7 +343,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||||
# focus answer content only
|
# focus answer content only
|
||||||
answer = ans
|
answer = ans
|
||||||
break
|
break
|
||||||
|
|
@ -733,10 +734,10 @@ async def ask_about(tenant_id):
|
||||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||||
uid = tenant_id
|
uid = tenant_id
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps(
|
yield "data:" + json.dumps(
|
||||||
|
|
@ -827,7 +828,7 @@ async def chatbot_completions(dialog_id):
|
||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
for answer in iframe_completion(dialog_id, **req):
|
async for answer in iframe_completion(dialog_id, **req):
|
||||||
return get_result(data=answer)
|
return get_result(data=answer)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -918,10 +919,10 @@ async def ask_about_embedded():
|
||||||
if search_app := SearchService.get_detail(search_id):
|
if search_app := SearchService.get_detail(search_id):
|
||||||
search_config = search_app.get("search_config", {})
|
search_config = search_app.get("search_config", {})
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps(
|
yield "data:" + json.dumps(
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from common.constants import StatusEnum
|
||||||
from api.db.db_models import Conversation, DB
|
from api.db.db_models import Conversation, DB
|
||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.dialog_service import DialogService, chat
|
from api.db.services.dialog_service import DialogService, async_chat
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id):
|
||||||
conv.reference[-1] = reference
|
conv.reference[-1] = reference
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
||||||
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
|
||||||
assert name, "`name` can not be empty."
|
assert name, "`name` can not be empty."
|
||||||
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||||
assert dia, "You do not own the chat."
|
assert dia, "You do not own the chat."
|
||||||
|
|
@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **kwargs):
|
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||||
ans = structure_answer(conv, ans, message_id, session_id)
|
ans = structure_answer(conv, ans, message_id, session_id)
|
||||||
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
|
|
@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, **kwargs):
|
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||||
answer = structure_answer(conv, ans, message_id, session_id)
|
answer = structure_answer(conv, ans, message_id, session_id)
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
break
|
break
|
||||||
yield answer
|
yield answer
|
||||||
|
|
||||||
|
async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
||||||
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
|
||||||
e, dia = DialogService.get_by_id(dialog_id)
|
e, dia = DialogService.get_by_id(dialog_id)
|
||||||
assert e, "Dialog not found"
|
assert e, "Dialog not found"
|
||||||
if not session_id:
|
if not session_id:
|
||||||
|
|
@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **kwargs):
|
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||||
ans = structure_answer(conv, ans, message_id, session_id)
|
ans = structure_answer(conv, ans, message_id, session_id)
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||||
ensure_ascii=False) + "\n\n"
|
ensure_ascii=False) + "\n\n"
|
||||||
|
|
@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, **kwargs):
|
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||||
answer = structure_answer(conv, ans, message_id, session_id)
|
answer = structure_answer(conv, ans, message_id, session_id)
|
||||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -178,7 +178,8 @@ class DialogService(CommonService):
|
||||||
offset += limit
|
offset += limit
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def chat_solo(dialog, messages, stream=True):
|
|
||||||
|
async def async_chat_solo(dialog, messages, stream=True):
|
||||||
attachments = ""
|
attachments = ""
|
||||||
if "files" in messages[-1]:
|
if "files" in messages[-1]:
|
||||||
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||||
|
|
@ -197,7 +198,8 @@ def chat_solo(dialog, messages, stream=True):
|
||||||
if stream:
|
if stream:
|
||||||
last_ans = ""
|
last_ans = ""
|
||||||
delta_ans = ""
|
delta_ans = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
answer = ""
|
||||||
|
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||||
answer = ans
|
answer = ans
|
||||||
delta_ans = ans[len(last_ans):]
|
delta_ans = ans[len(last_ans):]
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
|
|
@ -208,7 +210,7 @@ def chat_solo(dialog, messages, stream=True):
|
||||||
if delta_ans:
|
if delta_ans:
|
||||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||||
else:
|
else:
|
||||||
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||||
user_content = msg[-1].get("content", "[content not available]")
|
user_content = msg[-1].get("content", "[content not available]")
|
||||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
||||||
|
|
@ -347,13 +349,12 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
||||||
return []
|
return []
|
||||||
return list(doc_ids)
|
return list(doc_ids)
|
||||||
|
|
||||||
|
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||||
def chat(dialog, messages, stream=True, **kwargs):
|
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||||
for ans in chat_solo(dialog, messages, stream):
|
async for ans in async_chat_solo(dialog, messages, stream):
|
||||||
yield ans
|
yield ans
|
||||||
return None
|
return
|
||||||
|
|
||||||
chat_start_ts = timer()
|
chat_start_ts = timer()
|
||||||
|
|
||||||
|
|
@ -400,7 +401,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||||
if ans:
|
if ans:
|
||||||
yield ans
|
yield ans
|
||||||
return None
|
return
|
||||||
|
|
||||||
for p in prompt_config["parameters"]:
|
for p in prompt_config["parameters"]:
|
||||||
if p["key"] == "knowledge":
|
if p["key"] == "knowledge":
|
||||||
|
|
@ -508,7 +509,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
empty_res = prompt_config["empty_response"]
|
empty_res = prompt_config["empty_response"]
|
||||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||||
"audio_binary": tts(tts_mdl, empty_res)}
|
"audio_binary": tts(tts_mdl, empty_res)}
|
||||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||||
|
return
|
||||||
|
|
||||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||||
gen_conf = dialog.llm_setting
|
gen_conf = dialog.llm_setting
|
||||||
|
|
@ -612,7 +614,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
if stream:
|
if stream:
|
||||||
last_ans = ""
|
last_ans = ""
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||||
if thought:
|
if thought:
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
answer = ans
|
answer = ans
|
||||||
|
|
@ -626,14 +628,14 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
yield decorate_answer(thought + answer)
|
yield decorate_answer(thought + answer)
|
||||||
else:
|
else:
|
||||||
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
|
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||||
user_content = msg[-1].get("content", "[content not available]")
|
user_content = msg[-1].get("content", "[content not available]")
|
||||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||||
res = decorate_answer(answer)
|
res = decorate_answer(answer)
|
||||||
res["audio_binary"] = tts(tts_mdl, answer)
|
res["audio_binary"] = tts(tts_mdl, answer)
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
return None
|
return
|
||||||
|
|
||||||
|
|
||||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||||
|
|
@ -805,8 +807,7 @@ def tts(tts_mdl, text):
|
||||||
return None
|
return None
|
||||||
return binascii.hexlify(bin).decode("utf-8")
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
|
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||||
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|
||||||
doc_ids = search_config.get("doc_ids", [])
|
doc_ids = search_config.get("doc_ids", [])
|
||||||
rerank_mdl = None
|
rerank_mdl = None
|
||||||
kb_ids = search_config.get("kb_ids", kb_ids)
|
kb_ids = search_config.get("kb_ids", kb_ids)
|
||||||
|
|
@ -880,7 +881,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||||
return {"answer": answer, "reference": refs}
|
return {"answer": answer, "reference": refs}
|
||||||
|
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
||||||
answer = ans
|
answer = ans
|
||||||
yield {"answer": answer, "reference": {}}
|
yield {"answer": answer, "reference": {}}
|
||||||
yield decorate_answer(answer)
|
yield decorate_answer(answer)
|
||||||
|
|
|
||||||
|
|
@ -25,14 +25,17 @@ Provides functionality for evaluating RAG system performance including:
|
||||||
- Configuration recommendations
|
- Configuration recommendations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
|
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.dialog_service import DialogService, chat
|
from api.db.services.dialog_service import DialogService
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.time_utils import current_timestamp
|
from common.time_utils import current_timestamp
|
||||||
from common.constants import StatusEnum
|
from common.constants import StatusEnum
|
||||||
|
|
@ -357,6 +360,42 @@ class EvaluationService(CommonService):
|
||||||
answer = ""
|
answer = ""
|
||||||
retrieved_chunks = []
|
retrieved_chunks = []
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_from_async_gen(async_gen):
|
||||||
|
result_queue: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
async def consume():
|
||||||
|
try:
|
||||||
|
async for item in async_gen:
|
||||||
|
result_queue.put(item)
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put(e)
|
||||||
|
finally:
|
||||||
|
result_queue.put(StopIteration)
|
||||||
|
|
||||||
|
loop.run_until_complete(consume())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
threading.Thread(target=runner, daemon=True).start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = result_queue.get()
|
||||||
|
if item is StopIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
def chat(dialog, messages, stream=True, **kwargs):
|
||||||
|
from api.db.services.dialog_service import async_chat
|
||||||
|
|
||||||
|
return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
|
||||||
|
|
||||||
for ans in chat(dialog, messages, stream=False):
|
for ans in chat(dialog, messages, stream=False):
|
||||||
if isinstance(ans, dict):
|
if isinstance(ans, dict):
|
||||||
answer = ans.get("answer", "")
|
answer = ans.get("answer", "")
|
||||||
|
|
|
||||||
|
|
@ -16,15 +16,17 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
from common.token_utils import num_tokens_from_string
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
from common.constants import LLMType
|
|
||||||
from api.db.db_models import LLM
|
from api.db.db_models import LLM
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||||
|
from common.constants import LLMType
|
||||||
|
from common.token_utils import num_tokens_from_string
|
||||||
|
|
||||||
|
|
||||||
class LLMService(CommonService):
|
class LLMService(CommonService):
|
||||||
|
|
@ -33,6 +35,7 @@ class LLMService(CommonService):
|
||||||
|
|
||||||
def get_init_tenant_llm(user_id):
|
def get_init_tenant_llm(user_id):
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
tenant_llm = []
|
tenant_llm = []
|
||||||
|
|
||||||
model_configs = {
|
model_configs = {
|
||||||
|
|
@ -193,7 +196,7 @@ class LLMBundle(LLM4Tenant):
|
||||||
generation = self.langfuse.start_generation(
|
generation = self.langfuse.start_generation(
|
||||||
trace_context=self.trace_context,
|
trace_context=self.trace_context,
|
||||||
name="stream_transcription",
|
name="stream_transcription",
|
||||||
metadata={"model": self.llm_name}
|
metadata={"model": self.llm_name},
|
||||||
)
|
)
|
||||||
final_text = ""
|
final_text = ""
|
||||||
used_tokens = 0
|
used_tokens = 0
|
||||||
|
|
@ -217,32 +220,34 @@ class LLMBundle(LLM4Tenant):
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation.update(
|
generation.update(
|
||||||
output={"output": final_text},
|
output={"output": final_text},
|
||||||
usage_details={"total_tokens": used_tokens}
|
usage_details={"total_tokens": used_tokens},
|
||||||
)
|
)
|
||||||
generation.end()
|
generation.end()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
|
generation = self.langfuse.start_generation(
|
||||||
full_text, used_tokens = mdl.transcription(audio)
|
trace_context=self.trace_context,
|
||||||
if not TenantLLMService.increase_usage(
|
name="stream_transcription",
|
||||||
self.tenant_id, self.llm_type, used_tokens
|
metadata={"model": self.llm_name},
|
||||||
):
|
|
||||||
logging.error(
|
|
||||||
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
full_text, used_tokens = mdl.transcription(audio)
|
||||||
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||||
|
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
|
||||||
|
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation.update(
|
generation.update(
|
||||||
output={"output": full_text},
|
output={"output": full_text},
|
||||||
usage_details={"total_tokens": used_tokens}
|
usage_details={"total_tokens": used_tokens},
|
||||||
)
|
)
|
||||||
generation.end()
|
generation.end()
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"event": "final",
|
"event": "final",
|
||||||
"text": full_text,
|
"text": full_text,
|
||||||
"streaming": False
|
"streaming": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
def tts(self, text: str) -> Generator[bytes, None, None]:
|
def tts(self, text: str) -> Generator[bytes, None, None]:
|
||||||
|
|
@ -289,61 +294,79 @@ class LLMBundle(LLM4Tenant):
|
||||||
return kwargs
|
return kwargs
|
||||||
else:
|
else:
|
||||||
return {k: v for k, v in kwargs.items() if k in allowed_params}
|
return {k: v for k, v in kwargs.items() if k in allowed_params}
|
||||||
|
|
||||||
|
def _run_coroutine_sync(self, coro):
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
result_queue: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
try:
|
||||||
|
result_queue.put((True, asyncio.run(coro)))
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put((False, e))
|
||||||
|
|
||||||
|
thread = threading.Thread(target=runner, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
success, value = result_queue.get_nowait()
|
||||||
|
if success:
|
||||||
|
return value
|
||||||
|
raise value
|
||||||
|
|
||||||
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
||||||
if self.langfuse:
|
return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
|
||||||
|
|
||||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
|
||||||
if self.is_tools and self.mdl.is_tools:
|
result_queue: queue.Queue = queue.Queue()
|
||||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
|
||||||
|
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
def runner():
|
||||||
txt, used_tokens = chat_partial(**use_kwargs)
|
loop = asyncio.new_event_loop()
|
||||||
txt = self._remove_reasoning_content(txt)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
if not self.verbose_tool_use:
|
async def consume():
|
||||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
try:
|
||||||
|
async for item in async_gen_fn(*args, **kwargs):
|
||||||
|
result_queue.put(item)
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put(e)
|
||||||
|
finally:
|
||||||
|
result_queue.put(StopIteration)
|
||||||
|
|
||||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
loop.run_until_complete(consume())
|
||||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
loop.close()
|
||||||
|
|
||||||
if self.langfuse:
|
threading.Thread(target=runner, daemon=True).start()
|
||||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
|
||||||
generation.end()
|
|
||||||
|
|
||||||
return txt
|
while True:
|
||||||
|
item = result_queue.get()
|
||||||
|
if item is StopIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
if self.langfuse:
|
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
|
||||||
|
|
||||||
ans = ""
|
ans = ""
|
||||||
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
|
for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
|
||||||
total_tokens = 0
|
|
||||||
if self.is_tools and self.mdl.is_tools:
|
|
||||||
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
|
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
|
||||||
for txt in chat_partial(**use_kwargs):
|
|
||||||
if isinstance(txt, int):
|
if isinstance(txt, int):
|
||||||
total_tokens = txt
|
|
||||||
if self.langfuse:
|
|
||||||
generation.update(output={"output": ans})
|
|
||||||
generation.end()
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if txt.endswith("</think>"):
|
if txt.endswith("</think>"):
|
||||||
ans = ans[: -len("</think>")]
|
ans = txt[: -len("</think>")]
|
||||||
|
continue
|
||||||
|
|
||||||
if not self.verbose_tool_use:
|
if not self.verbose_tool_use:
|
||||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||||
|
|
||||||
ans += txt
|
# cancatination has beend done in async_chat_streamly
|
||||||
|
ans = txt
|
||||||
yield ans
|
yield ans
|
||||||
|
|
||||||
if total_tokens > 0:
|
|
||||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
|
||||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
|
||||||
|
|
||||||
def _bridge_sync_stream(self, gen):
|
def _bridge_sync_stream(self, gen):
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
queue: asyncio.Queue = asyncio.Queue()
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
@ -352,7 +375,7 @@ class LLMBundle(LLM4Tenant):
|
||||||
try:
|
try:
|
||||||
for item in gen:
|
for item in gen:
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e:
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||||
finally:
|
finally:
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||||
|
|
@ -361,18 +384,27 @@ class LLMBundle(LLM4Tenant):
|
||||||
return queue
|
return queue
|
||||||
|
|
||||||
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"):
|
||||||
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
|
base_fn = self.mdl.async_chat_with_tools
|
||||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
elif hasattr(self.mdl, "async_chat"):
|
||||||
|
base_fn = self.mdl.async_chat
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||||
|
|
||||||
|
generation = None
|
||||||
|
if self.langfuse:
|
||||||
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||||
|
|
||||||
|
chat_partial = partial(base_fn, system, history, gen_conf)
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
|
|
||||||
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
|
try:
|
||||||
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
|
txt, used_tokens = await chat_partial(**use_kwargs)
|
||||||
elif hasattr(self.mdl, "async_chat"):
|
except Exception as e:
|
||||||
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
|
if generation:
|
||||||
else:
|
generation.update(output={"error": str(e)})
|
||||||
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
|
generation.end()
|
||||||
|
raise
|
||||||
|
|
||||||
txt = self._remove_reasoning_content(txt)
|
txt = self._remove_reasoning_content(txt)
|
||||||
if not self.verbose_tool_use:
|
if not self.verbose_tool_use:
|
||||||
|
|
@ -381,19 +413,30 @@ class LLMBundle(LLM4Tenant):
|
||||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||||
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||||
|
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||||
|
generation.end()
|
||||||
|
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
ans = ""
|
ans = ""
|
||||||
if self.is_tools and self.mdl.is_tools:
|
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
|
||||||
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||||
else:
|
elif hasattr(self.mdl, "async_chat_streamly"):
|
||||||
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||||
|
|
||||||
|
generation = None
|
||||||
|
if self.langfuse:
|
||||||
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||||
|
|
||||||
if stream_fn:
|
if stream_fn:
|
||||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
|
try:
|
||||||
async for txt in chat_partial(**use_kwargs):
|
async for txt in chat_partial(**use_kwargs):
|
||||||
if isinstance(txt, int):
|
if isinstance(txt, int):
|
||||||
total_tokens = txt
|
total_tokens = txt
|
||||||
|
|
@ -407,23 +450,14 @@ class LLMBundle(LLM4Tenant):
|
||||||
|
|
||||||
ans += txt
|
ans += txt
|
||||||
yield ans
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"error": str(e)})
|
||||||
|
generation.end()
|
||||||
|
raise
|
||||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||||
|
generation.end()
|
||||||
return
|
return
|
||||||
|
|
||||||
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
|
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
|
||||||
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
|
|
||||||
while True:
|
|
||||||
item = await queue.get()
|
|
||||||
if item is StopAsyncIteration:
|
|
||||||
break
|
|
||||||
if isinstance(item, Exception):
|
|
||||||
raise item
|
|
||||||
if isinstance(item, int):
|
|
||||||
total_tokens = item
|
|
||||||
break
|
|
||||||
yield item
|
|
||||||
|
|
||||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
|
||||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ def _convert_message_to_document(
|
||||||
metadata: dict[str, str | list[str]] = {}
|
metadata: dict[str, str | list[str]] = {}
|
||||||
semantic_substring = ""
|
semantic_substring = ""
|
||||||
|
|
||||||
# Only messages from TextChannels will make it here but we have to check for it anyways
|
# Only messages from TextChannels will make it here, but we have to check for it anyway
|
||||||
if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
|
if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
|
||||||
metadata["Channel"] = channel_name
|
metadata["Channel"] = channel_name
|
||||||
semantic_substring += f" in Channel: #{channel_name}"
|
semantic_substring += f" in Channel: #{channel_name}"
|
||||||
|
|
@ -176,7 +176,7 @@ def _manage_async_retrieval(
|
||||||
# parse requested_start_date_string to datetime
|
# parse requested_start_date_string to datetime
|
||||||
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
|
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
|
||||||
|
|
||||||
# Set start_time to the later of start and pull_date, or whichever is provided
|
# Set start_time to the most recent of start and pull_date, or whichever is provided
|
||||||
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
||||||
|
|
||||||
end_time: datetime | None = end
|
end_time: datetime | None = end
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ class RAGFlowHtmlParser:
|
||||||
block_content = []
|
block_content = []
|
||||||
current_content = ""
|
current_content = ""
|
||||||
table_info_list = []
|
table_info_list = []
|
||||||
lask_block_id = None
|
last_block_id = None
|
||||||
for item in parser_result:
|
for item in parser_result:
|
||||||
content = item.get("content")
|
content = item.get("content")
|
||||||
tag_name = item.get("tag_name")
|
tag_name = item.get("tag_name")
|
||||||
|
|
@ -160,11 +160,11 @@ class RAGFlowHtmlParser:
|
||||||
if block_id:
|
if block_id:
|
||||||
if title_flag:
|
if title_flag:
|
||||||
content = f"{TITLE_TAGS[tag_name]} {content}"
|
content = f"{TITLE_TAGS[tag_name]} {content}"
|
||||||
if lask_block_id != block_id:
|
if last_block_id != block_id:
|
||||||
if lask_block_id is not None:
|
if last_block_id is not None:
|
||||||
block_content.append(current_content)
|
block_content.append(current_content)
|
||||||
current_content = content
|
current_content = content
|
||||||
lask_block_id = block_id
|
last_block_id = block_id
|
||||||
else:
|
else:
|
||||||
current_content += (" " if current_content else "") + content
|
current_content += (" " if current_content else "") + content
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -582,7 +582,7 @@ class OCR:
|
||||||
self.crop_image_res_index = 0
|
self.crop_image_res_index = 0
|
||||||
|
|
||||||
def get_rotate_crop_image(self, img, points):
|
def get_rotate_crop_image(self, img, points):
|
||||||
'''
|
"""
|
||||||
img_height, img_width = img.shape[0:2]
|
img_height, img_width = img.shape[0:2]
|
||||||
left = int(np.min(points[:, 0]))
|
left = int(np.min(points[:, 0]))
|
||||||
right = int(np.max(points[:, 0]))
|
right = int(np.max(points[:, 0]))
|
||||||
|
|
@ -591,7 +591,7 @@ class OCR:
|
||||||
img_crop = img[top:bottom, left:right, :].copy()
|
img_crop = img[top:bottom, left:right, :].copy()
|
||||||
points[:, 0] = points[:, 0] - left
|
points[:, 0] = points[:, 0] - left
|
||||||
points[:, 1] = points[:, 1] - top
|
points[:, 1] = points[:, 1] - top
|
||||||
'''
|
"""
|
||||||
assert len(points) == 4, "shape of points must be 4*2"
|
assert len(points) == 4, "shape of points must be 4*2"
|
||||||
img_crop_width = int(
|
img_crop_width = int(
|
||||||
max(
|
max(
|
||||||
|
|
|
||||||
|
|
@ -67,10 +67,10 @@ class DBPostProcess:
|
||||||
[[1, 1], [1, 1]])
|
[[1, 1], [1, 1]])
|
||||||
|
|
||||||
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||||
'''
|
"""
|
||||||
_bitmap: single map with shape (1, H, W),
|
_bitmap: single map with shape (1, H, W),
|
||||||
whose values are binarized as {0, 1}
|
whose values are binarized as {0, 1}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
bitmap = _bitmap
|
bitmap = _bitmap
|
||||||
height, width = bitmap.shape
|
height, width = bitmap.shape
|
||||||
|
|
@ -114,10 +114,10 @@ class DBPostProcess:
|
||||||
return boxes, scores
|
return boxes, scores
|
||||||
|
|
||||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||||
'''
|
"""
|
||||||
_bitmap: single map with shape (1, H, W),
|
_bitmap: single map with shape (1, H, W),
|
||||||
whose values are binarized as {0, 1}
|
whose values are binarized as {0, 1}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
bitmap = _bitmap
|
bitmap = _bitmap
|
||||||
height, width = bitmap.shape
|
height, width = bitmap.shape
|
||||||
|
|
@ -192,9 +192,9 @@ class DBPostProcess:
|
||||||
return box, min(bounding_box[1])
|
return box, min(bounding_box[1])
|
||||||
|
|
||||||
def box_score_fast(self, bitmap, _box):
|
def box_score_fast(self, bitmap, _box):
|
||||||
'''
|
"""
|
||||||
box_score_fast: use bbox mean score as the mean score
|
box_score_fast: use bbox mean score as the mean score
|
||||||
'''
|
"""
|
||||||
h, w = bitmap.shape[:2]
|
h, w = bitmap.shape[:2]
|
||||||
box = _box.copy()
|
box = _box.copy()
|
||||||
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
||||||
|
|
@ -209,9 +209,9 @@ class DBPostProcess:
|
||||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||||
|
|
||||||
def box_score_slow(self, bitmap, contour):
|
def box_score_slow(self, bitmap, contour):
|
||||||
'''
|
"""
|
||||||
box_score_slow: use polyon mean score as the mean score
|
box_score_slow: use polygon mean score as the mean score
|
||||||
'''
|
"""
|
||||||
h, w = bitmap.shape[:2]
|
h, w = bitmap.shape[:2]
|
||||||
contour = contour.copy()
|
contour = contour.copy()
|
||||||
contour = np.reshape(contour, (-1, 2))
|
contour = np.reshape(contour, (-1, 2))
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ services:
|
||||||
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
||||||
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
||||||
|
|
||||||
# Example configration to start Admin server:
|
# Example configuration to start Admin server:
|
||||||
# command:
|
# command:
|
||||||
# - --enable-adminserver
|
# - --enable-adminserver
|
||||||
ports:
|
ports:
|
||||||
|
|
@ -74,7 +74,7 @@ services:
|
||||||
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
||||||
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
||||||
|
|
||||||
# Example configration to start Admin server:
|
# Example configuration to start Admin server:
|
||||||
# command:
|
# command:
|
||||||
# - --enable-adminserver
|
# - --enable-adminserver
|
||||||
ports:
|
ports:
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ See [Build a RAGFlow Docker image](./develop/build_docker_image.mdx).
|
||||||
|
|
||||||
### Cannot access https://huggingface.co
|
### Cannot access https://huggingface.co
|
||||||
|
|
||||||
A locally deployed RAGflow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails:
|
A locally deployed RAGFlow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails:
|
||||||
|
|
||||||
```
|
```
|
||||||
FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/hub/models--InfiniFlow--deepdoc/snapshots/be0c1e50eef6047b412d1800aa89aba4d275f997/ocr.res'
|
FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/hub/models--InfiniFlow--deepdoc/snapshots/be0c1e50eef6047b412d1800aa89aba4d275f997/ocr.res'
|
||||||
|
|
|
||||||
|
|
@ -45,13 +45,13 @@ Click the light bulb icon above the *current* dialogue and scroll down the popup
|
||||||
|
|
||||||
|
|
||||||
| Item name | Description |
|
| Item name | Description |
|
||||||
| ----------------- | --------------------------------------------------------------------------------------------- |
|
| ----------------- |-----------------------------------------------------------------------------------------------|
|
||||||
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
|
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
|
||||||
| Check LLM | Time to validate the specified LLM. |
|
| Check LLM | Time to validate the specified LLM. |
|
||||||
| Create retriever | Time to create a chunk retriever. |
|
| Create retriever | Time to create a chunk retriever. |
|
||||||
| Bind embedding | Time to initialize an embedding model instance. |
|
| Bind embedding | Time to initialize an embedding model instance. |
|
||||||
| Bind LLM | Time to initialize an LLM instance. |
|
| Bind LLM | Time to initialize an LLM instance. |
|
||||||
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
|
| Tune question | Time to optimize the user query using the context of the multi-turn conversation. |
|
||||||
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
||||||
| Generate keywords | Time to extract keywords from the user query. |
|
| Generate keywords | Time to extract keywords from the user query. |
|
||||||
| Retrieval | Time to retrieve the chunks. |
|
| Retrieval | Time to retrieve the chunks. |
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ Please note that rerank models are essential in certain scenarios. There is alwa
|
||||||
| Create retriever | Time to create a chunk retriever. |
|
| Create retriever | Time to create a chunk retriever. |
|
||||||
| Bind embedding | Time to initialize an embedding model instance. |
|
| Bind embedding | Time to initialize an embedding model instance. |
|
||||||
| Bind LLM | Time to initialize an LLM instance. |
|
| Bind LLM | Time to initialize an LLM instance. |
|
||||||
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
|
| Tune question | Time to optimize the user query using the context of the multi-turn conversation. |
|
||||||
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
||||||
| Generate keywords | Time to extract keywords from the user query. |
|
| Generate keywords | Time to extract keywords from the user query. |
|
||||||
| Retrieval | Time to retrieve the chunks. |
|
| Retrieval | Time to retrieve the chunks. |
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ slug: /manage_users_and_services
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
The Admin CLI and Admin Service form a client-server architectural suite for RAGflow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGflow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled.
|
The Admin CLI and Admin Service form a client-server architectural suite for RAGFlow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGFlow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -305,7 +305,7 @@ With the Ollama service running, open a new terminal and run `./ollama pull <mod
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
### 4. Configure RAGflow
|
### 4. Configure RAGFlow
|
||||||
|
|
||||||
To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the configurations in RAGFlow. The steps are identical to those outlined in the *Deploy a local model using Ollama* section:
|
To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the configurations in RAGFlow. The steps are identical to those outlined in the *Deploy a local model using Ollama* section:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
# requires-python = ">=3.10"
|
# requires-python = ">=3.10"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "nltk",
|
# "nltk",
|
||||||
|
# "huggingface-hub"
|
||||||
# ]
|
# ]
|
||||||
# ///
|
# ///
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
'''
|
"""
|
||||||
The example is about CRUD operations (Create, Read, Update, Delete) on a dataset.
|
The example is about CRUD operations (Create, Read, Update, Delete) on a dataset.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
from ragflow_sdk import RAGFlow
|
from ragflow_sdk import RAGFlow
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,6 @@ JSON_RESPONSE = True
|
||||||
|
|
||||||
class RAGFlowConnector:
|
class RAGFlowConnector:
|
||||||
_MAX_DATASET_CACHE = 32
|
_MAX_DATASET_CACHE = 32
|
||||||
_MAX_DOCUMENT_CACHE = 128
|
|
||||||
_CACHE_TTL = 300
|
_CACHE_TTL = 300
|
||||||
|
|
||||||
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
|
_dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts)
|
||||||
|
|
@ -116,8 +115,6 @@ class RAGFlowConnector:
|
||||||
def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
|
def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list):
|
||||||
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
|
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
|
||||||
self._document_metadata_cache.move_to_end(dataset_id)
|
self._document_metadata_cache.move_to_end(dataset_id)
|
||||||
if len(self._document_metadata_cache) > self._MAX_DOCUMENT_CACHE:
|
|
||||||
self._document_metadata_cache.popitem(last=False)
|
|
||||||
|
|
||||||
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
|
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
|
||||||
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
|
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
|
||||||
|
|
@ -240,11 +237,14 @@ class RAGFlowConnector:
|
||||||
|
|
||||||
docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
|
docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id)
|
||||||
if docs is None:
|
if docs is None:
|
||||||
docs_res = self._get(f"/datasets/{dataset_id}/documents")
|
page = 1
|
||||||
docs_data = docs_res.json()
|
page_size = 30
|
||||||
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
|
|
||||||
doc_id_meta_list = []
|
doc_id_meta_list = []
|
||||||
docs = {}
|
docs = {}
|
||||||
|
while page:
|
||||||
|
docs_res = self._get(f"/datasets/{dataset_id}/documents?page={page}")
|
||||||
|
docs_data = docs_res.json()
|
||||||
|
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
|
||||||
for doc in docs_data["data"]["docs"]:
|
for doc in docs_data["data"]["docs"]:
|
||||||
doc_id = doc.get("id")
|
doc_id = doc.get("id")
|
||||||
if not doc_id:
|
if not doc_id:
|
||||||
|
|
@ -256,30 +256,27 @@ class RAGFlowConnector:
|
||||||
"type": doc.get("type", ""),
|
"type": doc.get("type", ""),
|
||||||
"size": doc.get("size"),
|
"size": doc.get("size"),
|
||||||
"chunk_count": doc.get("chunk_count"),
|
"chunk_count": doc.get("chunk_count"),
|
||||||
# "chunk_method": doc.get("chunk_method", ""),
|
|
||||||
"create_date": doc.get("create_date", ""),
|
"create_date": doc.get("create_date", ""),
|
||||||
"update_date": doc.get("update_date", ""),
|
"update_date": doc.get("update_date", ""),
|
||||||
# "process_begin_at": doc.get("process_begin_at", ""),
|
|
||||||
# "process_duration": doc.get("process_duration"),
|
|
||||||
# "progress": doc.get("progress"),
|
|
||||||
# "progress_msg": doc.get("progress_msg", ""),
|
|
||||||
# "status": doc.get("status", ""),
|
|
||||||
# "run": doc.get("run", ""),
|
|
||||||
"token_count": doc.get("token_count"),
|
"token_count": doc.get("token_count"),
|
||||||
# "source_type": doc.get("source_type", ""),
|
|
||||||
"thumbnail": doc.get("thumbnail", ""),
|
"thumbnail": doc.get("thumbnail", ""),
|
||||||
"dataset_id": doc.get("dataset_id", dataset_id),
|
"dataset_id": doc.get("dataset_id", dataset_id),
|
||||||
"meta_fields": doc.get("meta_fields", {}),
|
"meta_fields": doc.get("meta_fields", {}),
|
||||||
# "parser_config": doc.get("parser_config", {})
|
|
||||||
}
|
}
|
||||||
doc_id_meta_list.append((doc_id, doc_meta))
|
doc_id_meta_list.append((doc_id, doc_meta))
|
||||||
docs[doc_id] = doc_meta
|
docs[doc_id] = doc_meta
|
||||||
|
|
||||||
|
page += 1
|
||||||
|
if docs_data.get("data", {}).get("total", 0) - page * page_size <= 0:
|
||||||
|
page = None
|
||||||
|
|
||||||
self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
|
self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list)
|
||||||
if docs:
|
if docs:
|
||||||
document_cache.update(docs)
|
document_cache.update(docs)
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# Gracefully handle metadata cache failures
|
# Gracefully handle metadata cache failures
|
||||||
|
logging.error(f"Problem building the document metadata cache: {str(e)}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return document_cache, dataset_cache
|
return document_cache, dataset_cache
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||||
|
|
||||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
|
with BytesIO(binary) as binary:
|
||||||
binary = BytesIO(binary)
|
binary = BytesIO(binary)
|
||||||
doc_parsed = parser.from_buffer(binary)
|
doc_parsed = parser.from_buffer(binary)
|
||||||
sections = doc_parsed['content'].split('\n')
|
sections = doc_parsed['content'].split('\n')
|
||||||
|
|
|
||||||
|
|
@ -219,23 +219,27 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _normalize_section(section):
|
def _normalize_section(section):
|
||||||
# pad section to length 3: (txt, sec_id, poss)
|
# Pad/normalize to (txt, layout, positions)
|
||||||
if len(section) == 1:
|
if not isinstance(section, (list, tuple)):
|
||||||
|
section = (section, "", [])
|
||||||
|
elif len(section) == 1:
|
||||||
section = (section[0], "", [])
|
section = (section[0], "", [])
|
||||||
elif len(section) == 2:
|
elif len(section) == 2:
|
||||||
section = (section[0], "", section[1])
|
section = (section[0], "", section[1])
|
||||||
elif len(section) != 3:
|
else:
|
||||||
raise ValueError(f"Unexpected section length: {len(section)} (value={section!r})")
|
section = (section[0], section[1], section[2])
|
||||||
|
|
||||||
txt, layoutno, poss = section
|
txt, layoutno, poss = section
|
||||||
if isinstance(poss, str):
|
if isinstance(poss, str):
|
||||||
poss = pdf_parser.extract_positions(poss)
|
poss = pdf_parser.extract_positions(poss)
|
||||||
|
if poss:
|
||||||
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
|
first = poss[0] # tuple: ([pn], x1, x2, y1, y2)
|
||||||
pn = first[0]
|
pn = first[0]
|
||||||
|
if isinstance(pn, list) and pn:
|
||||||
if isinstance(pn, list):
|
|
||||||
pn = pn[0] # [pn] -> pn
|
pn = pn[0] # [pn] -> pn
|
||||||
poss[0] = (pn, *first[1:])
|
poss[0] = (pn, *first[1:])
|
||||||
|
if not poss:
|
||||||
|
poss = []
|
||||||
|
|
||||||
return (txt, layoutno, poss)
|
return (txt, layoutno, poss)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,8 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy, copy
|
from copy import deepcopy
|
||||||
|
|
||||||
import trio
|
|
||||||
import xxhash
|
import xxhash
|
||||||
|
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
|
|
@ -38,13 +37,13 @@ class ExtractorParam(ProcessParamBase, LLMParam):
|
||||||
class Extractor(ProcessBase, LLM):
|
class Extractor(ProcessBase, LLM):
|
||||||
component_name = "Extractor"
|
component_name = "Extractor"
|
||||||
|
|
||||||
def _build_TOC(self, docs):
|
async def _build_TOC(self, docs):
|
||||||
self.callback(message="Start to generate table of content ...")
|
self.callback(0.2,message="Start to generate table of content ...")
|
||||||
docs = sorted(docs, key=lambda d:(
|
docs = sorted(docs, key=lambda d:(
|
||||||
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||||
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||||
))
|
))
|
||||||
toc: list[dict] = trio.run(run_toc_from_text, [d["text"] for d in docs], self.chat_mdl)
|
toc = await run_toc_from_text([d["text"] for d in docs], self.chat_mdl)
|
||||||
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||||
ii = 0
|
ii = 0
|
||||||
while ii < len(toc):
|
while ii < len(toc):
|
||||||
|
|
@ -61,7 +60,8 @@ class Extractor(ProcessBase, LLM):
|
||||||
ii += 1
|
ii += 1
|
||||||
|
|
||||||
if toc:
|
if toc:
|
||||||
d = copy.deepcopy(docs[-1])
|
d = deepcopy(docs[-1])
|
||||||
|
d["doc_id"] = self._canvas._doc_id
|
||||||
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
|
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
|
||||||
d["toc_kwd"] = "toc"
|
d["toc_kwd"] = "toc"
|
||||||
d["available_int"] = 0
|
d["available_int"] = 0
|
||||||
|
|
@ -85,7 +85,10 @@ class Extractor(ProcessBase, LLM):
|
||||||
|
|
||||||
if chunks:
|
if chunks:
|
||||||
if self._param.field_name == "toc":
|
if self._param.field_name == "toc":
|
||||||
toc = self._build_TOC(chunks)
|
for ck in chunks:
|
||||||
|
ck["doc_id"] = self._canvas._doc_id
|
||||||
|
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||||
|
toc =await self._build_TOC(chunks)
|
||||||
chunks.append(toc)
|
chunks.append(toc)
|
||||||
self.set_output("chunks", chunks)
|
self.set_output("chunks", chunks)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ class Splitter(ProcessBase):
|
||||||
{
|
{
|
||||||
"text": RAGFlowPdfParser.remove_tag(c),
|
"text": RAGFlowPdfParser.remove_tag(c),
|
||||||
"image": img,
|
"image": img,
|
||||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)]
|
"positions": [[pos[0][-1], *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)]
|
||||||
}
|
}
|
||||||
for c, img in zip(chunks, images) if c.strip()
|
for c, img in zip(chunks, images) if c.strip()
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,8 @@ class SupportedLiteLLMProvider(StrEnum):
|
||||||
JiekouAI = "Jiekou.AI"
|
JiekouAI = "Jiekou.AI"
|
||||||
ZHIPU_AI = "ZHIPU-AI"
|
ZHIPU_AI = "ZHIPU-AI"
|
||||||
MiniMax = "MiniMax"
|
MiniMax = "MiniMax"
|
||||||
|
DeerAPI = "DeerAPI"
|
||||||
|
GPUStack = "GPUStack"
|
||||||
|
|
||||||
|
|
||||||
FACTORY_DEFAULT_BASE_URL = {
|
FACTORY_DEFAULT_BASE_URL = {
|
||||||
|
|
@ -75,6 +77,7 @@ FACTORY_DEFAULT_BASE_URL = {
|
||||||
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
|
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
|
||||||
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
|
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
|
||||||
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
|
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
|
||||||
|
SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -108,6 +111,8 @@ LITELLM_PROVIDER_PREFIX = {
|
||||||
SupportedLiteLLMProvider.JiekouAI: "openai/",
|
SupportedLiteLLMProvider.JiekouAI: "openai/",
|
||||||
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
|
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
|
||||||
SupportedLiteLLMProvider.MiniMax: "openai/",
|
SupportedLiteLLMProvider.MiniMax: "openai/",
|
||||||
|
SupportedLiteLLMProvider.DeerAPI: "openai/",
|
||||||
|
SupportedLiteLLMProvider.GPUStack: "openai/",
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatModel = globals().get("ChatModel", {})
|
ChatModel = globals().get("ChatModel", {})
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
@ -78,11 +77,9 @@ class Base(ABC):
|
||||||
self.toolcall_sessions = {}
|
self.toolcall_sessions = {}
|
||||||
|
|
||||||
def _get_delay(self):
|
def _get_delay(self):
|
||||||
"""Calculate retry delay time"""
|
|
||||||
return self.base_delay * random.uniform(10, 150)
|
return self.base_delay * random.uniform(10, 150)
|
||||||
|
|
||||||
def _classify_error(self, error):
|
def _classify_error(self, error):
|
||||||
"""Classify error based on error message content"""
|
|
||||||
error_str = str(error).lower()
|
error_str = str(error).lower()
|
||||||
|
|
||||||
keywords_mapping = [
|
keywords_mapping = [
|
||||||
|
|
@ -139,89 +136,7 @@ class Base(ABC):
|
||||||
|
|
||||||
return gen_conf
|
return gen_conf
|
||||||
|
|
||||||
def _bridge_sync_stream(self, gen):
|
async def _async_chat_streamly(self, history, gen_conf, **kwargs):
|
||||||
"""Run a sync generator in a thread and yield asynchronously."""
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
queue: asyncio.Queue = asyncio.Queue()
|
|
||||||
|
|
||||||
def worker():
|
|
||||||
try:
|
|
||||||
for item in gen:
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
|
||||||
except Exception as exc: # pragma: no cover - defensive
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, exc)
|
|
||||||
finally:
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
|
||||||
|
|
||||||
threading.Thread(target=worker, daemon=True).start()
|
|
||||||
return queue
|
|
||||||
|
|
||||||
def _chat(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
|
||||||
if self.model_name.lower().find("qwq") >= 0:
|
|
||||||
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
|
|
||||||
|
|
||||||
final_ans = ""
|
|
||||||
tol_token = 0
|
|
||||||
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
|
|
||||||
if delta.startswith("<think>") or delta.endswith("</think>"):
|
|
||||||
continue
|
|
||||||
final_ans += delta
|
|
||||||
tol_token = tol
|
|
||||||
|
|
||||||
if len(final_ans.strip()) == 0:
|
|
||||||
final_ans = "**ERROR**: Empty response from reasoning model"
|
|
||||||
|
|
||||||
return final_ans.strip(), tol_token
|
|
||||||
|
|
||||||
if self.model_name.lower().find("qwen3") >= 0:
|
|
||||||
kwargs["extra_body"] = {"enable_thinking": False}
|
|
||||||
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
||||||
|
|
||||||
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
|
||||||
return "", 0
|
|
||||||
ans = response.choices[0].message.content.strip()
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
ans = self._length_stop(ans)
|
|
||||||
return ans, total_token_count_from_response(response)
|
|
||||||
|
|
||||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
|
||||||
reasoning_start = False
|
|
||||||
|
|
||||||
if kwargs.get("stop") or "stop" in gen_conf:
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
|
||||||
else:
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
||||||
|
|
||||||
for resp in response:
|
|
||||||
if not resp.choices:
|
|
||||||
continue
|
|
||||||
if not resp.choices[0].delta.content:
|
|
||||||
resp.choices[0].delta.content = ""
|
|
||||||
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
|
||||||
ans = ""
|
|
||||||
if not reasoning_start:
|
|
||||||
reasoning_start = True
|
|
||||||
ans = "<think>"
|
|
||||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
|
||||||
else:
|
|
||||||
reasoning_start = False
|
|
||||||
ans = resp.choices[0].delta.content
|
|
||||||
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
|
||||||
|
|
||||||
if resp.choices[0].finish_reason == "length":
|
|
||||||
if is_chinese(ans):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
yield ans, tol
|
|
||||||
|
|
||||||
async def _async_chat_stream(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
|
|
||||||
|
|
@ -265,13 +180,19 @@ class Base(ABC):
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
ans = ""
|
ans = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
|
async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs):
|
||||||
ans = delta_ans
|
ans = delta_ans
|
||||||
total_tokens += tol
|
total_tokens += tol
|
||||||
yield delta_ans
|
yield ans
|
||||||
except openai.APIError as e:
|
except Exception as e:
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
e = await self._exceptions_async(e, attempt)
|
||||||
|
if e:
|
||||||
|
yield e
|
||||||
|
yield total_tokens
|
||||||
|
return
|
||||||
|
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
|
|
@ -307,7 +228,7 @@ class Base(ABC):
|
||||||
logging.error(f"sync base giving up: {msg}")
|
logging.error(f"sync base giving up: {msg}")
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
async def _exceptions_async(self, e, attempt) -> str | None:
|
async def _exceptions_async(self, e, attempt):
|
||||||
logging.exception("OpenAI async completion")
|
logging.exception("OpenAI async completion")
|
||||||
error_code = self._classify_error(e)
|
error_code = self._classify_error(e)
|
||||||
if attempt == self.max_retries:
|
if attempt == self.max_retries:
|
||||||
|
|
@ -357,61 +278,6 @@ class Base(ABC):
|
||||||
self.toolcall_session = toolcall_session
|
self.toolcall_session = toolcall_session
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
||||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
|
|
||||||
ans = ""
|
|
||||||
tk_count = 0
|
|
||||||
hist = deepcopy(history)
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
history = hist
|
|
||||||
try:
|
|
||||||
for _ in range(self.max_rounds + 1):
|
|
||||||
logging.info(f"{self.tools=}")
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
|
||||||
tk_count += total_token_count_from_response(response)
|
|
||||||
if any([not response.choices, not response.choices[0].message]):
|
|
||||||
raise Exception(f"500 response structure error. Response: {response}")
|
|
||||||
|
|
||||||
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
|
||||||
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
|
|
||||||
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
|
|
||||||
|
|
||||||
ans += response.choices[0].message.content
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
ans = self._length_stop(ans)
|
|
||||||
|
|
||||||
return ans, tk_count
|
|
||||||
|
|
||||||
for tool_call in response.choices[0].message.tool_calls:
|
|
||||||
logging.info(f"Response {tool_call=}")
|
|
||||||
name = tool_call.function.name
|
|
||||||
try:
|
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
|
||||||
ans += self._verbose_tool_use(name, args, tool_response)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
|
||||||
ans += self._verbose_tool_use(name, {}, str(e))
|
|
||||||
|
|
||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
|
||||||
response, token_count = self._chat(history, gen_conf)
|
|
||||||
ans += response
|
|
||||||
tk_count += token_count
|
|
||||||
return ans, tk_count
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
return e, tk_count
|
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
|
|
@ -466,140 +332,6 @@ class Base(ABC):
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
try:
|
|
||||||
return self._chat(history, gen_conf, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
return e, 0
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
def _wrap_toolcall_message(self, stream):
|
|
||||||
final_tool_calls = {}
|
|
||||||
|
|
||||||
for chunk in stream:
|
|
||||||
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
|
||||||
index = tool_call.index
|
|
||||||
|
|
||||||
if index not in final_tool_calls:
|
|
||||||
final_tool_calls[index] = tool_call
|
|
||||||
|
|
||||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
|
||||||
|
|
||||||
return final_tool_calls
|
|
||||||
|
|
||||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
tools = self.tools
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
|
|
||||||
total_tokens = 0
|
|
||||||
hist = deepcopy(history)
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
history = hist
|
|
||||||
try:
|
|
||||||
for _ in range(self.max_rounds + 1):
|
|
||||||
reasoning_start = False
|
|
||||||
logging.info(f"{tools=}")
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
|
||||||
final_tool_calls = {}
|
|
||||||
answer = ""
|
|
||||||
for resp in response:
|
|
||||||
if resp.choices[0].delta.tool_calls:
|
|
||||||
for tool_call in resp.choices[0].delta.tool_calls or []:
|
|
||||||
index = tool_call.index
|
|
||||||
|
|
||||||
if index not in final_tool_calls:
|
|
||||||
if not tool_call.function.arguments:
|
|
||||||
tool_call.function.arguments = ""
|
|
||||||
final_tool_calls[index] = tool_call
|
|
||||||
else:
|
|
||||||
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
|
|
||||||
continue
|
|
||||||
|
|
||||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
|
||||||
raise Exception("500 response structure error.")
|
|
||||||
|
|
||||||
if not resp.choices[0].delta.content:
|
|
||||||
resp.choices[0].delta.content = ""
|
|
||||||
|
|
||||||
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
|
||||||
ans = ""
|
|
||||||
if not reasoning_start:
|
|
||||||
reasoning_start = True
|
|
||||||
ans = "<think>"
|
|
||||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
|
||||||
yield ans
|
|
||||||
else:
|
|
||||||
reasoning_start = False
|
|
||||||
answer += resp.choices[0].delta.content
|
|
||||||
yield resp.choices[0].delta.content
|
|
||||||
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
|
||||||
else:
|
|
||||||
total_tokens = tol
|
|
||||||
|
|
||||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
|
||||||
if finish_reason == "length":
|
|
||||||
yield self._length_stop("")
|
|
||||||
|
|
||||||
if answer:
|
|
||||||
yield total_tokens
|
|
||||||
return
|
|
||||||
|
|
||||||
for tool_call in final_tool_calls.values():
|
|
||||||
name = tool_call.function.name
|
|
||||||
try:
|
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
|
||||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
|
||||||
yield self._verbose_tool_use(name, args, tool_response)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
|
||||||
yield self._verbose_tool_use(name, {}, str(e))
|
|
||||||
|
|
||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
||||||
for resp in response:
|
|
||||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
|
||||||
raise Exception("500 response structure error.")
|
|
||||||
if not resp.choices[0].delta.content:
|
|
||||||
resp.choices[0].delta.content = ""
|
|
||||||
continue
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
|
||||||
else:
|
|
||||||
total_tokens = tol
|
|
||||||
answer += resp.choices[0].delta.content
|
|
||||||
yield resp.choices[0].delta.content
|
|
||||||
|
|
||||||
yield total_tokens
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
yield e
|
|
||||||
yield total_tokens
|
|
||||||
return
|
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
tools = self.tools
|
tools = self.tools
|
||||||
|
|
@ -715,9 +447,10 @@ class Base(ABC):
|
||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||||
if self.model_name.lower().find("qwq") >= 0:
|
if self.model_name.lower().find("qwq") >= 0:
|
||||||
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
|
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
|
||||||
|
|
||||||
final_ans = ""
|
final_ans = ""
|
||||||
tol_token = 0
|
tol_token = 0
|
||||||
async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
|
async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
|
||||||
if delta.startswith("<think>") or delta.endswith("</think>"):
|
if delta.startswith("<think>") or delta.endswith("</think>"):
|
||||||
continue
|
continue
|
||||||
final_ans += delta
|
final_ans += delta
|
||||||
|
|
@ -754,57 +487,6 @@ class Base(ABC):
|
||||||
return e, 0
|
return e, 0
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
ans = ""
|
|
||||||
total_tokens = 0
|
|
||||||
try:
|
|
||||||
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
|
|
||||||
yield delta_ans
|
|
||||||
total_tokens += tol
|
|
||||||
except openai.APIError as e:
|
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
|
|
||||||
yield total_tokens
|
|
||||||
|
|
||||||
def _calculate_dynamic_ctx(self, history):
|
|
||||||
"""Calculate dynamic context window size"""
|
|
||||||
|
|
||||||
def count_tokens(text):
|
|
||||||
"""Calculate token count for text"""
|
|
||||||
# Simple calculation: 1 token per ASCII character
|
|
||||||
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
|
||||||
total = 0
|
|
||||||
for char in text:
|
|
||||||
if ord(char) < 128: # ASCII characters
|
|
||||||
total += 1
|
|
||||||
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
|
||||||
total += 2
|
|
||||||
return total
|
|
||||||
|
|
||||||
# Calculate total tokens for all messages
|
|
||||||
total_tokens = 0
|
|
||||||
for message in history:
|
|
||||||
content = message.get("content", "")
|
|
||||||
# Calculate content tokens
|
|
||||||
content_tokens = count_tokens(content)
|
|
||||||
# Add role marker token overhead
|
|
||||||
role_tokens = 4
|
|
||||||
total_tokens += content_tokens + role_tokens
|
|
||||||
|
|
||||||
# Apply 1.2x buffer ratio
|
|
||||||
total_tokens_with_buffer = int(total_tokens * 1.2)
|
|
||||||
|
|
||||||
if total_tokens_with_buffer <= 8192:
|
|
||||||
ctx_size = 8192
|
|
||||||
else:
|
|
||||||
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
|
|
||||||
ctx_size = ctx_multiplier * 8192
|
|
||||||
|
|
||||||
return ctx_size
|
|
||||||
|
|
||||||
|
|
||||||
class GptTurbo(Base):
|
class GptTurbo(Base):
|
||||||
_FACTORY_NAME = "OpenAI"
|
_FACTORY_NAME = "OpenAI"
|
||||||
|
|
@ -1504,16 +1186,6 @@ class GoogleChat(Base):
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
|
|
||||||
class GPUStackChat(Base):
|
|
||||||
_FACTORY_NAME = "GPUStack"
|
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
|
||||||
if not base_url:
|
|
||||||
raise ValueError("Local llm url cannot be None")
|
|
||||||
base_url = urljoin(base_url, "v1")
|
|
||||||
super().__init__(key, model_name, base_url, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenPonyChat(Base):
|
class TokenPonyChat(Base):
|
||||||
_FACTORY_NAME = "TokenPony"
|
_FACTORY_NAME = "TokenPony"
|
||||||
|
|
||||||
|
|
@ -1523,15 +1195,6 @@ class TokenPonyChat(Base):
|
||||||
super().__init__(key, model_name, base_url, **kwargs)
|
super().__init__(key, model_name, base_url, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class DeerAPIChat(Base):
|
|
||||||
_FACTORY_NAME = "DeerAPI"
|
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
|
|
||||||
if not base_url:
|
|
||||||
base_url = "https://api.deerapi.com/v1"
|
|
||||||
super().__init__(key, model_name, base_url, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMBase(ABC):
|
class LiteLLMBase(ABC):
|
||||||
_FACTORY_NAME = [
|
_FACTORY_NAME = [
|
||||||
"Tongyi-Qianwen",
|
"Tongyi-Qianwen",
|
||||||
|
|
@ -1562,6 +1225,8 @@ class LiteLLMBase(ABC):
|
||||||
"Jiekou.AI",
|
"Jiekou.AI",
|
||||||
"ZHIPU-AI",
|
"ZHIPU-AI",
|
||||||
"MiniMax",
|
"MiniMax",
|
||||||
|
"DeerAPI",
|
||||||
|
"GPUStack",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
|
|
@ -1589,11 +1254,9 @@ class LiteLLMBase(ABC):
|
||||||
self.provider_order = json.loads(key).get("provider_order", "")
|
self.provider_order = json.loads(key).get("provider_order", "")
|
||||||
|
|
||||||
def _get_delay(self):
|
def _get_delay(self):
|
||||||
"""Calculate retry delay time"""
|
|
||||||
return self.base_delay * random.uniform(10, 150)
|
return self.base_delay * random.uniform(10, 150)
|
||||||
|
|
||||||
def _classify_error(self, error):
|
def _classify_error(self, error):
|
||||||
"""Classify error based on error message content"""
|
|
||||||
error_str = str(error).lower()
|
error_str = str(error).lower()
|
||||||
|
|
||||||
keywords_mapping = [
|
keywords_mapping = [
|
||||||
|
|
@ -1619,72 +1282,6 @@ class LiteLLMBase(ABC):
|
||||||
del gen_conf["max_tokens"]
|
del gen_conf["max_tokens"]
|
||||||
return gen_conf
|
return gen_conf
|
||||||
|
|
||||||
def _chat(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
|
||||||
if self.model_name.lower().find("qwen3") >= 0:
|
|
||||||
kwargs["extra_body"] = {"enable_thinking": False}
|
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
|
|
||||||
response = litellm.completion(
|
|
||||||
**completion_args,
|
|
||||||
drop_params=True,
|
|
||||||
timeout=self.timeout,
|
|
||||||
)
|
|
||||||
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
||||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
|
||||||
return "", 0
|
|
||||||
ans = response.choices[0].message.content.strip()
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
ans = self._length_stop(ans)
|
|
||||||
|
|
||||||
return ans, total_token_count_from_response(response)
|
|
||||||
|
|
||||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
reasoning_start = False
|
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
|
|
||||||
stop = kwargs.get("stop")
|
|
||||||
if stop:
|
|
||||||
completion_args["stop"] = stop
|
|
||||||
response = litellm.completion(
|
|
||||||
**completion_args,
|
|
||||||
drop_params=True,
|
|
||||||
timeout=self.timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
for resp in response:
|
|
||||||
if not hasattr(resp, "choices") or not resp.choices:
|
|
||||||
continue
|
|
||||||
|
|
||||||
delta = resp.choices[0].delta
|
|
||||||
if not hasattr(delta, "content") or delta.content is None:
|
|
||||||
delta.content = ""
|
|
||||||
|
|
||||||
if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
|
||||||
ans = ""
|
|
||||||
if not reasoning_start:
|
|
||||||
reasoning_start = True
|
|
||||||
ans = "<think>"
|
|
||||||
ans += delta.reasoning_content + "</think>"
|
|
||||||
else:
|
|
||||||
reasoning_start = False
|
|
||||||
ans = delta.content
|
|
||||||
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
tol = num_tokens_from_string(delta.content)
|
|
||||||
|
|
||||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
|
||||||
if finish_reason == "length":
|
|
||||||
if is_chinese(ans):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
|
|
||||||
yield ans, tol
|
|
||||||
|
|
||||||
async def async_chat(self, system, history, gen_conf, **kwargs):
|
async def async_chat(self, system, history, gen_conf, **kwargs):
|
||||||
hist = list(history) if history else []
|
hist = list(history) if history else []
|
||||||
if system:
|
if system:
|
||||||
|
|
@ -1795,22 +1392,7 @@ class LiteLLMBase(ABC):
|
||||||
def _should_retry(self, error_code: str) -> bool:
|
def _should_retry(self, error_code: str) -> bool:
|
||||||
return error_code in self._retryable_errors
|
return error_code in self._retryable_errors
|
||||||
|
|
||||||
def _exceptions(self, e, attempt) -> str | None:
|
async def _exceptions_async(self, e, attempt):
|
||||||
logging.exception("OpenAI chat_with_tools")
|
|
||||||
# Classify the error
|
|
||||||
error_code = self._classify_error(e)
|
|
||||||
if attempt == self.max_retries:
|
|
||||||
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
|
||||||
|
|
||||||
if self._should_retry(error_code):
|
|
||||||
delay = self._get_delay()
|
|
||||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
||||||
time.sleep(delay)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
|
||||||
|
|
||||||
async def _exceptions_async(self, e, attempt) -> str | None:
|
|
||||||
logging.exception("LiteLLMBase async completion")
|
logging.exception("LiteLLMBase async completion")
|
||||||
error_code = self._classify_error(e)
|
error_code = self._classify_error(e)
|
||||||
if attempt == self.max_retries:
|
if attempt == self.max_retries:
|
||||||
|
|
@ -1859,71 +1441,7 @@ class LiteLLMBase(ABC):
|
||||||
self.toolcall_session = toolcall_session
|
self.toolcall_session = toolcall_session
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
||||||
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
|
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
completion_args = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": history,
|
|
||||||
"api_key": self.api_key,
|
|
||||||
"num_retries": self.max_retries,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
if stream:
|
|
||||||
completion_args.update(
|
|
||||||
{
|
|
||||||
"stream": stream,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if tools and self.tools:
|
|
||||||
completion_args.update(
|
|
||||||
{
|
|
||||||
"tools": self.tools,
|
|
||||||
"tool_choice": "auto",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if self.provider in FACTORY_DEFAULT_BASE_URL:
|
|
||||||
completion_args.update({"api_base": self.base_url})
|
|
||||||
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
|
||||||
completion_args.pop("api_key", None)
|
|
||||||
completion_args.pop("api_base", None)
|
|
||||||
completion_args.update(
|
|
||||||
{
|
|
||||||
"aws_access_key_id": self.bedrock_ak,
|
|
||||||
"aws_secret_access_key": self.bedrock_sk,
|
|
||||||
"aws_region_name": self.bedrock_region,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.provider == SupportedLiteLLMProvider.OpenRouter:
|
|
||||||
if self.provider_order:
|
|
||||||
|
|
||||||
def _to_order_list(x):
|
|
||||||
if x is None:
|
|
||||||
return []
|
|
||||||
if isinstance(x, str):
|
|
||||||
return [s.strip() for s in x.split(",") if s.strip()]
|
|
||||||
if isinstance(x, (list, tuple)):
|
|
||||||
return [str(s).strip() for s in x if str(s).strip()]
|
|
||||||
return []
|
|
||||||
|
|
||||||
extra_body = {}
|
|
||||||
provider_cfg = {}
|
|
||||||
provider_order = _to_order_list(self.provider_order)
|
|
||||||
provider_cfg["order"] = provider_order
|
|
||||||
provider_cfg["allow_fallbacks"] = False
|
|
||||||
extra_body["provider"] = provider_cfg
|
|
||||||
completion_args.update({"extra_body": extra_body})
|
|
||||||
|
|
||||||
# Ollama deployments commonly sit behind a reverse proxy that enforces
|
|
||||||
# Bearer auth. Ensure the Authorization header is set when an API key
|
|
||||||
# is provided, while respecting any user-supplied headers. #11350
|
|
||||||
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
|
|
||||||
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
|
|
||||||
extra_headers["Authorization"] = f"Bearer {self.api_key}"
|
|
||||||
if extra_headers:
|
|
||||||
completion_args["extra_headers"] = extra_headers
|
|
||||||
return completion_args
|
|
||||||
|
|
||||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
|
@ -1931,16 +1449,14 @@ class LiteLLMBase(ABC):
|
||||||
ans = ""
|
ans = ""
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
history = deepcopy(hist) # deepcopy is required here
|
history = deepcopy(hist)
|
||||||
try:
|
try:
|
||||||
for _ in range(self.max_rounds + 1):
|
for _ in range(self.max_rounds + 1):
|
||||||
logging.info(f"{self.tools=}")
|
logging.info(f"{self.tools=}")
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
|
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
|
||||||
response = litellm.completion(
|
response = await litellm.acompletion(
|
||||||
**completion_args,
|
**completion_args,
|
||||||
drop_params=True,
|
drop_params=True,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
|
|
@ -1966,7 +1482,7 @@ class LiteLLMBase(ABC):
|
||||||
name = tool_call.function.name
|
name = tool_call.function.name
|
||||||
try:
|
try:
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
history = self._append_history(history, tool_call, tool_response)
|
||||||
ans += self._verbose_tool_use(name, args, tool_response)
|
ans += self._verbose_tool_use(name, args, tool_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1977,49 +1493,19 @@ class LiteLLMBase(ABC):
|
||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||||
|
|
||||||
response, token_count = self._chat(history, gen_conf)
|
response, token_count = await self.async_chat("", history, gen_conf)
|
||||||
ans += response
|
ans += response
|
||||||
tk_count += token_count
|
tk_count += token_count
|
||||||
return ans, tk_count
|
return ans, tk_count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e = self._exceptions(e, attempt)
|
e = await self._exceptions_async(e, attempt)
|
||||||
if e:
|
if e:
|
||||||
return e, tk_count
|
return e, tk_count
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
try:
|
|
||||||
response = self._chat(history, gen_conf, **kwargs)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
return e, 0
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
def _wrap_toolcall_message(self, stream):
|
|
||||||
final_tool_calls = {}
|
|
||||||
|
|
||||||
for chunk in stream:
|
|
||||||
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
|
||||||
index = tool_call.index
|
|
||||||
|
|
||||||
if index not in final_tool_calls:
|
|
||||||
final_tool_calls[index] = tool_call
|
|
||||||
|
|
||||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
|
||||||
|
|
||||||
return final_tool_calls
|
|
||||||
|
|
||||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
tools = self.tools
|
tools = self.tools
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
|
|
@ -2028,16 +1514,15 @@ class LiteLLMBase(ABC):
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
history = deepcopy(hist) # deepcopy is required here
|
history = deepcopy(hist)
|
||||||
try:
|
try:
|
||||||
for _ in range(self.max_rounds + 1):
|
for _ in range(self.max_rounds + 1):
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
logging.info(f"{tools=}")
|
logging.info(f"{tools=}")
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
||||||
response = litellm.completion(
|
response = await litellm.acompletion(
|
||||||
**completion_args,
|
**completion_args,
|
||||||
drop_params=True,
|
drop_params=True,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
|
|
@ -2046,7 +1531,7 @@ class LiteLLMBase(ABC):
|
||||||
final_tool_calls = {}
|
final_tool_calls = {}
|
||||||
answer = ""
|
answer = ""
|
||||||
|
|
||||||
for resp in response:
|
async for resp in response:
|
||||||
if not hasattr(resp, "choices") or not resp.choices:
|
if not hasattr(resp, "choices") or not resp.choices:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -2082,7 +1567,7 @@ class LiteLLMBase(ABC):
|
||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(delta.content)
|
total_tokens += num_tokens_from_string(delta.content)
|
||||||
else:
|
else:
|
||||||
total_tokens += tol
|
total_tokens = tol
|
||||||
|
|
||||||
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
||||||
if finish_reason == "length":
|
if finish_reason == "length":
|
||||||
|
|
@ -2097,31 +1582,25 @@ class LiteLLMBase(ABC):
|
||||||
try:
|
try:
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
history = self._append_history(history, tool_call, tool_response)
|
||||||
yield self._verbose_tool_use(name, args, tool_response)
|
yield self._verbose_tool_use(name, args, tool_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||||
history.append(
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_call.id,
|
|
||||||
"content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
yield self._verbose_tool_use(name, {}, str(e))
|
yield self._verbose_tool_use(name, {}, str(e))
|
||||||
|
|
||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
||||||
response = litellm.completion(
|
response = await litellm.acompletion(
|
||||||
**completion_args,
|
**completion_args,
|
||||||
drop_params=True,
|
drop_params=True,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
for resp in response:
|
async for resp in response:
|
||||||
if not hasattr(resp, "choices") or not resp.choices:
|
if not hasattr(resp, "choices") or not resp.choices:
|
||||||
continue
|
continue
|
||||||
delta = resp.choices[0].delta
|
delta = resp.choices[0].delta
|
||||||
|
|
@ -2131,14 +1610,14 @@ class LiteLLMBase(ABC):
|
||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(delta.content)
|
total_tokens += num_tokens_from_string(delta.content)
|
||||||
else:
|
else:
|
||||||
total_tokens += tol
|
total_tokens = tol
|
||||||
yield delta.content
|
yield delta.content
|
||||||
|
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
return
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e = self._exceptions(e, attempt)
|
e = await self._exceptions_async(e, attempt)
|
||||||
if e:
|
if e:
|
||||||
yield e
|
yield e
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
@ -2146,53 +1625,71 @@ class LiteLLMBase(ABC):
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
|
||||||
if system and history and history[0].get("role") != "system":
|
completion_args = {
|
||||||
history.insert(0, {"role": "system", "content": system})
|
"model": self.model_name,
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
"messages": history,
|
||||||
ans = ""
|
"api_key": self.api_key,
|
||||||
total_tokens = 0
|
"num_retries": self.max_retries,
|
||||||
try:
|
**kwargs,
|
||||||
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
|
}
|
||||||
yield delta_ans
|
if stream:
|
||||||
total_tokens += tol
|
completion_args.update(
|
||||||
except openai.APIError as e:
|
{
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
"stream": stream,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if tools and self.tools:
|
||||||
|
completion_args.update(
|
||||||
|
{
|
||||||
|
"tools": self.tools,
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if self.provider in FACTORY_DEFAULT_BASE_URL:
|
||||||
|
completion_args.update({"api_base": self.base_url})
|
||||||
|
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
||||||
|
completion_args.pop("api_key", None)
|
||||||
|
completion_args.pop("api_base", None)
|
||||||
|
completion_args.update(
|
||||||
|
{
|
||||||
|
"aws_access_key_id": self.bedrock_ak,
|
||||||
|
"aws_secret_access_key": self.bedrock_sk,
|
||||||
|
"aws_region_name": self.bedrock_region,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||||
|
if self.provider_order:
|
||||||
|
|
||||||
yield total_tokens
|
def _to_order_list(x):
|
||||||
|
if x is None:
|
||||||
|
return []
|
||||||
|
if isinstance(x, str):
|
||||||
|
return [s.strip() for s in x.split(",") if s.strip()]
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
return [str(s).strip() for s in x if str(s).strip()]
|
||||||
|
return []
|
||||||
|
|
||||||
def _calculate_dynamic_ctx(self, history):
|
extra_body = {}
|
||||||
"""Calculate dynamic context window size"""
|
provider_cfg = {}
|
||||||
|
provider_order = _to_order_list(self.provider_order)
|
||||||
|
provider_cfg["order"] = provider_order
|
||||||
|
provider_cfg["allow_fallbacks"] = False
|
||||||
|
extra_body["provider"] = provider_cfg
|
||||||
|
completion_args.update({"extra_body": extra_body})
|
||||||
|
elif self.provider == SupportedLiteLLMProvider.GPUStack:
|
||||||
|
completion_args.update(
|
||||||
|
{
|
||||||
|
"api_base": self.base_url,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def count_tokens(text):
|
# Ollama deployments commonly sit behind a reverse proxy that enforces
|
||||||
"""Calculate token count for text"""
|
# Bearer auth. Ensure the Authorization header is set when an API key
|
||||||
# Simple calculation: 1 token per ASCII character
|
# is provided, while respecting any user-supplied headers. #11350
|
||||||
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
|
||||||
total = 0
|
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
|
||||||
for char in text:
|
extra_headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
if ord(char) < 128: # ASCII characters
|
if extra_headers:
|
||||||
total += 1
|
completion_args["extra_headers"] = extra_headers
|
||||||
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
return completion_args
|
||||||
total += 2
|
|
||||||
return total
|
|
||||||
|
|
||||||
# Calculate total tokens for all messages
|
|
||||||
total_tokens = 0
|
|
||||||
for message in history:
|
|
||||||
content = message.get("content", "")
|
|
||||||
# Calculate content tokens
|
|
||||||
content_tokens = count_tokens(content)
|
|
||||||
# Add role marker token overhead
|
|
||||||
role_tokens = 4
|
|
||||||
total_tokens += content_tokens + role_tokens
|
|
||||||
|
|
||||||
# Apply 1.2x buffer ratio
|
|
||||||
total_tokens_with_buffer = int(total_tokens * 1.2)
|
|
||||||
|
|
||||||
if total_tokens_with_buffer <= 8192:
|
|
||||||
ctx_size = 8192
|
|
||||||
else:
|
|
||||||
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
|
|
||||||
ctx_size = ctx_multiplier * 8192
|
|
||||||
|
|
||||||
return ctx_size
|
|
||||||
|
|
|
||||||
|
|
@ -592,6 +592,7 @@ async def run_dataflow(task: dict):
|
||||||
ck["docnm_kwd"] = task["name"]
|
ck["docnm_kwd"] = task["name"]
|
||||||
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||||
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
||||||
|
if not ck.get("id"):
|
||||||
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||||
if "questions" in ck:
|
if "questions" in ck:
|
||||||
if "question_tks" not in ck:
|
if "question_tks" not in ck:
|
||||||
|
|
|
||||||
|
|
@ -122,15 +122,15 @@ async def create_container(name: str, language: SupportLanguage) -> bool:
|
||||||
logger.info(f"Sandbox config:\n\t {create_args}")
|
logger.info(f"Sandbox config:\n\t {create_args}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
returncode, _, stderr = await async_run_command(*create_args, timeout=10)
|
return_code, _, stderr = await async_run_command(*create_args, timeout=10)
|
||||||
if returncode != 0:
|
if return_code != 0:
|
||||||
logger.error(f"❌ Container creation failed {name}: {stderr}")
|
logger.error(f"❌ Container creation failed {name}: {stderr}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if language == SupportLanguage.NODEJS:
|
if language == SupportLanguage.NODEJS:
|
||||||
copy_cmd = ["docker", "exec", name, "bash", "-c", "cp -a /app/node_modules /workspace/"]
|
copy_cmd = ["docker", "exec", name, "bash", "-c", "cp -a /app/node_modules /workspace/"]
|
||||||
returncode, _, stderr = await async_run_command(*copy_cmd, timeout=10)
|
return_code, _, stderr = await async_run_command(*copy_cmd, timeout=10)
|
||||||
if returncode != 0:
|
if return_code != 0:
|
||||||
logger.error(f"❌ Failed to prepare dependencies for {name}: {stderr}")
|
logger.error(f"❌ Failed to prepare dependencies for {name}: {stderr}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -185,7 +185,7 @@ async def allocate_container_blocking(language: SupportLanguage, timeout=10) ->
|
||||||
async def container_is_running(name: str) -> bool:
|
async def container_is_running(name: str) -> bool:
|
||||||
"""Asynchronously check the container status"""
|
"""Asynchronously check the container status"""
|
||||||
try:
|
try:
|
||||||
returncode, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2)
|
return_code, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2)
|
||||||
return returncode == 0 and stdout.strip() == "true"
|
return return_code == 0 and stdout.strip() == "true"
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,7 @@
|
||||||
// Inspired by react-hot-toast library
|
// Inspired by react-hot-toast library
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
|
|
||||||
import type {
|
import type { ToastActionElement, ToastProps } from '@/components/ui/toast';
|
||||||
ToastActionElement,
|
|
||||||
ToastProps,
|
|
||||||
} from '@/registry/default/ui/toast';
|
|
||||||
|
|
||||||
const TOAST_LIMIT = 1;
|
const TOAST_LIMIT = 1;
|
||||||
const TOAST_REMOVE_DELAY = 1000000;
|
const TOAST_REMOVE_DELAY = 1000000;
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,15 @@ import { buildLlmUuid } from '@/utils/llm-util';
|
||||||
|
|
||||||
export const enum LLMApiAction {
|
export const enum LLMApiAction {
|
||||||
LlmList = 'llmList',
|
LlmList = 'llmList',
|
||||||
|
MyLlmList = 'myLlmList',
|
||||||
|
MyLlmListDetailed = 'myLlmListDetailed',
|
||||||
|
FactoryList = 'factoryList',
|
||||||
|
SaveApiKey = 'saveApiKey',
|
||||||
|
SaveTenantInfo = 'saveTenantInfo',
|
||||||
|
AddLlm = 'addLlm',
|
||||||
|
DeleteLlm = 'deleteLlm',
|
||||||
|
EnableLlm = 'enableLlm',
|
||||||
|
DeleteFactory = 'deleteFactory',
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useFetchLlmList = (modelType?: LlmModelType) => {
|
export const useFetchLlmList = (modelType?: LlmModelType) => {
|
||||||
|
|
@ -177,7 +186,7 @@ export const useComposeLlmOptionsByModelTypes = (
|
||||||
|
|
||||||
export const useFetchLlmFactoryList = (): ResponseGetType<IFactory[]> => {
|
export const useFetchLlmFactoryList = (): ResponseGetType<IFactory[]> => {
|
||||||
const { data, isFetching: loading } = useQuery({
|
const { data, isFetching: loading } = useQuery({
|
||||||
queryKey: ['factoryList'],
|
queryKey: [LLMApiAction.FactoryList],
|
||||||
initialData: [],
|
initialData: [],
|
||||||
gcTime: 0,
|
gcTime: 0,
|
||||||
queryFn: async () => {
|
queryFn: async () => {
|
||||||
|
|
@ -196,7 +205,7 @@ export const useFetchMyLlmList = (): ResponseGetType<
|
||||||
Record<string, IMyLlmValue>
|
Record<string, IMyLlmValue>
|
||||||
> => {
|
> => {
|
||||||
const { data, isFetching: loading } = useQuery({
|
const { data, isFetching: loading } = useQuery({
|
||||||
queryKey: ['myLlmList'],
|
queryKey: [LLMApiAction.MyLlmList],
|
||||||
initialData: {},
|
initialData: {},
|
||||||
gcTime: 0,
|
gcTime: 0,
|
||||||
queryFn: async () => {
|
queryFn: async () => {
|
||||||
|
|
@ -213,7 +222,7 @@ export const useFetchMyLlmListDetailed = (): ResponseGetType<
|
||||||
Record<string, any>
|
Record<string, any>
|
||||||
> => {
|
> => {
|
||||||
const { data, isFetching: loading } = useQuery({
|
const { data, isFetching: loading } = useQuery({
|
||||||
queryKey: ['myLlmListDetailed'],
|
queryKey: [LLMApiAction.MyLlmListDetailed],
|
||||||
initialData: {},
|
initialData: {},
|
||||||
gcTime: 0,
|
gcTime: 0,
|
||||||
queryFn: async () => {
|
queryFn: async () => {
|
||||||
|
|
@ -271,14 +280,16 @@ export const useSaveApiKey = () => {
|
||||||
isPending: loading,
|
isPending: loading,
|
||||||
mutateAsync,
|
mutateAsync,
|
||||||
} = useMutation({
|
} = useMutation({
|
||||||
mutationKey: ['saveApiKey'],
|
mutationKey: [LLMApiAction.SaveApiKey],
|
||||||
mutationFn: async (params: IApiKeySavingParams) => {
|
mutationFn: async (params: IApiKeySavingParams) => {
|
||||||
const { data } = await userService.set_api_key(params);
|
const { data } = await userService.set_api_key(params);
|
||||||
if (data.code === 0) {
|
if (data.code === 0) {
|
||||||
message.success(t('message.modified'));
|
message.success(t('message.modified'));
|
||||||
queryClient.invalidateQueries({ queryKey: ['myLlmList'] });
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] });
|
||||||
queryClient.invalidateQueries({ queryKey: ['myLlmListDetailed'] });
|
queryClient.invalidateQueries({
|
||||||
queryClient.invalidateQueries({ queryKey: ['factoryList'] });
|
queryKey: [LLMApiAction.MyLlmListDetailed],
|
||||||
|
});
|
||||||
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] });
|
||||||
}
|
}
|
||||||
return data.code;
|
return data.code;
|
||||||
},
|
},
|
||||||
|
|
@ -303,7 +314,7 @@ export const useSaveTenantInfo = () => {
|
||||||
isPending: loading,
|
isPending: loading,
|
||||||
mutateAsync,
|
mutateAsync,
|
||||||
} = useMutation({
|
} = useMutation({
|
||||||
mutationKey: ['saveTenantInfo'],
|
mutationKey: [LLMApiAction.SaveTenantInfo],
|
||||||
mutationFn: async (params: ISystemModelSettingSavingParams) => {
|
mutationFn: async (params: ISystemModelSettingSavingParams) => {
|
||||||
const { data } = await userService.set_tenant_info(params);
|
const { data } = await userService.set_tenant_info(params);
|
||||||
if (data.code === 0) {
|
if (data.code === 0) {
|
||||||
|
|
@ -324,13 +335,16 @@ export const useAddLlm = () => {
|
||||||
isPending: loading,
|
isPending: loading,
|
||||||
mutateAsync,
|
mutateAsync,
|
||||||
} = useMutation({
|
} = useMutation({
|
||||||
mutationKey: ['addLlm'],
|
mutationKey: [LLMApiAction.AddLlm],
|
||||||
mutationFn: async (params: IAddLlmRequestBody) => {
|
mutationFn: async (params: IAddLlmRequestBody) => {
|
||||||
const { data } = await userService.add_llm(params);
|
const { data } = await userService.add_llm(params);
|
||||||
if (data.code === 0) {
|
if (data.code === 0) {
|
||||||
queryClient.invalidateQueries({ queryKey: ['myLlmList'] });
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] });
|
||||||
queryClient.invalidateQueries({ queryKey: ['myLlmListDetailed'] });
|
queryClient.invalidateQueries({
|
||||||
queryClient.invalidateQueries({ queryKey: ['factoryList'] });
|
queryKey: [LLMApiAction.MyLlmListDetailed],
|
||||||
|
});
|
||||||
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] });
|
||||||
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.LlmList] });
|
||||||
message.success(t('message.modified'));
|
message.success(t('message.modified'));
|
||||||
}
|
}
|
||||||
return data.code;
|
return data.code;
|
||||||
|
|
@ -348,13 +362,15 @@ export const useDeleteLlm = () => {
|
||||||
isPending: loading,
|
isPending: loading,
|
||||||
mutateAsync,
|
mutateAsync,
|
||||||
} = useMutation({
|
} = useMutation({
|
||||||
mutationKey: ['deleteLlm'],
|
mutationKey: [LLMApiAction.DeleteLlm],
|
||||||
mutationFn: async (params: IDeleteLlmRequestBody) => {
|
mutationFn: async (params: IDeleteLlmRequestBody) => {
|
||||||
const { data } = await userService.delete_llm(params);
|
const { data } = await userService.delete_llm(params);
|
||||||
if (data.code === 0) {
|
if (data.code === 0) {
|
||||||
queryClient.invalidateQueries({ queryKey: ['myLlmList'] });
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] });
|
||||||
queryClient.invalidateQueries({ queryKey: ['myLlmListDetailed'] });
|
queryClient.invalidateQueries({
|
||||||
queryClient.invalidateQueries({ queryKey: ['factoryList'] });
|
queryKey: [LLMApiAction.MyLlmListDetailed],
|
||||||
|
});
|
||||||
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] });
|
||||||
message.success(t('message.deleted'));
|
message.success(t('message.deleted'));
|
||||||
}
|
}
|
||||||
return data.code;
|
return data.code;
|
||||||
|
|
@ -372,7 +388,7 @@ export const useEnableLlm = () => {
|
||||||
isPending: loading,
|
isPending: loading,
|
||||||
mutateAsync,
|
mutateAsync,
|
||||||
} = useMutation({
|
} = useMutation({
|
||||||
mutationKey: ['enableLlm'],
|
mutationKey: [LLMApiAction.EnableLlm],
|
||||||
mutationFn: async (params: IDeleteLlmRequestBody & { enable: boolean }) => {
|
mutationFn: async (params: IDeleteLlmRequestBody & { enable: boolean }) => {
|
||||||
const reqParam: IDeleteLlmRequestBody & {
|
const reqParam: IDeleteLlmRequestBody & {
|
||||||
enable?: boolean;
|
enable?: boolean;
|
||||||
|
|
@ -381,9 +397,11 @@ export const useEnableLlm = () => {
|
||||||
delete reqParam.enable;
|
delete reqParam.enable;
|
||||||
const { data } = await userService.enable_llm(reqParam);
|
const { data } = await userService.enable_llm(reqParam);
|
||||||
if (data.code === 0) {
|
if (data.code === 0) {
|
||||||
queryClient.invalidateQueries({ queryKey: ['myLlmList'] });
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] });
|
||||||
queryClient.invalidateQueries({ queryKey: ['myLlmListDetailed'] });
|
queryClient.invalidateQueries({
|
||||||
queryClient.invalidateQueries({ queryKey: ['factoryList'] });
|
queryKey: [LLMApiAction.MyLlmListDetailed],
|
||||||
|
});
|
||||||
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] });
|
||||||
message.success(t('message.modified'));
|
message.success(t('message.modified'));
|
||||||
}
|
}
|
||||||
return data.code;
|
return data.code;
|
||||||
|
|
@ -401,14 +419,16 @@ export const useDeleteFactory = () => {
|
||||||
isPending: loading,
|
isPending: loading,
|
||||||
mutateAsync,
|
mutateAsync,
|
||||||
} = useMutation({
|
} = useMutation({
|
||||||
mutationKey: ['deleteFactory'],
|
mutationKey: [LLMApiAction.DeleteFactory],
|
||||||
mutationFn: async (params: IDeleteLlmRequestBody) => {
|
mutationFn: async (params: IDeleteLlmRequestBody) => {
|
||||||
const { data } = await userService.deleteFactory(params);
|
const { data } = await userService.deleteFactory(params);
|
||||||
if (data.code === 0) {
|
if (data.code === 0) {
|
||||||
queryClient.invalidateQueries({ queryKey: ['myLlmList'] });
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] });
|
||||||
queryClient.invalidateQueries({ queryKey: ['myLlmListDetailed'] });
|
queryClient.invalidateQueries({
|
||||||
queryClient.invalidateQueries({ queryKey: ['factoryList'] });
|
queryKey: [LLMApiAction.MyLlmListDetailed],
|
||||||
queryClient.invalidateQueries({ queryKey: ['llmList'] });
|
});
|
||||||
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] });
|
||||||
|
queryClient.invalidateQueries({ queryKey: [LLMApiAction.LlmList] });
|
||||||
message.success(t('message.deleted'));
|
message.success(t('message.deleted'));
|
||||||
}
|
}
|
||||||
return data.code;
|
return data.code;
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import api from '@/utils/api';
|
import api from '@/utils/api';
|
||||||
import registerServer from '@/utils/register-server';
|
import { registerNextServer } from '@/utils/register-server';
|
||||||
import request from '@/utils/request';
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
createSearch,
|
createSearch,
|
||||||
|
|
@ -49,6 +48,6 @@ const methods = {
|
||||||
method: 'get',
|
method: 'get',
|
||||||
},
|
},
|
||||||
} as const;
|
} as const;
|
||||||
const searchService = registerServer<keyof typeof methods>(methods, request);
|
const searchService = registerNextServer<keyof typeof methods>(methods);
|
||||||
|
|
||||||
export default searchService;
|
export default searchService;
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue