# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any, NotRequired, Optional, TYPE_CHECKING, TypedDict from .base import Base from .session import Session if TYPE_CHECKING: from ..ragflow import RAGFlow __all__ = 'Chat', class Variable(TypedDict): key: str optional: NotRequired[bool] LLMUpdateMessage = TypedDict('LLMUpdateMessage', { "model_name": NotRequired[str], "temperature": NotRequired[float], "top_p": NotRequired[float], "presence_penalty": NotRequired[float], "frequency penalty": NotRequired[float], }) class PromptUpdateMessage(TypedDict): similarity_threshold: NotRequired[float] keywords_similarity_weight: NotRequired[float] top_n: NotRequired[int] variables: NotRequired[list[Variable]] rerank_model: NotRequired[str] empty_response: NotRequired[str] opener: NotRequired[str] show_quote: NotRequired[bool] prompt: NotRequired[str] class UpdateMessage(TypedDict): name: NotRequired[str] avatar: NotRequired[str] dataset_ids: NotRequired[list[str]] llm: NotRequired[LLMUpdateMessage] prompt: NotRequired[PromptUpdateMessage] class Chat(Base): __slots__ = ( 'id', 'name', 'avatar', 'llm', 'prompt', ) id: str name: str avatar: str llm: "Chat.LLM" prompt: "Chat.Prompt" def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.id = "" self.name = "assistant" self.avatar = "path/to/avatar" self.llm = Chat.LLM(rag, {}) self.prompt = Chat.Prompt(rag, {}) super().__init__(rag, res_dict) class LLM(Base): __slots__ = ( 'model_name', 'temperature', 'top_p', 'presence_penalty', 'frequency_penalty', 'max_tokens', ) model_name: Optional[str] temperature: float top_p: float presence_penalty: float frequency_penalty: float max_tokens: int def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.model_name = None self.temperature = 0.1 self.top_p = 0.3 self.presence_penalty = 0.4 self.frequency_penalty = 0.7 self.max_tokens = 512 super().__init__(rag, res_dict) class Prompt(Base): __slots__ = ( 'similarity_threshold', 'keywords_similarity_weight', 'top_n', 'top_k', 'variables', 'rerank_model', 'empty_response', 'opener', 'show_quote', 'prompt', ) similarity_threshold: float keywords_similarity_weight: float top_n: int top_k: int variables: list[Variable] rerank_model: str empty_response: Optional[str] opener: str show_quote: bool prompt: str def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.similarity_threshold = 0.2 self.keywords_similarity_weight = 0.7 self.top_n = 8 self.top_k = 1024 self.variables = [{"key": "knowledge", "optional": True}] self.rerank_model = "" self.empty_response = None self.opener = "Hi! I'm your assistant. What can I do for you?" self.show_quote = True self.prompt = ( "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. " "Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, " "your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' " "Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base." ) super().__init__(rag, res_dict) def update(self, update_message: UpdateMessage) -> None: res = self.put(f"/chats/{self.id}", update_message) res = res.json() if res.get("code") != 0: raise Exception(res["message"]) def create_session(self, name: str = "New session") -> Session: res = self.post(f"/chats/{self.id}/sessions", {"name": name}) res = res.json() if res.get("code") == 0: return Session(self.rag, res["data"]) raise Exception(res["message"]) def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: Optional[str] = None, name: Optional[str] = None) -> list[Session]: res = self.get(f"/chats/{self.id}/sessions", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) res = res.json() if res.get("code") == 0: result_list = [] for data in res["data"]: result_list.append(Session(self.rag, data)) return result_list raise Exception(res["message"]) def delete_sessions(self, ids: list[str] | None = None) -> None: res = self.rm(f"/chats/{self.id}/sessions", {"ids": ids}) res = res.json() if res.get("code") != 0: raise Exception(res.get("message"))