Merge branch 'main' into feature/OND211-2329-Check-existing-REST-endponts-and-extend-with-new-requested-endpoints

This commit is contained in:
hetavi-bluexkye 2025-12-01 10:59:50 +05:30 committed by GitHub
commit 8a9ada69dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 280 additions and 80 deletions

View file

@ -194,7 +194,7 @@ releases! 🌟
# git checkout v0.22.1
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases)
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
# This step ensures the **entrypoint.sh** file in the code matches the Docker image version.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View file

@ -69,7 +69,7 @@ class CodeExecParam(ToolParamBase):
self.meta: ToolMeta = {
"name": "execute_code",
"description": """
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It recieves a piece of code and return a Json string.
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It receives a piece of code and return a Json string.
Here's a code example for Python(`main` function MUST be included):
def main() -> dict:
\"\"\"

View file

@ -82,6 +82,11 @@ app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response)
# Configure Quart timeouts for slow LLM responses (e.g., local Ollama on CPU)
# Default Quart timeouts are 60 seconds which is too short for many LLM backends
app.config["RESPONSE_TIMEOUT"] = int(os.environ.get("QUART_RESPONSE_TIMEOUT", 600))
app.config["BODY_TIMEOUT"] = int(os.environ.get("QUART_BODY_TIMEOUT", 600))
## convince for dev and debug
# app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False

View file

@ -706,6 +706,7 @@ async def set_meta():
except Exception as e:
return server_error_response(e)
@manager.route("/upload_info", methods=["POST"]) # noqa: F821
async def upload_info():
files = await request.files

View file

@ -676,7 +676,11 @@ Please write the SQL, only SQL, without any other explanations or text.
if kb_ids:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
if "where" not in sql.lower():
sql += f" WHERE {kb_filter}"
o = sql.lower().split("order by")
if len(o) > 1:
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
else:
sql += f" WHERE {kb_filter}"
else:
sql += f" AND {kb_filter}"
@ -684,10 +688,9 @@ Please write the SQL, only SQL, without any other explanations or text.
tried_times += 1
return settings.retriever.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table()
if tbl is None:
return None
if tbl.get("error") and tried_times <= 2:
try:
tbl, sql = get_table()
except Exception as e:
user_prompt = """
Table name: {};
Table of database fields are as follows:
@ -701,16 +704,14 @@ Please write the SQL, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows:
{}
Error issued by database as follows:
{}
Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
tbl, sql = get_table()
logging.debug("TRY it again: {}".format(sql))
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
try:
tbl, sql = get_table()
except Exception:
return
logging.debug("GET table: {}".format(tbl))
if tbl.get("error") or len(tbl["rows"]) == 0:
if len(tbl["rows"]) == 0:
return None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])

View file

@ -14,6 +14,7 @@
# limitations under the License.
#
from collections import Counter
import string
from typing import Annotated, Any, Literal
from uuid import UUID
@ -25,6 +26,7 @@ from pydantic import (
StringConstraints,
ValidationError,
field_validator,
model_validator,
)
from pydantic_core import PydanticCustomError
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
@ -362,10 +364,9 @@ class CreateDatasetReq(Base):
embedding_model: Annotated[str | None, Field(default=None, max_length=255, serialization_alias="embd_id")]
permission: Annotated[Literal["me", "team"], Field(default="me", min_length=1, max_length=16)]
shared_tenant_id: Annotated[str | None, Field(default=None, max_length=32, description="Specific tenant ID to share with when permission is 'team'")]
chunk_method: Annotated[
Literal["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
Field(default="naive", min_length=1, max_length=32, serialization_alias="parser_id"),
]
chunk_method: Annotated[str | None, Field(default=None, serialization_alias="parser_id")]
parse_type: Annotated[int | None, Field(default=None, ge=0, le=64)]
pipeline_id: Annotated[str | None, Field(default=None, min_length=32, max_length=32, serialization_alias="pipeline_id")]
parser_config: Annotated[ParserConfig | None, Field(default=None)]
@field_validator("shared_tenant_id", mode="after")
@ -535,6 +536,93 @@ class CreateDatasetReq(Base):
raise PydanticCustomError("string_too_long", "Parser config exceeds size limit (max 65,535 characters). Current size: {actual}", {"actual": len(json_str)})
return v
@field_validator("pipeline_id", mode="after")
@classmethod
def validate_pipeline_id(cls, v: str | None) -> str | None:
"""Validate pipeline_id as 32-char lowercase hex string if provided.
Rules:
- None or empty string: treat as None (not set)
- Must be exactly length 32
- Must contain only hex digits (0-9a-fA-F); normalized to lowercase
"""
if v is None:
return None
if v == "":
return None
if len(v) != 32:
raise PydanticCustomError("format_invalid", "pipeline_id must be 32 hex characters")
if any(ch not in string.hexdigits for ch in v):
raise PydanticCustomError("format_invalid", "pipeline_id must be hexadecimal")
return v.lower()
@model_validator(mode="after")
def validate_parser_dependency(self) -> "CreateDatasetReq":
"""
Mixed conditional validation:
- If parser_id is omitted (field not set):
* If both parse_type and pipeline_id are omitted default chunk_method = "naive"
* If both parse_type and pipeline_id are provided allow ingestion pipeline mode
- If parser_id is provided (valid enum) parse_type and pipeline_id must be None (disallow mixed usage)
Raises:
PydanticCustomError with code 'dependency_error' on violation.
"""
# Omitted chunk_method (not in fields) logic
if self.chunk_method is None and "chunk_method" not in self.model_fields_set:
# All three absent → default naive
if self.parse_type is None and self.pipeline_id is None:
object.__setattr__(self, "chunk_method", "naive")
return self
# parser_id omitted: require BOTH parse_type & pipeline_id present (no partial allowed)
if self.parse_type is None or self.pipeline_id is None:
missing = []
if self.parse_type is None:
missing.append("parse_type")
if self.pipeline_id is None:
missing.append("pipeline_id")
raise PydanticCustomError(
"dependency_error",
"parser_id omitted → required fields missing: {fields}",
{"fields": ", ".join(missing)},
)
# Both provided → allow pipeline mode
return self
# parser_id provided (valid): MUST NOT have parse_type or pipeline_id
if isinstance(self.chunk_method, str):
if self.parse_type is not None or self.pipeline_id is not None:
invalid = []
if self.parse_type is not None:
invalid.append("parse_type")
if self.pipeline_id is not None:
invalid.append("pipeline_id")
raise PydanticCustomError(
"dependency_error",
"parser_id provided → disallowed fields present: {fields}",
{"fields": ", ".join(invalid)},
)
return self
@field_validator("chunk_method", mode="wrap")
@classmethod
def validate_chunk_method(cls, v: Any, handler) -> Any:
"""Wrap validation to unify error messages, including type errors (e.g. list)."""
allowed = {"naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"}
error_msg = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'"
# Omitted field: handler won't be invoked (wrap still gets value); None treated as explicit invalid
if v is None:
raise PydanticCustomError("literal_error", error_msg)
try:
# Run inner validation (type checking)
result = handler(v)
except Exception:
raise PydanticCustomError("literal_error", error_msg)
# After handler, enforce enumeration
if not isinstance(result, str) or result == "" or result not in allowed:
raise PydanticCustomError("literal_error", error_msg)
return result
class UpdateDatasetReq(CreateDatasetReq):
dataset_id: Annotated[str, Field(...)]

View file

@ -38,6 +38,7 @@ oceanbase:
port: 2881
redis:
db: 1
username: ''
password: 'infini_rag_flow'
host: 'localhost:6379'
task_executor:

View file

@ -190,7 +190,7 @@ class MinerUParser(RAGFlowPdfParser):
self._run_mineru_executable(input_path, output_dir, method, backend, lang, server_url, callback)
def _run_mineru_api(self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, callback: Optional[Callable] = None):
OUTPUT_ZIP_PATH = os.path.join(str(output_dir), "output.zip")
output_zip_path = os.path.join(str(output_dir), "output.zip")
pdf_file_path = str(input_path)
@ -230,16 +230,16 @@ class MinerUParser(RAGFlowPdfParser):
response.raise_for_status()
if response.headers.get("Content-Type") == "application/zip":
self.logger.info(f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
self.logger.info(f"[MinerU] zip file returned, saving to {output_zip_path}...")
if callback:
callback(0.30, f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
callback(0.30, f"[MinerU] zip file returned, saving to {output_zip_path}...")
with open(OUTPUT_ZIP_PATH, "wb") as f:
with open(output_zip_path, "wb") as f:
f.write(response.content)
self.logger.info(f"[MinerU] Unzip to {output_path}...")
self._extract_zip_no_root(OUTPUT_ZIP_PATH, output_path, pdf_file_name + "/")
self._extract_zip_no_root(output_zip_path, output_path, pdf_file_name + "/")
if callback:
callback(0.40, f"[MinerU] Unzip to {output_path}...")
@ -459,13 +459,36 @@ class MinerUParser(RAGFlowPdfParser):
return poss
def _read_output(self, output_dir: Path, file_stem: str, method: str = "auto", backend: str = "pipeline") -> list[dict[str, Any]]:
subdir = output_dir / file_stem / method
if backend.startswith("vlm-"):
subdir = output_dir / file_stem / "vlm"
json_file = subdir / f"{file_stem}_content_list.json"
candidates = []
seen = set()
if not json_file.exists():
raise FileNotFoundError(f"[MinerU] Missing output file: {json_file}")
def add_candidate_path(p: Path):
if p not in seen:
seen.add(p)
candidates.append(p)
if backend.startswith("vlm-"):
add_candidate_path(output_dir / file_stem / "vlm")
if method:
add_candidate_path(output_dir / file_stem / method)
add_candidate_path(output_dir / file_stem / "auto")
else:
if method:
add_candidate_path(output_dir / file_stem / method)
add_candidate_path(output_dir / file_stem / "vlm")
add_candidate_path(output_dir / file_stem / "auto")
json_file = None
subdir = None
for sub in candidates:
jf = sub / f"{file_stem}_content_list.json"
if jf.exists():
subdir = sub
json_file = jf
break
if not json_file:
raise FileNotFoundError(f"[MinerU] Missing output file, tried: {', '.join(str(c / (file_stem + '_content_list.json')) for c in candidates)}")
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
@ -520,7 +543,7 @@ class MinerUParser(RAGFlowPdfParser):
method: str = "auto",
server_url: Optional[str] = None,
delete_output: bool = True,
parse_method: str = "raw"
parse_method: str = "raw",
) -> tuple:
import shutil
@ -570,7 +593,7 @@ class MinerUParser(RAGFlowPdfParser):
self.logger.info(f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
if callback:
callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
return self._transfer_to_sections(outputs, parse_method), self._transfer_to_tables(outputs)
finally:
if temp_pdf and temp_pdf.exists():

View file

@ -38,6 +38,7 @@ oceanbase:
port: ${OCEANBASE_PORT:-2881}
redis:
db: 1
username: '${REDIS_USERNAME:-}'
password: '${REDIS_PASSWORD:-infini_rag_flow}'
host: '${REDIS_HOST:-redis}:6379'
user_default_llm:

View file

@ -89,6 +89,8 @@ RAGFlow utilizes MinIO as its object storage solution, leveraging its scalabilit
- `REDIS_PORT`
The port used to expose the Redis service to the host machine, allowing **external** access to the Redis service running inside the Docker container. Defaults to `6379`.
- `REDIS_USERNAME`
Optional Redis ACL username when using Redis 6+ authentication.
- `REDIS_PASSWORD`
The password for Redis.
@ -160,6 +162,13 @@ If you cannot download the RAGFlow Docker image, try the following mirrors.
- `password`: The password for MinIO.
- `host`: The MinIO serving IP *and* port inside the Docker container. Defaults to `minio:9000`.
### `redis`
- `host`: The Redis serving IP *and* port inside the Docker container. Defaults to `redis:6379`.
- `db`: The Redis database index to use. Defaults to `1`.
- `username`: Optional Redis ACL username (Redis 6+).
- `password`: The password for the specified Redis user.
### `oauth`
The OAuth configuration for signing up or signing in to RAGFlow using a third-party account.

View file

@ -314,35 +314,3 @@ To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the con
3. [Update System Model Settings](#6-update-system-model-settings)
4. [Update Chat Configuration](#7-update-chat-configuration)
## Deploy a local model using jina
To deploy a local model, e.g., **gpt2**, using jina:
### 1. Check firewall settings
Ensure that your host machine's firewall allows inbound connections on port 12345.
```bash
sudo ufw allow 12345/tcp
```
### 2. Install jina package
```bash
pip install jina
```
### 3. Deploy a local model
Step 1: Navigate to the **rag/svr** directory.
```bash
cd rag/svr
```
Step 2: Run **jina_server.py**, specifying either the model's name or its local directory:
```bash
python jina_server.py --model_name gpt2
```
> The script only supports models downloaded from Hugging Face.

View file

@ -419,7 +419,15 @@ Creates a dataset.
- `"embedding_model"`: `string`
- `"permission"`: `string`
- `"chunk_method"`: `string`
- `"parser_config"`: `object`
- "parser_config": `object`
- "parse_type": `int`
- "pipeline_id": `string`
Note: Choose exactly one ingestion mode when creating a dataset.
- Chunking method: provide `"chunk_method"` (optionally with `"parser_config"`).
- Ingestion pipeline: provide both `"parse_type"` and `"pipeline_id"` and do not provide `"chunk_method"`.
These options are mutually exclusive. If all three of `chunk_method`, `parse_type`, and `pipeline_id` are omitted, the system defaults to `chunk_method = "naive"`.
##### Request example
@ -433,6 +441,26 @@ curl --request POST \
}'
```
##### Request example (ingestion pipeline)
Use this form when specifying an ingestion pipeline (do not include `chunk_method`).
```bash
curl --request POST \
--url http://{address}/api/v1/datasets \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer <YOUR_API_KEY>' \
--data '{
"name": "test-sdk",
"parse_type": <NUMBER_OF_FORMATS_IN_PARSE>,
"pipeline_id": "<PIPELINE_ID_32_HEX>"
}'
```
Notes:
- `parse_type` is an integer. Replace `<NUMBER_OF_FORMATS_IN_PARSE>` with your pipeline's parse-type value.
- `pipeline_id` must be a 32-character lowercase hexadecimal string.
##### Request parameters
- `"name"`: (*Body parameter*), `string`, *Required*
@ -473,6 +501,7 @@ curl --request POST \
- `"qa"`: Q&A
- `"table"`: Table
- `"tag"`: Tag
- Mutually exclusive with `parse_type` and `pipeline_id`. If you set `chunk_method`, do not include `parse_type` or `pipeline_id`.
- `"parser_config"`: (*Body parameter*), `object`
The configuration settings for the dataset parser. The attributes in this JSON object vary with the selected `"chunk_method"`:
@ -509,6 +538,15 @@ curl --request POST \
- Defaults to: `{"use_raptor": false}`.
- If `"chunk_method"` is `"table"`, `"picture"`, `"one"`, or `"email"`, `"parser_config"` is an empty JSON object.
- "parse_type": (*Body parameter*), `int`
The ingestion pipeline parse type identifier. Required if and only if you are using an ingestion pipeline (together with `"pipeline_id"`). Must not be provided when `"chunk_method"` is set.
- "pipeline_id": (*Body parameter*), `string`
The ingestion pipeline ID. Required if and only if you are using an ingestion pipeline (together with `"parse_type"`).
- Must not be provided when `"chunk_method"` is set.
Note: If none of `chunk_method`, `parse_type`, and `pipeline_id` are provided, the system will default to `chunk_method = "naive"`.
#### Response
Success:

View file

@ -33,9 +33,9 @@ from openai.lib.azure import AzureOpenAI
from strenum import StrEnum
from zhipuai import ZhipuAI
from common.token_utils import num_tokens_from_string, total_token_count_from_response
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
from rag.nlp import is_chinese, is_english
from common.token_utils import num_tokens_from_string, total_token_count_from_response
# Error message constants
@ -66,7 +66,7 @@ LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to
class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.model_name = model_name
# Configure retry parameters
@ -127,7 +127,7 @@ class Base(ABC):
"tool_choice",
"logprobs",
"top_logprobs",
"extra_headers"
"extra_headers",
}
gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf}
@ -1213,7 +1213,7 @@ class GoogleChat(Base):
# Build GenerateContentConfig
try:
from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part
from google.genai.types import Content, GenerateContentConfig, Part, ThinkingConfig
except ImportError as e:
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
raise
@ -1242,14 +1242,14 @@ class GoogleChat(Base):
role = "model" if item["role"] == "assistant" else item["role"]
content = Content(
role=role,
parts=[Part(text=item["content"])]
parts=[Part(text=item["content"])],
)
contents.append(content)
response = self.client.models.generate_content(
model=self.model_name,
contents=contents,
config=config
config=config,
)
ans = response.text
@ -1299,7 +1299,7 @@ class GoogleChat(Base):
# Build GenerateContentConfig
try:
from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part
from google.genai.types import Content, GenerateContentConfig, Part, ThinkingConfig
except ImportError as e:
logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0")
raise
@ -1326,7 +1326,7 @@ class GoogleChat(Base):
role = "model" if item["role"] == "assistant" else item["role"]
content = Content(
role=role,
parts=[Part(text=item["content"])]
parts=[Part(text=item["content"])],
)
contents.append(content)
@ -1334,7 +1334,7 @@ class GoogleChat(Base):
for chunk in self.client.models.generate_content_stream(
model=self.model_name,
contents=contents,
config=config
config=config,
):
text = chunk.text
ans = text
@ -1406,7 +1406,7 @@ class LiteLLMBase(ABC):
]
def __init__(self, key, model_name, base_url=None, **kwargs):
self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
self.timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
self.provider = kwargs.get("provider", "")
self.prefix = LITELLM_PROVIDER_PREFIX.get(self.provider, "")
self.model_name = f"{self.prefix}{model_name}"
@ -1625,6 +1625,7 @@ class LiteLLMBase(ABC):
if self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order:
def _to_order_list(x):
if x is None:
return []
@ -1633,6 +1634,7 @@ class LiteLLMBase(ABC):
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
extra_body = {}
provider_cfg = {}
provider_order = _to_order_list(self.provider_order)

View file

@ -106,4 +106,4 @@ REMEMBER:
- Each citation supports the ENTIRE sentence
- When in doubt, ask: "Would a fact-checker need to verify this?"
- Place citations at sentence end, before punctuation
- Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be seperated like, [ID:0][ID:5]...
- Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be separated like, [ID:0][ID:5]...

View file

@ -575,9 +575,9 @@ class ESConnection(DocStoreConnection):
time.sleep(3)
self._connect()
continue
except Exception:
logger.exception("ESConnection.sql got exception")
break
except Exception as e:
logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None

View file

@ -86,6 +86,9 @@ class RedisDB:
"db": int(self.config.get("db", 1)),
"decode_responses": True,
}
username = self.config.get("username")
if username:
conn_params["username"] = username
password = self.config.get("password")
if password:
conn_params["password"] = password

View file

@ -22,6 +22,7 @@ import { SharedFrom } from '@/constants/chat';
import {
LanguageAbbreviation,
LanguageAbbreviationMap,
ThemeEnum,
} from '@/constants/common';
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
@ -36,6 +37,7 @@ const FormSchema = z.object({
locale: z.string(),
embedType: z.enum(['fullscreen', 'widget']),
enableStreaming: z.boolean(),
theme: z.enum([ThemeEnum.Light, ThemeEnum.Dark]),
});
type IProps = IModalProps<any> & {
@ -61,6 +63,7 @@ function EmbedDialog({
locale: '',
embedType: 'fullscreen' as const,
enableStreaming: false,
theme: ThemeEnum.Light,
},
});
@ -74,7 +77,7 @@ function EmbedDialog({
}, []);
const generateIframeSrc = useCallback(() => {
const { visibleAvatar, locale, embedType, enableStreaming } = values;
const { visibleAvatar, locale, embedType, enableStreaming, theme } = values;
const baseRoute =
embedType === 'widget'
? Routes.ChatWidget
@ -91,6 +94,9 @@ function EmbedDialog({
if (enableStreaming) {
src += '&streaming=true';
}
if (theme && embedType === 'fullscreen') {
src += `&theme=${theme}`;
}
return src;
}, [beta, from, token, values]);
@ -181,6 +187,41 @@ function EmbedDialog({
</FormItem>
)}
/>
{values.embedType === 'fullscreen' && (
<FormField
control={form.control}
name="theme"
render={({ field }) => (
<FormItem>
<FormLabel>Theme</FormLabel>
<FormControl>
<RadioGroup
onValueChange={field.onChange}
value={field.value}
className="flex flex-row space-x-4"
>
<div className="flex items-center space-x-2">
<RadioGroupItem
value={ThemeEnum.Light}
id="light"
/>
<Label htmlFor="light" className="text-sm">
Light
</Label>
</div>
<div className="flex items-center space-x-2">
<RadioGroupItem value={ThemeEnum.Dark} id="dark" />
<Label htmlFor="dark" className="text-sm">
Dark
</Label>
</div>
</RadioGroup>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
)}
<FormField
control={form.control}
name="visibleAvatar"

View file

@ -71,3 +71,13 @@ export function useSwitchToDarkThemeOnMount() {
setTheme(ThemeEnum.Dark);
}, [setTheme]);
}
export function useSyncThemeFromParams(theme: string | null) {
const { setTheme } = useTheme();
useEffect(() => {
if (theme && (theme === ThemeEnum.Light || theme === ThemeEnum.Dark)) {
setTheme(theme as ThemeEnum);
}
}, [theme, setTheme]);
}

View file

@ -29,6 +29,7 @@ export const useGetSharedChatSearchParams = () => {
from: searchParams.get('from') as SharedFrom,
sharedId: searchParams.get('shared_id'),
locale: searchParams.get('locale'),
theme: searchParams.get('theme'),
data: data,
visibleAvatar: searchParams.get('visible_avatar')
? searchParams.get('visible_avatar') !== '1'

View file

@ -4,6 +4,7 @@ import { NextMessageInput } from '@/components/message-input/next';
import MessageItem from '@/components/next-message-item';
import PdfSheet from '@/components/pdf-drawer';
import { useClickDrawer } from '@/components/pdf-drawer/hooks';
import { useSyncThemeFromParams } from '@/components/theme-provider';
import { MessageType } from '@/constants/chat';
import { useUploadCanvasFileWithProgress } from '@/hooks/use-agent-request';
import { cn } from '@/lib/utils';
@ -25,8 +26,10 @@ const ChatContainer = () => {
const {
sharedId: conversationId,
locale,
theme,
visibleAvatar,
} = useGetSharedChatSearchParams();
useSyncThemeFromParams(theme);
const { visible, hideModal, documentId, selectedChunk, clickDocumentButton } =
useClickDrawer();

View file

@ -33,6 +33,7 @@ export const useGetSharedChatSearchParams = () => {
from: searchParams.get('from') as SharedFrom,
sharedId: searchParams.get('shared_id'),
locale: searchParams.get('locale'),
theme: searchParams.get('theme'),
data: data,
visibleAvatar: searchParams.get('visible_avatar')
? searchParams.get('visible_avatar') !== '1'

View file

@ -3,6 +3,7 @@ import { NextMessageInput } from '@/components/message-input/next';
import MessageItem from '@/components/message-item';
import PdfSheet from '@/components/pdf-drawer';
import { useClickDrawer } from '@/components/pdf-drawer/hooks';
import { useSyncThemeFromParams } from '@/components/theme-provider';
import { MessageType, SharedFrom } from '@/constants/chat';
import { useFetchNextConversationSSE } from '@/hooks/chat-hooks';
import { useFetchFlowSSE } from '@/hooks/flow-hooks';
@ -22,8 +23,10 @@ const ChatContainer = () => {
sharedId: conversationId,
from,
locale,
theme,
visibleAvatar,
} = useGetSharedChatSearchParams();
useSyncThemeFromParams(theme);
const { visible, hideModal, documentId, selectedChunk, clickDocumentButton } =
useClickDrawer();
@ -52,6 +55,7 @@ const ChatContainer = () => {
i18n.changeLanguage(locale);
}
}, [locale, visibleAvatar]);
const { data: avatarData } = useFetchAvatar();
if (!conversationId) {