This commit is contained in:
Mathias Panzenböck 2025-11-20 10:15:00 +08:00 committed by GitHub
commit 53ade8b3d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 484 additions and 84 deletions

View file

@ -14,12 +14,31 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Optional, TYPE_CHECKING
from .base import Base from .base import Base
from .session import Session from .session import Session
if TYPE_CHECKING:
from ..ragflow import RAGFlow
__all__ = 'Agent',
class Agent(Base): class Agent(Base):
def __init__(self, rag, res_dict): __slots__ = (
'id',
'avatar',
'canvas_type',
'description',
'dsl',
)
id: Optional[str]
avatar: Optional[str]
canvas_type: Optional[str]
description: Optional[str]
dsl: Optional["Agent.Dsl"]
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.id = None self.id = None
self.avatar = None self.avatar = None
self.canvas_type = None self.canvas_type = None
@ -28,7 +47,25 @@ class Agent(Base):
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
class Dsl(Base): class Dsl(Base):
def __init__(self, rag, res_dict): __slots__ = (
'answer',
'components',
'graph',
'history',
'messages',
'path',
'reference',
)
# TODO: Proper typing including TypedDict for the dicts. Where is the specification of the DSL?
answer: list[Any]
components: dict[str, Any]
graph: dict[str, Any]
history: list[Any]
messages: list[Any]
path: list[Any]
reference: list[Any]
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.answer = [] self.answer = []
self.components = { self.components = {
"begin": { "begin": {
@ -65,8 +102,8 @@ class Agent(Base):
self.reference = [] self.reference = []
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
# TODO: Proper typing of kwargs. Where are these parameters defined?
def create_session(self, **kwargs) -> Session: def create_session(self, **kwargs: dict[str, Any]) -> Session:
res = self.post(f"/agents/{self.id}/sessions", json=kwargs) res = self.post(f"/agents/{self.id}/sessions", json=kwargs)
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:
@ -75,7 +112,7 @@ class Agent(Base):
def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True,
id: str = None) -> list[Session]: id: Optional[str] = None) -> list[Session]:
res = self.get(f"/agents/{self.id}/sessions", res = self.get(f"/agents/{self.id}/sessions",
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id}) {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id})
res = res.json() res = res.json()
@ -87,7 +124,7 @@ class Agent(Base):
return result_list return result_list
raise Exception(res.get("message")) raise Exception(res.get("message"))
def delete_sessions(self, ids: list[str] | None = None): def delete_sessions(self, ids: list[str] | None = None) -> None:
res = self.rm(f"/agents/{self.id}/sessions", {"ids": ids}) res = self.rm(f"/agents/{self.id}/sessions", {"ids": ids})
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:

View file

@ -14,21 +14,31 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from requests import Response
from requests.sessions import _Files, _Params
from ..ragflow import RAGFlow
class Base: class Base:
def __init__(self, rag, res_dict): __slots__ = 'rag',
rag: "RAGFlow"
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.rag = rag self.rag = rag
self._update_from_dict(rag, res_dict) self._update_from_dict(rag, res_dict)
def _update_from_dict(self, rag, res_dict): def _update_from_dict(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
for k, v in res_dict.items(): for k, v in res_dict.items():
if isinstance(v, dict): if isinstance(v, dict):
self.__dict__[k] = Base(rag, v) setattr(self, k, Base(rag, v))
else: else:
self.__dict__[k] = v setattr(self, k, v)
def to_json(self): def to_json(self) -> dict[str, Any]:
pr = {} pr: dict[str, Any] = {}
for name in dir(self): for name in dir(self):
value = getattr(self, name) value = getattr(self, name)
if not name.startswith("__") and not callable(value) and name != "rag": if not name.startswith("__") and not callable(value) and name != "rag":
@ -38,21 +48,21 @@ class Base:
pr[name] = value pr[name] = value
return pr return pr
def post(self, path, json=None, stream=False, files=None): def post(self, path: str, json: Any=None, stream: bool=False, files: Optional["_Files"]=None) -> "Response":
res = self.rag.post(path, json, stream=stream, files=files) res = self.rag.post(path, json, stream=stream, files=files)
return res return res
def get(self, path, params=None): def get(self, path: str, params: Optional["_Params"]=None) -> "Response":
res = self.rag.get(path, params) res = self.rag.get(path, params)
return res return res
def rm(self, path, json): def rm(self, path: str, json: Any) -> "Response":
res = self.rag.delete(path, json) res = self.rag.delete(path, json)
return res return res
def put(self, path, json): def put(self, path: str, json: Any) -> "Response":
res = self.rag.put(path, json) res = self.rag.put(path, json)
return res return res
def __str__(self): def __str__(self) -> str:
return str(self.to_json()) return str(self.to_json())

View file

@ -15,12 +15,62 @@
# #
from typing import Any, NotRequired, Optional, TYPE_CHECKING, TypedDict
from .base import Base from .base import Base
from .session import Session from .session import Session
if TYPE_CHECKING:
from ..ragflow import RAGFlow
__all__ = 'Chat',
class Variable(TypedDict):
key: str
optional: NotRequired[bool]
LLMUpdateMessage = TypedDict('LLMUpdateMessage', {
"model_name": NotRequired[str],
"temperature": NotRequired[float],
"top_p": NotRequired[float],
"presence_penalty": NotRequired[float],
"frequency penalty": NotRequired[float],
})
class PromptUpdateMessage(TypedDict):
similarity_threshold: NotRequired[float]
keywords_similarity_weight: NotRequired[float]
top_n: NotRequired[int]
variables: NotRequired[list[Variable]]
rerank_model: NotRequired[str]
empty_response: NotRequired[str]
opener: NotRequired[str]
show_quote: NotRequired[bool]
prompt: NotRequired[str]
class UpdateMessage(TypedDict):
name: NotRequired[str]
avatar: NotRequired[str]
dataset_ids: NotRequired[list[str]]
llm: NotRequired[LLMUpdateMessage]
prompt: NotRequired[PromptUpdateMessage]
class Chat(Base): class Chat(Base):
def __init__(self, rag, res_dict): __slots__ = (
'id',
'name',
'avatar',
'llm',
'prompt',
)
id: str
name: str
avatar: str
llm: "Chat.LLM"
prompt: "Chat.Prompt"
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.id = "" self.id = ""
self.name = "assistant" self.name = "assistant"
self.avatar = "path/to/avatar" self.avatar = "path/to/avatar"
@ -29,7 +79,23 @@ class Chat(Base):
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
class LLM(Base): class LLM(Base):
def __init__(self, rag, res_dict): __slots__ = (
'model_name',
'temperature',
'top_p',
'presence_penalty',
'frequency_penalty',
'max_tokens',
)
model_name: Optional[str]
temperature: float
top_p: float
presence_penalty: float
frequency_penalty: float
max_tokens: int
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.model_name = None self.model_name = None
self.temperature = 0.1 self.temperature = 0.1
self.top_p = 0.3 self.top_p = 0.3
@ -39,7 +105,31 @@ class Chat(Base):
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
class Prompt(Base): class Prompt(Base):
def __init__(self, rag, res_dict): __slots__ = (
'similarity_threshold',
'keywords_similarity_weight',
'top_n',
'top_k',
'variables',
'rerank_model',
'empty_response',
'opener',
'show_quote',
'prompt',
)
similarity_threshold: float
keywords_similarity_weight: float
top_n: int
top_k: int
variables: list[Variable]
rerank_model: str
empty_response: Optional[str]
opener: str
show_quote: bool
prompt: str
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.similarity_threshold = 0.2 self.similarity_threshold = 0.2
self.keywords_similarity_weight = 0.7 self.keywords_similarity_weight = 0.7
self.top_n = 8 self.top_n = 8
@ -57,7 +147,7 @@ class Chat(Base):
) )
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def update(self, update_message: dict): def update(self, update_message: UpdateMessage) -> None:
res = self.put(f"/chats/{self.id}", update_message) res = self.put(f"/chats/{self.id}", update_message)
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:
@ -70,7 +160,7 @@ class Chat(Base):
return Session(self.rag, res["data"]) return Session(self.rag, res["data"])
raise Exception(res["message"]) raise Exception(res["message"])
def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str = None, name: str = None) -> list[Session]: def list_sessions(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: Optional[str] = None, name: Optional[str] = None) -> list[Session]:
res = self.get(f"/chats/{self.id}/sessions", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) res = self.get(f"/chats/{self.id}/sessions", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:
@ -80,7 +170,7 @@ class Chat(Base):
return result_list return result_list
raise Exception(res["message"]) raise Exception(res["message"])
def delete_sessions(self, ids: list[str] | None = None): def delete_sessions(self, ids: list[str] | None = None) -> None:
res = self.rm(f"/chats/{self.id}/sessions", {"ids": ids}) res = self.rm(f"/chats/{self.id}/sessions", {"ids": ids})
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:

View file

@ -14,17 +14,72 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, NotRequired, Optional, TYPE_CHECKING, TypedDict
from .base import Base from .base import Base
if TYPE_CHECKING:
from ..ragflow import RAGFlow
__all__ = 'Chunk',
class UpdateMessage(TypedDict):
content: NotRequired[str]
important_keywords: NotRequired[list[str]]
available: NotRequired[bool]
class ChunkUpdateError(Exception): class ChunkUpdateError(Exception):
def __init__(self, code=None, message=None, details=None): __slots__ = (
'code',
'message',
'details',
)
code: Optional[int]
message: Optional[str]
details: Optional[str]
def __init__(self, code: Optional[int]=None, message: Optional[str]=None, details: Optional[str]=None):
self.code = code self.code = code
self.message = message self.message = message
self.details = details self.details = details
super().__init__(message) super().__init__(message)
class Chunk(Base): class Chunk(Base):
def __init__(self, rag, res_dict): __slots__ = (
'id',
'content',
'important_keywords',
'questions',
'create_time',
'create_timestamp',
'dataset_id',
'document_name',
'document_id',
'available',
'similarity',
'vector_similarity',
'term_similarity',
'positions',
'doc_type',
)
id: str
content: str
important_keywords: list[str]
questions: list[str]
create_time: str
create_timestamp: float
dataset_id: Optional[str]
document_name: str
document_id: str
available: bool
similarity: float
vector_similarity: float
term_similarity: float
positions: list[str]
doc_type: str
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.id = "" self.id = ""
self.content = "" self.content = ""
self.important_keywords = [] self.important_keywords = []
@ -42,11 +97,11 @@ class Chunk(Base):
self.positions = [] self.positions = []
self.doc_type = "" self.doc_type = ""
for k in list(res_dict.keys()): for k in list(res_dict.keys()):
if k not in self.__dict__: if not hasattr(self, k):
res_dict.pop(k) res_dict.pop(k)
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def update(self, update_message: dict): def update(self, update_message: UpdateMessage) -> None:
res = self.put(f"/datasets/{self.dataset_id}/documents/{self.document_id}/chunks/{self.id}", update_message) res = self.put(f"/datasets/{self.dataset_id}/documents/{self.document_id}/chunks/{self.id}", update_message)
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:

View file

@ -14,16 +14,68 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Literal, NamedTuple, NotRequired, Optional, TYPE_CHECKING, TypedDict
from .base import Base from .base import Base
from .document import Document from .document import Document, ChunkMethod
if TYPE_CHECKING:
from ..ragflow import RAGFlow
__all__ = 'DataSet',
class DocumentStatus(NamedTuple):
document_id: str
run: str
chunk_count: int
token_count: int
class DocumentParams(TypedDict):
display_name: str
blob: str|bytes
Permission = Literal["me", "team"]
class UpdateMessage(TypedDict):
name: NotRequired[str]
avatar: NotRequired[str]
embedding_model: NotRequired[str]
permission: NotRequired[Permission]
pagerank: NotRequired[int]
chunk_method: NotRequired[ChunkMethod]
class DataSet(Base): class DataSet(Base):
__slots__ = (
'id',
'name',
'avatar',
'tenant_id',
'description',
'embedding_model',
'permission',
'chunk_method',
'parser_config',
'pagerank',
)
class ParserConfig(Base): class ParserConfig(Base):
def __init__(self, rag, res_dict): # TODO: Proper typing of parser config.
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def __init__(self, rag, res_dict): id: str
name: str
avatar: str
tenant_id: Optional[str]
description: str
embedding_model: str
permission: str
chunk_method: str
parser_config: Optional[ParserConfig]
pagerank: int
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.id = "" self.id = ""
self.name = "" self.name = ""
self.avatar = "" self.avatar = ""
@ -37,11 +89,11 @@ class DataSet(Base):
self.parser_config = None self.parser_config = None
self.pagerank = 0 self.pagerank = 0
for k in list(res_dict.keys()): for k in list(res_dict.keys()):
if k not in self.__dict__: if not hasattr(self, k):
res_dict.pop(k) res_dict.pop(k)
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def update(self, update_message: dict): def update(self, update_message: UpdateMessage) -> "DataSet":
res = self.put(f"/datasets/{self.id}", update_message) res = self.put(f"/datasets/{self.id}", update_message)
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:
@ -50,13 +102,13 @@ class DataSet(Base):
self._update_from_dict(self.rag, res.get("data", {})) self._update_from_dict(self.rag, res.get("data", {}))
return self return self
def upload_documents(self, document_list: list[dict]): def upload_documents(self, document_list: list[DocumentParams]) -> list[Document]:
url = f"/datasets/{self.id}/documents" url = f"/datasets/{self.id}/documents"
files = [("file", (ele["display_name"], ele["blob"])) for ele in document_list] files = [("file", (ele["display_name"], ele["blob"])) for ele in document_list]
res = self.post(path=url, json=None, files=files) res = self.post(path=url, json=None, files=files)
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:
doc_list = [] doc_list: list[Document] = []
for doc in res["data"]: for doc in res["data"]:
document = Document(self.rag, doc) document = Document(self.rag, doc)
doc_list.append(document) doc_list.append(document)
@ -74,7 +126,7 @@ class DataSet(Base):
desc: bool = True, desc: bool = True,
create_time_from: int = 0, create_time_from: int = 0,
create_time_to: int = 0, create_time_to: int = 0,
): ) -> list[Document]:
params = { params = {
"id": id, "id": id,
"name": name, "name": name,
@ -88,26 +140,26 @@ class DataSet(Base):
} }
res = self.get(f"/datasets/{self.id}/documents", params=params) res = self.get(f"/datasets/{self.id}/documents", params=params)
res = res.json() res = res.json()
documents = [] documents: list[Document] = []
if res.get("code") == 0: if res.get("code") == 0:
for document in res["data"].get("docs"): for document in res["data"].get("docs"):
documents.append(Document(self.rag, document)) documents.append(Document(self.rag, document))
return documents return documents
raise Exception(res["message"]) raise Exception(res["message"])
def delete_documents(self, ids: list[str] | None = None): def delete_documents(self, ids: list[str] | None = None) -> None:
res = self.rm(f"/datasets/{self.id}/documents", {"ids": ids}) res = self.rm(f"/datasets/{self.id}/documents", {"ids": ids})
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:
raise Exception(res["message"]) raise Exception(res["message"])
def _get_documents_status(self, document_ids): def _get_documents_status(self, document_ids: list[str]) -> list[DocumentStatus]:
import time import time
terminal_states = {"DONE", "FAIL", "CANCEL"} terminal_states = {"DONE", "FAIL", "CANCEL"}
interval_sec = 1 interval_sec = 1
pending = set(document_ids) pending = set(document_ids)
finished = [] finished: list[DocumentStatus] = []
while pending: while pending:
for doc_id in list(pending): for doc_id in list(pending):
def fetch_doc(doc_id: str) -> Document | None: def fetch_doc(doc_id: str) -> Document | None:
@ -120,23 +172,23 @@ class DataSet(Base):
if doc is None: if doc is None:
continue continue
if isinstance(doc.run, str) and doc.run.upper() in terminal_states: if isinstance(doc.run, str) and doc.run.upper() in terminal_states:
finished.append((doc_id, doc.run, doc.chunk_count, doc.token_count)) finished.append(DocumentStatus(doc_id, doc.run, doc.chunk_count, doc.token_count))
pending.discard(doc_id) pending.discard(doc_id)
elif float(doc.progress or 0.0) >= 1.0: elif float(doc.progress or 0.0) >= 1.0:
finished.append((doc_id, "DONE", doc.chunk_count, doc.token_count)) finished.append(DocumentStatus(doc_id, "DONE", doc.chunk_count, doc.token_count))
pending.discard(doc_id) pending.discard(doc_id)
if pending: if pending:
time.sleep(interval_sec) time.sleep(interval_sec)
return finished return finished
def async_parse_documents(self, document_ids): def async_parse_documents(self, document_ids: list[str]) -> None:
res = self.post(f"/datasets/{self.id}/chunks", {"document_ids": document_ids}) res = self.post(f"/datasets/{self.id}/chunks", {"document_ids": document_ids})
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:
raise Exception(res.get("message")) raise Exception(res.get("message"))
def parse_documents(self, document_ids): def parse_documents(self, document_ids: list[str]) -> list[DocumentStatus]:
try: try:
self.async_parse_documents(document_ids) self.async_parse_documents(document_ids)
self._get_documents_status(document_ids) self._get_documents_status(document_ids)
@ -146,7 +198,7 @@ class DataSet(Base):
return self._get_documents_status(document_ids) return self._get_documents_status(document_ids)
def async_cancel_parse_documents(self, document_ids): def async_cancel_parse_documents(self, document_ids: list[str]) -> None:
res = self.rm(f"/datasets/{self.id}/chunks", {"document_ids": document_ids}) res = self.rm(f"/datasets/{self.id}/chunks", {"document_ids": document_ids})
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:

View file

@ -14,18 +14,97 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Literal, NotRequired, Optional, TYPE_CHECKING, TypedDict
import json import json
from .base import Base from .base import Base
from .chunk import Chunk from .chunk import Chunk
if TYPE_CHECKING:
from ..ragflow import RAGFlow
__all__ = 'Document',
ChunkMethod = Literal["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "email"]
LayoutRecognize = Literal["DeepDOC", "Plain Text", "Naive"]
class RaptorParams(TypedDict):
use_raptor: NotRequired[bool]
class GraphragParams(TypedDict):
use_graphrag: NotRequired[bool]
class ParserConfigParams(TypedDict):
filename_embd_weight: NotRequired[int|float]
# chunk_method=naive
chunk_token_num: NotRequired[int]
delimiter: NotRequired[str]
html4excel: NotRequired[bool]
layout_recognize: NotRequired[LayoutRecognize|bool]
# chunk_method=raptor
raptor: NotRequired[RaptorParams]
# chunk_method=knowledge-graph
entity_types: NotRequired[list[str]]
graphrag: NotRequired[GraphragParams]
class UpdateMessage(TypedDict):
display_name: NotRequired[str]
meta_fields: NotRequired[dict[str, Any]]
chunk_method: NotRequired[ChunkMethod]
parser_config: NotRequired[ParserConfigParams]
class Document(Base): class Document(Base):
__slots__ = (
'id',
'name',
'thumbnail',
'dataset_id',
'chunk_method',
'parser_config',
'source_type',
'type',
'created_by',
'progress',
'progress_msg',
'process_begin_at',
'process_duration',
'run',
'status',
'meta_fields',
'blob',
'keywords',
)
class ParserConfig(Base): class ParserConfig(Base):
def __init__(self, rag, res_dict): def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def __init__(self, rag, res_dict): id: str
name: str
thumbnail: Optional[str]
dataset_id: Optional[str]
chunk_method: Optional[str]
parser_config: dict[str, Any]
source_type: str
type: str
created_by: str
progress: float
progress_msg: str
process_begin_at: Optional[str]
process_duration: float
run: str
status: str
meta_fields: dict[str, Any]
blob: bytes
keywords: set[str]
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.id = "" self.id = ""
self.name = "" self.name = ""
self.thumbnail = None self.thumbnail = None
@ -46,11 +125,11 @@ class Document(Base):
self.status = "1" self.status = "1"
self.meta_fields = {} self.meta_fields = {}
for k in list(res_dict.keys()): for k in list(res_dict.keys()):
if k not in self.__dict__: if not hasattr(self, k):
res_dict.pop(k) res_dict.pop(k)
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def update(self, update_message: dict): def update(self, update_message: UpdateMessage) -> "Document":
if "meta_fields" in update_message: if "meta_fields" in update_message:
if not isinstance(update_message["meta_fields"], dict): if not isinstance(update_message["meta_fields"], dict):
raise Exception("meta_fields must be a dictionary") raise Exception("meta_fields must be a dictionary")
@ -69,32 +148,32 @@ class Document(Base):
response = res.json() response = res.json()
actual_keys = set(response.keys()) actual_keys = set(response.keys())
if actual_keys == error_keys: if actual_keys == error_keys:
raise Exception(res.get("message")) raise Exception(response.get("message"))
else: else:
return res.content return res.content
except json.JSONDecodeError: except json.JSONDecodeError:
return res.content return res.content
def list_chunks(self, page=1, page_size=30, keywords="", id=""): def list_chunks(self, page: int=1, page_size: int=30, keywords: str="", id: str="") -> list[Chunk]:
data = {"keywords": keywords, "page": page, "page_size": page_size, "id": id} data = {"keywords": keywords, "page": page, "page_size": page_size, "id": id}
res = self.get(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", data) res = self.get(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", data)
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:
chunks = [] chunks: list[Chunk] = []
for data in res["data"].get("chunks"): for data in res["data"].get("chunks"):
chunk = Chunk(self.rag, data) chunk = Chunk(self.rag, data)
chunks.append(chunk) chunks.append(chunk)
return chunks return chunks
raise Exception(res.get("message")) raise Exception(res.get("message"))
def add_chunk(self, content: str, important_keywords: list[str] = [], questions: list[str] = []): def add_chunk(self, content: str, important_keywords: list[str] = [], questions: list[str] = []) -> Chunk:
res = self.post(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", {"content": content, "important_keywords": important_keywords, "questions": questions}) res = self.post(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", {"content": content, "important_keywords": important_keywords, "questions": questions})
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:
return Chunk(self.rag, res["data"].get("chunk")) return Chunk(self.rag, res["data"].get("chunk"))
raise Exception(res.get("message")) raise Exception(res.get("message"))
def delete_chunks(self, ids: list[str] | None = None): def delete_chunks(self, ids: list[str] | None = None) -> None:
res = self.rm(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", {"chunk_ids": ids}) res = self.rm(f"/datasets/{self.dataset_id}/documents/{self.id}/chunks", {"chunk_ids": ids})
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:

View file

@ -14,12 +14,45 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Literal, Optional, TYPE_CHECKING, TypedDict
import json import json
from .base import Base from .base import Base
if TYPE_CHECKING:
from ..ragflow import RAGFlow
__all__ = 'Session', 'Message'
Role = Literal["assistant", "user"]
SessionType = Literal["chat", "agent"]
class MessageDict(TypedDict):
role: Role
content: str
class UpdateMessage(TypedDict):
name: str
class Session(Base): class Session(Base):
def __init__(self, rag, res_dict): __slots__ = (
'id',
'name',
'messages',
'chat_id',
'agent_id',
'__session_type',
)
id: Optional[str]
name: str
messages: list[MessageDict]
chat_id: Optional[str]
agent_id: Optional[str]
__session_type: SessionType
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.id = None self.id = None
self.name = "New session" self.name = "New session"
self.messages = [{"role": "assistant", "content": "Hi! I am your assistant, can I help you?"}] self.messages = [{"role": "assistant", "content": "Hi! I am your assistant, can I help you?"}]
@ -33,7 +66,8 @@ class Session(Base):
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def ask(self, question="", stream=False, **kwargs): # TODO: Proper typing of kwargs. What are the agent/chat specific kwargs?
def ask(self, question: str="", stream: bool=False, **kwargs):
""" """
Ask a question to the session. If stream=True, yields Message objects as they arrive (SSE streaming). Ask a question to the session. If stream=True, yields Message objects as they arrive (SSE streaming).
If stream=False, returns a single Message object for the final answer. If stream=False, returns a single Message object for the final answer.
@ -79,12 +113,14 @@ class Session(Base):
yield self._structure_answer(json_data["data"]) yield self._structure_answer(json_data["data"])
def _structure_answer(self, json_data): def _structure_answer(self, json_data) -> "Message":
if self.__session_type == "agent": if self.__session_type == "agent":
answer = json_data["data"]["content"] answer = json_data["data"]["content"]
elif self.__session_type == "chat": elif self.__session_type == "chat":
answer = json_data["answer"] answer = json_data["answer"]
reference = json_data.get("reference", {}) else:
raise Exception(f"Unknown session type: {self.__session_type}")
reference = json_data.get("reference")
temp_dict = { temp_dict = {
"content": answer, "content": answer,
"role": "assistant" "role": "assistant"
@ -109,16 +145,41 @@ class Session(Base):
json_data, stream=stream) json_data, stream=stream)
return res return res
def update(self, update_message): def update(self, update_message: UpdateMessage) -> None:
res = self.put(f"/chats/{self.chat_id}/sessions/{self.id}", res = self.put(f"/chats/{self.chat_id}/sessions/{self.id}",
update_message) update_message)
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:
raise Exception(res.get("message")) raise Exception(res.get("message"))
class ChunkDict(TypedDict):
id: str
content: str
image_id: Optional[str]
document_id: str
document_name: str
position: list[str]
dataset_id: str
similarity: float
vector_similarity: float
term_similarity: float
class Message(Base): class Message(Base):
def __init__(self, rag, res_dict): __slots__ = (
'content',
'reference',
'role',
'prompt',
'id',
)
content: str
reference: Optional[list[ChunkDict]]
role: Role
prompt: Optional[str]
id: Optional[str]
def __init__(self, rag: "RAGFlow", res_dict: dict[str, Any]) -> None:
self.content = "Hi! I am your assistant, can I help you?" self.content = "Hi! I am your assistant, can I help you?"
self.reference = None self.reference = None
self.role = "assistant" self.role = "assistant"

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Any, Literal, Optional, TYPE_CHECKING
import requests import requests
@ -22,9 +22,25 @@ from .modules.chat import Chat
from .modules.chunk import Chunk from .modules.chunk import Chunk
from .modules.dataset import DataSet from .modules.dataset import DataSet
if TYPE_CHECKING:
from requests.sessions import _Files, _Params
__all__ = 'RAGFlow',
OrderBy = Literal["create_time", "update_time"]
class RAGFlow: class RAGFlow:
def __init__(self, api_key, base_url, version="v1"): __slots__ = (
'user_key',
'api_url',
'authorization_header',
)
user_key: str
api_url: str
authorization_header: dict[str, str]
def __init__(self, api_key: str, base_url: str, version: str="v1") -> None:
""" """
api_url: http://<host_address>/api/v1 api_url: http://<host_address>/api/v1
""" """
@ -32,19 +48,19 @@ class RAGFlow:
self.api_url = f"{base_url}/api/{version}" self.api_url = f"{base_url}/api/{version}"
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)} self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
def post(self, path, json=None, stream=False, files=None): def post(self, path: str, json: Any=None, stream: bool=False, files: Optional["_Files"]=None) -> requests.Response:
res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files) res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
return res return res
def get(self, path, params=None, json=None): def get(self, path: str, params: Optional["_Params"]=None, json: Any=None) -> requests.Response:
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json) res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
return res return res
def delete(self, path, json): def delete(self, path: str, json: Any=None) -> requests.Response:
res = requests.delete(url=self.api_url + path, json=json, headers=self.authorization_header) res = requests.delete(url=self.api_url + path, json=json, headers=self.authorization_header)
return res return res
def put(self, path, json): def put(self, path: str, json: Any=None) -> requests.Response:
res = requests.put(url=self.api_url + path, json=json, headers=self.authorization_header) res = requests.put(url=self.api_url + path, json=json, headers=self.authorization_header)
return res return res
@ -75,19 +91,19 @@ class RAGFlow:
return DataSet(self, res["data"]) return DataSet(self, res["data"])
raise Exception(res["message"]) raise Exception(res["message"])
def delete_datasets(self, ids: list[str] | None = None): def delete_datasets(self, ids: list[str] | None = None) -> None:
res = self.delete("/datasets", {"ids": ids}) res = self.delete("/datasets", {"ids": ids})
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:
raise Exception(res["message"]) raise Exception(res["message"])
def get_dataset(self, name: str): def get_dataset(self, name: str) -> DataSet:
_list = self.list_datasets(name=name) _list = self.list_datasets(name=name)
if len(_list) > 0: if len(_list) > 0:
return _list[0] return _list[0]
raise Exception("Dataset %s not found" % name) raise Exception("Dataset %s not found" % name)
def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[DataSet]: def list_datasets(self, page: int = 1, page_size: int = 30, orderby: OrderBy = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[DataSet]:
res = self.get( res = self.get(
"/datasets", "/datasets",
{ {
@ -107,7 +123,7 @@ class RAGFlow:
return result_list return result_list
raise Exception(res["message"]) raise Exception(res["message"])
def create_chat(self, name: str, avatar: str = "", dataset_ids=None, llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat: def create_chat(self, name: str, avatar: str = "", dataset_ids: Optional[list[str]]=None, llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat:
if dataset_ids is None: if dataset_ids is None:
dataset_ids = [] dataset_ids = []
dataset_list = [] dataset_list = []
@ -159,7 +175,7 @@ class RAGFlow:
return Chat(self, res["data"]) return Chat(self, res["data"])
raise Exception(res["message"]) raise Exception(res["message"])
def delete_chats(self, ids: list[str] | None = None): def delete_chats(self, ids: list[str] | None = None) -> None:
res = self.delete("/chats", {"ids": ids}) res = self.delete("/chats", {"ids": ids})
res = res.json() res = res.json()
if res.get("code") != 0: if res.get("code") != 0:
@ -187,19 +203,19 @@ class RAGFlow:
def retrieve( def retrieve(
self, self,
dataset_ids, dataset_ids: list[str],
document_ids=None, document_ids: Optional[list[str]]=None,
question="", question: str="",
page=1, page: int=1,
page_size=30, page_size: int=30,
similarity_threshold=0.2, similarity_threshold: float=0.2,
vector_similarity_weight=0.3, vector_similarity_weight: float=0.3,
top_k=1024, top_k: int=1024,
rerank_id: str | None = None, rerank_id: str | None = None,
keyword: bool = False, keyword: bool = False,
cross_languages: list[str]|None = None, cross_languages: list[str]|None = None,
metadata_condition: dict | None = None, metadata_condition: dict[str, Any] | None = None,
): ) -> list[Chunk]:
if document_ids is None: if document_ids is None:
document_ids = [] document_ids = []
data_json = { data_json = {
@ -220,7 +236,7 @@ class RAGFlow:
res = self.post("/retrieval", json=data_json) res = self.post("/retrieval", json=data_json)
res = res.json() res = res.json()
if res.get("code") == 0: if res.get("code") == 0:
chunks = [] chunks: list[Chunk] = []
for chunk_data in res["data"].get("chunks"): for chunk_data in res["data"].get("chunks"):
chunk = Chunk(self, chunk_data) chunk = Chunk(self, chunk_data)
chunks.append(chunk) chunks.append(chunk)
@ -240,7 +256,7 @@ class RAGFlow:
}, },
) )
res = res.json() res = res.json()
result_list = [] result_list: list[Agent] = []
if res.get("code") == 0: if res.get("code") == 0:
for data in res["data"]: for data in res["data"]:
result_list.append(Agent(self, data)) result_list.append(Agent(self, data))