diff --git a/sdk/python/ragflow_sdk/modules/agent.py b/sdk/python/ragflow_sdk/modules/agent.py index 42b97a88e..c6c693177 100644 --- a/sdk/python/ragflow_sdk/modules/agent.py +++ b/sdk/python/ragflow_sdk/modules/agent.py @@ -14,12 +14,31 @@ # limitations under the License. # +from typing import Any, Optional, TYPE_CHECKING from .base import Base from .session import Session +if TYPE_CHECKING: + from ..ragflow import RAGFlow + +__all__ = 'Agent', class Agent(Base): - def __init__(self, rag, res_dict): + __slots__ = ( + 'id', + 'avatar', + 'canvas_type', + 'description', + 'dsl', + ) + + id: Optional[str] + avatar: Optional[str] + canvas_type: Optional[str] + description: Optional[str] + dsl: Optional["Agent.Dsl"] + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.id = None self.avatar = None self.canvas_type = None @@ -28,7 +47,25 @@ class Agent(Base): super().__init__(rag, res_dict) class Dsl(Base): - def __init__(self, rag, res_dict): + __slots__ = ( + 'answer', + 'components', + 'graph', + 'history', + 'messages', + 'path', + 'reference', + ) + # TODO: Proper typing including TypedDict for the dicts. Where is the specification of the DSL? + answer: list[Any] + components: dict[str, Any] + graph: dict[str, Any] + history: list[Any] + messages: list[Any] + path: list[Any] + reference: list[Any] + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.answer = [] self.components = { "begin": { @@ -65,8 +102,8 @@ class Agent(Base): self.reference = [] super().__init__(rag, res_dict) - - def create_session(self, **kwargs) -> Session: + # TODO: Proper typing of kwargs. Where are these parameters defined? + def create_session(self, **kwargs: dict[str, Any]) -> Session: res = self.post(f"/agents/{self.id}/sessions", json=kwargs) res = res.json() if res.get("code") == 0: @@ -75,7 +112,7 @@ class Agent(Base): def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, - id: str = None) -> list[Session]: + id: Optional[str] = None) -> list[Session]: res = self.get(f"/agents/{self.id}/sessions", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id}) res = res.json() @@ -87,8 +124,8 @@ class Agent(Base): return result_list raise Exception(res.get("message")) - def delete_sessions(self, ids: list[str] | None = None): + def delete_sessions(self, ids: list[str] | None = None) -> None: res = self.rm(f"/agents/{self.id}/sessions", {"ids": ids}) res = res.json() if res.get("code") != 0: - raise Exception(res.get("message")) \ No newline at end of file + raise Exception(res.get("message")) diff --git a/sdk/python/ragflow_sdk/modules/base.py b/sdk/python/ragflow_sdk/modules/base.py index 6b958fb8d..875fb4934 100644 --- a/sdk/python/ragflow_sdk/modules/base.py +++ b/sdk/python/ragflow_sdk/modules/base.py @@ -14,21 +14,31 @@ # limitations under the License. # +from typing import Any, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from requests import Response + from requests.sessions import _Files, _Params + from ..ragflow import RAGFlow class Base: - def __init__(self, rag, res_dict): + __slots__ = 'rag', + + rag: "RAGFlow" + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.rag = rag self._update_from_dict(rag, res_dict) - def _update_from_dict(self, rag, res_dict): + def _update_from_dict(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: for k, v in res_dict.items(): if isinstance(v, dict): - self.__dict__[k] = Base(rag, v) + setattr(self, k, Base(rag, v)) else: - self.__dict__[k] = v + setattr(self, k, v) - def to_json(self): - pr = {} + def to_json(self) -> dict[str, Any]: + pr: dict[str, Any] = {} for name in dir(self): value = getattr(self, name) if not name.startswith("__") and not callable(value) and name != "rag": @@ -38,21 +48,21 @@ class Base: pr[name] = value return pr - def post(self, path, json=None, stream=False, files=None): + def post(self, path: str, json: Any=None, stream: bool=False, files: Optional["_Files"]=None) -> "Response": res = self.rag.post(path, json, stream=stream, files=files) return res - def get(self, path, params=None): + def get(self, path: str, params: Optional["_Params"]=None) -> "Response": res = self.rag.get(path, params) return res - def rm(self, path, json): + def rm(self, path: str, json: Any) -> "Response": res = self.rag.delete(path, json) return res - def put(self, path, json): + def put(self, path: str, json: Any) -> "Response": res = self.rag.put(path, json) return res - def __str__(self): + def __str__(self) -> str: return str(self.to_json()) diff --git a/sdk/python/ragflow_sdk/modules/chat.py b/sdk/python/ragflow_sdk/modules/chat.py index 5935b5b70..4a32a1782 100644 --- a/sdk/python/ragflow_sdk/modules/chat.py +++ b/sdk/python/ragflow_sdk/modules/chat.py @@ -15,12 +15,62 @@ # +from typing import Any, NotRequired, Optional, TYPE_CHECKING, TypedDict from .base import Base from .session import Session +if TYPE_CHECKING: + from ..ragflow import RAGFlow + +__all__ = 'Chat', + +class Variable(TypedDict): + key: str + optional: NotRequired[bool] + +LLMUpdateMessage = TypedDict('LLMUpdateMessage', { + "model_name": NotRequired[str], + "temperature": NotRequired[float], + "top_p": NotRequired[float], + "presence_penalty": NotRequired[float], + "frequency penalty": NotRequired[float], +}) + +class PromptUpdateMessage(TypedDict): + similarity_threshold: NotRequired[float] + keywords_similarity_weight: NotRequired[float] + top_n: NotRequired[int] + variables: NotRequired[list[Variable]] + rerank_model: NotRequired[str] + empty_response: NotRequired[str] + opener: NotRequired[str] + show_quote: NotRequired[bool] + prompt: NotRequired[str] + +class UpdateMessage(TypedDict): + name: NotRequired[str] + avatar: NotRequired[str] + dataset_ids: NotRequired[list[str]] + llm: NotRequired[LLMUpdateMessage] + prompt: NotRequired[PromptUpdateMessage] + class Chat(Base): - def __init__(self, rag, res_dict): + __slots__ = ( + 'id', + 'name', + 'avatar', + 'llm', + 'prompt', + ) + + id: str + name: str + avatar: str + llm: "Chat.LLM" + prompt: "Chat.Prompt" + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.id = "" self.name = "assistant" self.avatar = "path/to/avatar" @@ -29,7 +79,23 @@ class Chat(Base): super().__init__(rag, res_dict) class LLM(Base): - def __init__(self, rag, res_dict): + __slots__ = ( + 'model_name', + 'temperature', + 'top_p', + 'presence_penalty', + 'frequency_penalty', + 'max_tokens', + ) + + model_name: Optional[str] + temperature: float + top_p: float + presence_penalty: float + frequency_penalty: float + max_tokens: int + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.model_name = None self.temperature = 0.1 self.top_p = 0.3 @@ -39,7 +105,31 @@ class Chat(Base): super().__init__(rag, res_dict) class Prompt(Base): - def __init__(self, rag, res_dict): + __slots__ = ( + 'similarity_threshold', + 'keywords_similarity_weight', + 'top_n', + 'top_k', + 'variables', + 'rerank_model', + 'empty_response', + 'opener', + 'show_quote', + 'prompt', + ) + + similarity_threshold: float + keywords_similarity_weight: float + top_n: int + top_k: int + variables: list[Variable] + rerank_model: str + empty_response: Optional[str] + opener: str + show_quote: bool + prompt: str + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.similarity_threshold = 0.2 self.keywords_similarity_weight = 0.7 self.top_n = 8 @@ -57,7 +147,7 @@ class Chat(Base): ) super().__init__(rag, res_dict) - def update(self, update_message: dict): + def update(self, update_message: UpdateMessage) -> None: res = self.put(f"/chats/{self.id}", update_message) res = res.json() if res.get("code") != 0: @@ -70,7 +160,7 @@ class Chat(Base): return Session(self.rag, res["data"]) raise Exception(res["message"]) - def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str = None, name: str = None) -> list[Session]: + def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: Optional[str] = None, name: Optional[str] = None) -> list[Session]: res = self.get(f"/chats/{self.id}/sessions", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) res = res.json() if res.get("code") == 0: @@ -80,7 +170,7 @@ class Chat(Base): return result_list raise Exception(res["message"]) - def delete_sessions(self, ids: list[str] | None = None): + def delete_sessions(self, ids: list[str] | None = None) -> None: res = self.rm(f"/chats/{self.id}/sessions", {"ids": ids}) res = res.json() if res.get("code") != 0: diff --git a/sdk/python/ragflow_sdk/modules/chunk.py b/sdk/python/ragflow_sdk/modules/chunk.py index b71314d8a..33e25b709 100644 --- a/sdk/python/ragflow_sdk/modules/chunk.py +++ b/sdk/python/ragflow_sdk/modules/chunk.py @@ -14,17 +14,72 @@ # limitations under the License. # +from typing import Any, NotRequired, Optional, TYPE_CHECKING, TypedDict from .base import Base +if TYPE_CHECKING: + from ..ragflow import RAGFlow + +__all__ = 'Chunk', + +class UpdateMessage(TypedDict): + content: NotRequired[str] + important_keywords: NotRequired[list[str]] + available: NotRequired[bool] + class ChunkUpdateError(Exception): - def __init__(self, code=None, message=None, details=None): + __slots__ = ( + 'code', + 'message', + 'details', + ) + + code: Optional[int] + message: Optional[str] + details: Optional[str] + + def __init__(self, code: Optional[int]=None, message: Optional[str]=None, details: Optional[str]=None): self.code = code self.message = message self.details = details super().__init__(message) class Chunk(Base): - def __init__(self, rag, res_dict): + __slots__ = ( + 'id', + 'content', + 'important_keywords', + 'questions', + 'create_time', + 'create_timestamp', + 'dataset_id', + 'document_name', + 'document_id', + 'available', + 'similarity', + 'vector_similarity', + 'term_similarity', + 'positions', + 'doc_type', + ) + + id: str + content: str + important_keywords: list[str] + questions: list[str] + create_time: str + create_timestamp: float + dataset_id: Optional[str] + document_name: str + document_id: str + available: bool + similarity: float + vector_similarity: float + term_similarity: float + positions: list[str] + doc_type: str + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.id = "" self.content = "" self.important_keywords = [] @@ -42,11 +97,11 @@ class Chunk(Base): self.positions = [] self.doc_type = "" for k in list(res_dict.keys()): - if k not in self.__dict__: + if not hasattr(self, k): res_dict.pop(k) super().__init__(rag, res_dict) - def update(self, update_message: dict): + def update(self, update_message: UpdateMessage) -> None: res = self.put(f"/datasets/{self.dataset_id}/documents/{self.document_id}/chunks/{self.id}", update_message) res = res.json() if res.get("code") != 0: @@ -54,4 +109,4 @@ class Chunk(Base): code=res.get("code"), message=res.get("message"), details=res.get("details") - ) \ No newline at end of file + ) diff --git a/sdk/python/ragflow_sdk/modules/dataset.py b/sdk/python/ragflow_sdk/modules/dataset.py index d2d689da3..e2eddbba3 100644 --- a/sdk/python/ragflow_sdk/modules/dataset.py +++ b/sdk/python/ragflow_sdk/modules/dataset.py @@ -14,16 +14,68 @@ # limitations under the License. # +from typing import Any, Literal, NamedTuple, NotRequired, Optional, TYPE_CHECKING, TypedDict from .base import Base -from .document import Document +from .document import Document, ChunkMethod +if TYPE_CHECKING: + from ..ragflow import RAGFlow + +__all__ = 'DataSet', + +class DocumentStatus(NamedTuple): + document_id: str + run: str + chunk_count: int + token_count: int + + +class DocumentParams(TypedDict): + display_name: str + blob: str|bytes + +Permission = Literal["me", "team"] + +class UpdateMessage(TypedDict): + name: NotRequired[str] + avatar: NotRequired[str] + embedding_model: NotRequired[str] + permission: NotRequired[Permission] + pagerank: NotRequired[int] + chunk_method: NotRequired[ChunkMethod] class DataSet(Base): + __slots__ = ( + 'id', + 'name', + 'avatar', + 'tenant_id', + 'description', + 'embedding_model', + 'permission', + 'chunk_method', + 'parser_config', + 'pagerank', + ) + class ParserConfig(Base): - def __init__(self, rag, res_dict): + # TODO: Proper typing of parser config. + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: super().__init__(rag, res_dict) - def __init__(self, rag, res_dict): + id: str + name: str + avatar: str + tenant_id: Optional[str] + description: str + embedding_model: str + permission: str + chunk_method: str + parser_config: Optional[ParserConfig] + pagerank: int + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.id = "" self.name = "" self.avatar = "" @@ -37,11 +89,11 @@ class DataSet(Base): self.parser_config = None self.pagerank = 0 for k in list(res_dict.keys()): - if k not in self.__dict__: + if not hasattr(self, k): res_dict.pop(k) super().__init__(rag, res_dict) - def update(self, update_message: dict): + def update(self, update_message: UpdateMessage) -> "DataSet": res = self.put(f"/datasets/{self.id}", update_message) res = res.json() if res.get("code") != 0: @@ -50,13 +102,13 @@ class DataSet(Base): self._update_from_dict(self.rag, res.get("data", {})) return self - def upload_documents(self, document_list: list[dict]): + def upload_documents(self, document_list: list[DocumentParams]) -> list[Document]: url = f"/datasets/{self.id}/documents" files = [("file", (ele["display_name"], ele["blob"])) for ele in document_list] res = self.post(path=url, json=None, files=files) res = res.json() if res.get("code") == 0: - doc_list = [] + doc_list: list[Document] = [] for doc in res["data"]: document = Document(self.rag, doc) doc_list.append(document) @@ -74,7 +126,7 @@ class DataSet(Base): desc: bool = True, create_time_from: int = 0, create_time_to: int = 0, - ): + ) -> list[Document]: params = { "id": id, "name": name, @@ -88,26 +140,26 @@ class DataSet(Base): } res = self.get(f"/datasets/{self.id}/documents", params=params) res = res.json() - documents = [] + documents: list[Document] = [] if res.get("code") == 0: for document in res["data"].get("docs"): documents.append(Document(self.rag, document)) return documents raise Exception(res["message"]) - def delete_documents(self, ids: list[str] | None = None): + def delete_documents(self, ids: list[str] | None = None) -> None: res = self.rm(f"/datasets/{self.id}/documents", {"ids": ids}) res = res.json() if res.get("code") != 0: raise Exception(res["message"]) - def _get_documents_status(self, document_ids): + def _get_documents_status(self, document_ids: list[str]) -> list[DocumentStatus]: import time terminal_states = {"DONE", "FAIL", "CANCEL"} interval_sec = 1 pending = set(document_ids) - finished = [] + finished: list[DocumentStatus] = [] while pending: for doc_id in list(pending): def fetch_doc(doc_id: str) -> Document | None: @@ -120,23 +172,23 @@ class DataSet(Base): if doc is None: continue if isinstance(doc.run, str) and doc.run.upper() in terminal_states: - finished.append((doc_id, doc.run, doc.chunk_count, doc.token_count)) + finished.append(DocumentStatus(doc_id, doc.run, doc.chunk_count, doc.token_count)) pending.discard(doc_id) elif float(doc.progress or 0.0) >= 1.0: - finished.append((doc_id, "DONE", doc.chunk_count, doc.token_count)) + finished.append(DocumentStatus(doc_id, "DONE", doc.chunk_count, doc.token_count)) pending.discard(doc_id) if pending: time.sleep(interval_sec) return finished - def async_parse_documents(self, document_ids): + def async_parse_documents(self, document_ids: list[str]) -> None: res = self.post(f"/datasets/{self.id}/chunks", {"document_ids": document_ids}) res = res.json() if res.get("code") != 0: raise Exception(res.get("message")) - def parse_documents(self, document_ids): + def parse_documents(self, document_ids: list[str]) -> list[DocumentStatus]: try: self.async_parse_documents(document_ids) self._get_documents_status(document_ids) @@ -146,7 +198,7 @@ class DataSet(Base): return self._get_documents_status(document_ids) - def async_cancel_parse_documents(self, document_ids): + def async_cancel_parse_documents(self, document_ids: list[str]) -> None: res = self.rm(f"/datasets/{self.id}/chunks", {"document_ids": document_ids}) res = res.json() if res.get("code") != 0: diff --git a/sdk/python/ragflow_sdk/modules/document.py b/sdk/python/ragflow_sdk/modules/document.py index 70c1ac842..fbbfe66de 100644 --- a/sdk/python/ragflow_sdk/modules/document.py +++ b/sdk/python/ragflow_sdk/modules/document.py @@ -14,18 +14,97 @@ # limitations under the License. # +from typing import Any, Literal, NotRequired, Optional, TYPE_CHECKING, TypedDict + import json from .base import Base from .chunk import Chunk +if TYPE_CHECKING: + from ..ragflow import RAGFlow + +__all__ = 'Document', + +ChunkMethod = Literal["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "email"] + +LayoutRecognize = Literal["DeepDOC", "Plain Text", "Naive"] + +class RaptorParams(TypedDict): + use_raptor: NotRequired[bool] + +class GraphragParams(TypedDict): + use_graphrag: NotRequired[bool] + +class ParserConfigParams(TypedDict): + filename_embd_weight: NotRequired[int|float] + + # chunk_method=naive + chunk_token_num: NotRequired[int] + delimiter: NotRequired[str] + html4excel: NotRequired[bool] + layout_recognize: NotRequired[LayoutRecognize|bool] + + # chunk_method=raptor + raptor: NotRequired[RaptorParams] + + # chunk_method=knowledge-graph + entity_types: NotRequired[list[str]] + + graphrag: NotRequired[GraphragParams] + +class UpdateMessage(TypedDict): + display_name: NotRequired[str] + meta_fields: NotRequired[dict[str, Any]] + chunk_method: NotRequired[ChunkMethod] + parser_config: NotRequired[ParserConfigParams] class Document(Base): + __slots__ = ( + 'id', + 'name', + 'thumbnail', + 'dataset_id', + 'chunk_method', + 'parser_config', + 'source_type', + 'type', + 'created_by', + 'progress', + 'progress_msg', + 'process_begin_at', + 'process_duration', + 'run', + 'status', + 'meta_fields', + 'blob', + 'keywords', + ) + class ParserConfig(Base): - def __init__(self, rag, res_dict): + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: super().__init__(rag, res_dict) - def __init__(self, rag, res_dict): + id: str + name: str + thumbnail: Optional[str] + dataset_id: Optional[str] + chunk_method: Optional[str] + parser_config: dict[str, Any] + source_type: str + type: str + created_by: str + progress: float + progress_msg: str + process_begin_at: Optional[str] + process_duration: float + run: str + status: str + meta_fields: dict[str, Any] + blob: bytes + keywords: set[str] + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.id = "" self.name = "" self.thumbnail = None @@ -46,11 +125,11 @@ class Document(Base): self.status = "1" self.meta_fields = {} for k in list(res_dict.keys()): - if k not in self.__dict__: + if not hasattr(self, k): res_dict.pop(k) super().__init__(rag, res_dict) - def update(self, update_message: dict): + def update(self, update_message: UpdateMessage) -> "Document": if "meta_fields" in update_message: if not isinstance(update_message["meta_fields"], dict): raise Exception("meta_fields must be a dictionary") @@ -69,32 +148,32 @@ class Document(Base): response = res.json() actual_keys = set(response.keys()) if actual_keys == error_keys: - raise Exception(res.get("message")) + raise Exception(response.get("message")) else: return res.content except json.JSONDecodeError: return res.content - def list_chunks(self, page=1, page_size=30, keywords="", id=""): + def list_chunks(self, page: int=1, page_size: int=30, keywords: str="", id: str="") -> list[Chunk]: data = {"keywords": keywords, "page": page, "page_size": page_size, "id": id} res = self.get(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", data) res = res.json() if res.get("code") == 0: - chunks = [] + chunks: list[Chunk] = [] for data in res["data"].get("chunks"): chunk = Chunk(self.rag, data) chunks.append(chunk) return chunks raise Exception(res.get("message")) - def add_chunk(self, content: str, important_keywords: list[str] = [], questions: list[str] = []): + def add_chunk(self, content: str, important_keywords: list[str] = [], questions: list[str] = []) -> Chunk: res = self.post(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", {"content": content, "important_keywords": important_keywords, "questions": questions}) res = res.json() if res.get("code") == 0: return Chunk(self.rag, res["data"].get("chunk")) raise Exception(res.get("message")) - def delete_chunks(self, ids: list[str] | None = None): + def delete_chunks(self, ids: list[str] | None = None) -> None: res = self.rm(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", {"chunk_ids": ids}) res = res.json() if res.get("code") != 0: diff --git a/sdk/python/ragflow_sdk/modules/session.py b/sdk/python/ragflow_sdk/modules/session.py index 12141afed..1026b0230 100644 --- a/sdk/python/ragflow_sdk/modules/session.py +++ b/sdk/python/ragflow_sdk/modules/session.py @@ -14,12 +14,45 @@ # limitations under the License. # +from typing import Any, Literal, Optional, TYPE_CHECKING, TypedDict + import json + from .base import Base +if TYPE_CHECKING: + from ..ragflow import RAGFlow + +__all__ = 'Session', 'Message' + +Role = Literal["assistant", "user"] +SessionType = Literal["chat", "agent"] + +class MessageDict(TypedDict): + role: Role + content: str + +class UpdateMessage(TypedDict): + name: str class Session(Base): - def __init__(self, rag, res_dict): + __slots__ = ( + 'id', + 'name', + 'messages', + 'chat_id', + 'agent_id', + '__session_type', + ) + + id: Optional[str] + name: str + messages: list[MessageDict] + chat_id: Optional[str] + agent_id: Optional[str] + __session_type: SessionType + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.id = None self.name = "New session" self.messages = [{"role": "assistant", "content": "Hi! I am your assistant, can I help you?"}] @@ -33,7 +66,8 @@ class Session(Base): super().__init__(rag, res_dict) - def ask(self, question="", stream=False, **kwargs): + # TODO: Proper typing of kwargs. What are the agent/chat specific kwargs? + def ask(self, question: str="", stream: bool=False, **kwargs): """ Ask a question to the session. If stream=True, yields Message objects as they arrive (SSE streaming). If stream=False, returns a single Message object for the final answer. @@ -79,12 +113,14 @@ class Session(Base): yield self._structure_answer(json_data["data"]) - def _structure_answer(self, json_data): + def _structure_answer(self, json_data) -> "Message": if self.__session_type == "agent": answer = json_data["data"]["content"] elif self.__session_type == "chat": answer = json_data["answer"] - reference = json_data.get("reference", {}) + else: + raise Exception(f"Unknown session type: {self.__session_type}") + reference = json_data.get("reference") temp_dict = { "content": answer, "role": "assistant" @@ -109,16 +145,41 @@ class Session(Base): json_data, stream=stream) return res - def update(self, update_message): + def update(self, update_message: UpdateMessage) -> None: res = self.put(f"/chats/{self.chat_id}/sessions/{self.id}", update_message) res = res.json() if res.get("code") != 0: raise Exception(res.get("message")) +class ChunkDict(TypedDict): + id: str + content: str + image_id: Optional[str] + document_id: str + document_name: str + position: list[str] + dataset_id: str + similarity: float + vector_similarity: float + term_similarity: float class Message(Base): - def __init__(self, rag, res_dict): + __slots__ = ( + 'content', + 'reference', + 'role', + 'prompt', + 'id', + ) + + content: str + reference: Optional[list[ChunkDict]] + role: Role + prompt: Optional[str] + id: Optional[str] + + def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None: self.content = "Hi! I am your assistant, can I help you?" self.reference = None self.role = "assistant" diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index f200a6b5c..acde45a4b 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Literal, Optional, TYPE_CHECKING import requests @@ -22,9 +22,25 @@ from .modules.chat import Chat from .modules.chunk import Chunk from .modules.dataset import DataSet +if TYPE_CHECKING: + from requests.sessions import _Files, _Params + +__all__ = 'RAGFlow', + +OrderBy = Literal["create_time", "update_time"] class RAGFlow: - def __init__(self, api_key, base_url, version="v1"): + __slots__ = ( + 'user_key', + 'api_url', + 'authorization_header', + ) + + user_key: str + api_url: str + authorization_header: dict[str, str] + + def __init__(self, api_key: str, base_url: str, version: str="v1") -> None: """ api_url: http:///api/v1 """ @@ -32,19 +48,19 @@ class RAGFlow: self.api_url = f"{base_url}/api/{version}" self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)} - def post(self, path, json=None, stream=False, files=None): + def post(self, path: str, json: Any=None, stream: bool=False, files: Optional["_Files"]=None) -> requests.Response: res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files) return res - def get(self, path, params=None, json=None): + def get(self, path: str, params: Optional["_Params"]=None, json: Any=None) -> requests.Response: res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json) return res - def delete(self, path, json): + def delete(self, path: str, json: Any=None) -> requests.Response: res = requests.delete(url=self.api_url + path, json=json, headers=self.authorization_header) return res - def put(self, path, json): + def put(self, path: str, json: Any=None) -> requests.Response: res = requests.put(url=self.api_url + path, json=json, headers=self.authorization_header) return res @@ -75,19 +91,19 @@ class RAGFlow: return DataSet(self, res["data"]) raise Exception(res["message"]) - def delete_datasets(self, ids: list[str] | None = None): + def delete_datasets(self, ids: list[str] | None = None) -> None: res = self.delete("/datasets", {"ids": ids}) res = res.json() if res.get("code") != 0: raise Exception(res["message"]) - def get_dataset(self, name: str): + def get_dataset(self, name: str) -> DataSet: _list = self.list_datasets(name=name) if len(_list) > 0: return _list[0] raise Exception("Dataset %s not found" % name) - def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[DataSet]: + def list_datasets(self, page: int = 1, page_size: int = 30, orderby: OrderBy = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[DataSet]: res = self.get( "/datasets", { @@ -107,7 +123,7 @@ class RAGFlow: return result_list raise Exception(res["message"]) - def create_chat(self, name: str, avatar: str = "", dataset_ids=None, llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat: + def create_chat(self, name: str, avatar: str = "", dataset_ids: Optional[list[str]]=None, llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat: if dataset_ids is None: dataset_ids = [] dataset_list = [] @@ -159,7 +175,7 @@ class RAGFlow: return Chat(self, res["data"]) raise Exception(res["message"]) - def delete_chats(self, ids: list[str] | None = None): + def delete_chats(self, ids: list[str] | None = None) -> None: res = self.delete("/chats", {"ids": ids}) res = res.json() if res.get("code") != 0: @@ -187,19 +203,19 @@ class RAGFlow: def retrieve( self, - dataset_ids, - document_ids=None, - question="", - page=1, - page_size=30, - similarity_threshold=0.2, - vector_similarity_weight=0.3, - top_k=1024, + dataset_ids: list[str], + document_ids: Optional[list[str]]=None, + question: str="", + page: int=1, + page_size: int=30, + similarity_threshold: float=0.2, + vector_similarity_weight: float=0.3, + top_k: int=1024, rerank_id: str | None = None, keyword: bool = False, cross_languages: list[str]|None = None, - metadata_condition: dict | None = None, - ): + metadata_condition: dict[str, Any] | None = None, + ) -> list[Chunk]: if document_ids is None: document_ids = [] data_json = { @@ -220,7 +236,7 @@ class RAGFlow: res = self.post("/retrieval", json=data_json) res = res.json() if res.get("code") == 0: - chunks = [] + chunks: list[Chunk] = [] for chunk_data in res["data"].get("chunks"): chunk = Chunk(self, chunk_data) chunks.append(chunk) @@ -240,7 +256,7 @@ class RAGFlow: }, ) res = res.json() - result_list = [] + result_list: list[Agent] = [] if res.get("code") == 0: for data in res["data"]: result_list.append(Agent(self, data))