This commit is contained in:
Kevin Hu 2025-11-14 15:35:08 +08:00
commit ae8bc5d376
47 changed files with 1508 additions and 444 deletions

View file

@ -95,6 +95,38 @@ jobs:
version: ">=0.11.x"
args: "check"
- name: Check comments of changed Python files
if: ${{ false }}
run: |
if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
| grep -E '\.(py)$' || true)
if [ -n "$CHANGED_FILES" ]; then
echo "Check comments of changed Python files with check_comment_ascii.py"
readarray -t files <<< "$CHANGED_FILES"
HAS_ERROR=0
for file in "${files[@]}"; do
if [ -f "$file" ]; then
if python3 check_comment_ascii.py "$file"; then
echo "✅ $file"
else
echo "❌ $file"
HAS_ERROR=1
fi
fi
done
if [ $HAS_ERROR -ne 0 ]; then
exit 1
fi
else
echo "No Python files changed"
fi
fi
- name: Build ragflow:nightly
run: |
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}

View file

@ -4,7 +4,7 @@
Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently.
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Infinity, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents.

View file

@ -169,7 +169,7 @@ def login_verify(f):
username = auth.parameters['username']
password = auth.parameters['password']
try:
if check_admin(username, password) is False:
if not check_admin(username, password):
return jsonify({
"code": 500,
"message": "Access denied",

View file

@ -25,8 +25,21 @@ from common.config_utils import read_config
from urllib.parse import urlparse
class BaseConfig(BaseModel):
id: int
name: str
host: str
port: int
service_type: str
detail_func_name: str
def to_dict(self) -> dict[str, Any]:
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
'service_type': self.service_type}
class ServiceConfigs:
configs = dict
configs = list[BaseConfig]
def __init__(self):
self.configs = []
@ -45,19 +58,6 @@ class ServiceType(Enum):
FILE_STORE = "file_store"
class BaseConfig(BaseModel):
id: int
name: str
host: str
port: int
service_type: str
detail_func_name: str
def to_dict(self) -> dict[str, Any]:
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
'service_type': self.service_type}
class MetaConfig(BaseConfig):
meta_type: str
@ -227,7 +227,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
ragflow_count = 0
id_count = 0
for k, v in raw_configs.items():
match (k):
match k:
case "ragflow":
name: str = f'ragflow_{ragflow_count}'
host: str = v['host']

View file

@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
from werkzeug.security import check_password_hash
from common.constants import ActiveEnum
@ -190,7 +189,8 @@ class ServiceMgr:
config_dict['status'] = service_detail['status']
else:
config_dict['status'] = 'timeout'
except Exception:
except Exception as e:
logging.warning(f"Can't get service details, error: {e}")
config_dict['status'] = 'timeout'
if not config_dict['host']:
config_dict['host'] = '-'
@ -205,17 +205,13 @@ class ServiceMgr:
@staticmethod
def get_service_details(service_id: int):
service_id = int(service_id)
service_idx = int(service_id)
configs = SERVICE_CONFIGS.configs
service_config_mapping = {
c.id: {
'name': c.name,
'detail_func_name': c.detail_func_name
} for c in configs
}
service_info = service_config_mapping.get(service_id, {})
if not service_info:
raise AdminException(f"invalid service_id: {service_id}")
if service_idx < 0 or service_idx >= len(configs):
raise AdminException(f"invalid service_index: {service_idx}")
service_config = configs[service_idx]
service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
detail_func = getattr(health_utils, service_info.get('detail_func_name'))
res = detail_func()

View file

@ -83,10 +83,10 @@
"value": []
}
},
"password": "20010812Yy!",
"password": "",
"port": 3306,
"sql": "{Agent:WickedGoatsDivide@content}",
"username": "13637682833@163.com"
"username": ""
}
},
"upstream": [
@ -527,10 +527,10 @@
"value": []
}
},
"password": "20010812Yy!",
"password": "",
"port": 3306,
"sql": "{Agent:WickedGoatsDivide@content}",
"username": "13637682833@163.com"
"username": ""
},
"label": "ExeSQL",
"name": "ExeSQL"

View file

@ -224,12 +224,12 @@ def logout_user():
return True
def search_pages_path(pages_dir):
def search_pages_path(page_path):
app_path_list = [
path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".")
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
]
api_path_list = [
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
]
app_path_list.extend(api_path_list)
return app_path_list
@ -266,10 +266,12 @@ pages_dir = [
]
client_urls_prefix = [
register_page(path) for dir in pages_dir for path in search_pages_path(dir)
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
]
@app.teardown_request
def _db_close(exc):
def _db_close(exception):
if exception:
logging.exception(f"Request failed: {exception}")
close_connection()

View file

@ -424,7 +424,6 @@ async def test_db_connect():
try:
import trino
import os
from trino.auth import BasicAuthentication
except Exception as e:
return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
@ -436,7 +435,7 @@ async def test_db_connect():
auth = None
if http_scheme == "https" and req.get("password"):
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
conn = trino.dbapi.connect(
host=req["host"],
@ -469,8 +468,8 @@ async def test_db_connect():
@login_required
def getlistversion(canvas_id):
try:
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
return get_json_result(data=list)
versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
return get_json_result(data=versions)
except Exception as e:
return get_data_error_result(message=f"Error getting history files: {e}")

View file

@ -57,7 +57,6 @@ async def set_connector():
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
"status": TaskStatus.SCHEDULE,
}
conn["status"] = TaskStatus.SCHEDULE
ConnectorService.save(**conn)
await trio.sleep(1)

View file

@ -85,7 +85,6 @@ def get():
if not e:
return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
avatar = None
for tenant in tenants:
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
if dialog and len(dialog) > 0:

View file

@ -154,15 +154,15 @@ def get_kb_names(kb_ids):
@login_required
def list_dialogs():
try:
diags = DialogService.query(
conversations = DialogService.query(
tenant_id=current_user.id,
status=StatusEnum.VALID.value,
reverse=True,
order_by=DialogService.model.create_time)
diags = [d.to_dict() for d in diags]
for d in diags:
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
return get_json_result(data=diags)
conversations = [d.to_dict() for d in conversations]
for conversation in conversations:
conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"])
return get_json_result(data=conversations)
except Exception as e:
return server_error_response(e)

View file

@ -306,7 +306,7 @@ async def get_filter():
@manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required
async def docinfos():
async def doc_infos():
req = await request.json
doc_ids = req["doc_ids"]
for doc_id in doc_ids:
@ -542,6 +542,7 @@ async def change_parser():
return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return None
try:
if "pipeline_id" in req and req["pipeline_id"] != "":

View file

@ -244,8 +244,8 @@ async def rm():
try:
if file.location:
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
except Exception as e:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
informs = File2DocumentService.get_by_file_id(file.id)
for inform in informs:

View file

@ -16,6 +16,7 @@
import json
import logging
import random
import re
from quart import request
import numpy as np
@ -733,6 +734,8 @@ def delete_kb_task():
def cancel_task(task_id):
REDIS_CONN.set(f"{task_id}-cancel", "x")
kb_task_id_field: str = ""
kb_task_finish_at: str = ""
match pipeline_task_type:
case PipelineTaskType.GRAPH_RAG:
kb_task_id_field = "graphrag_task_id"
@ -847,8 +850,13 @@ async def check_embedding():
"position_int": full_doc.get("position_int"),
"top_int": full_doc.get("top_int"),
"content_with_weight": full_doc.get("content_with_weight") or "",
"question_kwd": full_doc.get("question_kwd") or []
})
return out
def _clean(s: str) -> str:
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
return s if s else "None"
req = await request.json
kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "")
@ -861,8 +869,10 @@ async def check_embedding():
results, eff_sims = [], []
for ck in samples:
txt = (ck.get("content_with_weight") or "").strip()
if not txt:
title = ck.get("doc_name") or "Title"
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
txt_in = _clean(txt_in)
if not txt_in:
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
continue
@ -871,8 +881,16 @@ async def check_embedding():
continue
try:
qv, _ = emb_mdl.encode_queries(txt)
sim = _cos_sim(qv, ck["vector"])
v, _ = emb_mdl.encode([title, txt_in])
sim_content = _cos_sim(v[1], ck["vector"])
title_w = 0.1
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
sim_mix = _cos_sim(qv_mix, ck["vector"])
sim = sim_content
mode = "content_only"
if sim_mix > sim:
sim = sim_mix
mode = "title+content"
except Exception:
return get_error_data_result(message="embedding failure")
@ -894,8 +912,9 @@ async def check_embedding():
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
"match_mode": mode,
}
if summary["avg_cos_sim"] > 0.99:
if summary["avg_cos_sim"] > 0.9:
return get_json_result(data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})

View file

@ -21,10 +21,11 @@ import json
from quart import request
from peewee import OperationalError
from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
from api.db.services.user_service import TenantService
from common.constants import RetCode, FileSource, StatusEnum
from api.utils.api_utils import (
@ -118,7 +119,6 @@ async def create(tenant_id):
req, err = await validate_and_parse_json_request(request, CreateDatasetReq)
if err is not None:
return get_error_argument_result(err)
req = KnowledgebaseService.create_with_name(
name = req.pop("name", None),
tenant_id = tenant_id,
@ -144,7 +144,6 @@ async def create(tenant_id):
ok, k = KnowledgebaseService.get_by_id(req["id"])
if not ok:
return get_error_data_result(message="Dataset created failed")
response_data = remap_dictionary_keys(k.to_dict())
return get_result(data=response_data)
except Exception as e:
@ -532,3 +531,157 @@ def delete_knowledge_graph(tenant_id, dataset_id):
search.index_name(kb.tenant_id), dataset_id)
return get_result(data=True)
@manager.route("/datasets/<dataset_id>/run_graphrag", methods=["POST"]) # noqa: F821
@token_required
def run_graphrag(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.graphrag_task_id
if task_id:
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
documents, _ = DocumentService.get_by_kb_id(
kb_id=dataset_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
if not documents:
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}")
return get_result(data={"graphrag_task_id": task_id})
@manager.route("/datasets/<dataset_id>/trace_graphrag", methods=["GET"]) # noqa: F821
@token_required
def trace_graphrag(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.graphrag_task_id
if not task_id:
return get_result(data={})
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_result(data={})
return get_result(data=task.to_dict())
@manager.route("/datasets/<dataset_id>/run_raptor", methods=["POST"]) # noqa: F821
@token_required
def run_raptor(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.raptor_task_id
if task_id:
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
documents, _ = DocumentService.get_by_kb_id(
kb_id=dataset_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
if not documents:
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}")
return get_result(data={"raptor_task_id": task_id})
@manager.route("/datasets/<dataset_id>/trace_raptor", methods=["GET"]) # noqa: F821
@token_required
def trace_raptor(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.raptor_task_id
if not task_id:
return get_result(data={})
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
return get_result(data=task.to_dict())

48
check_comment_ascii.py Normal file
View file

@ -0,0 +1,48 @@
#!/usr/bin/env python3
"""
Check whether given python files contain non-ASCII comments.
How to check the whole git repo:
```
$ git ls-files -z -- '*.py' | xargs -0 python3 check_comment_ascii.py
```
"""
import sys
import tokenize
import ast
import pathlib
import re
ASCII = re.compile(r"^[\n -~]*\Z") # Printable ASCII + newline
def check(src: str, name: str) -> int:
"""
docstring line 1
docstring line 2
"""
ok = 1
# A common comment begins with `#`
with tokenize.open(src) as fp:
for tk in tokenize.generate_tokens(fp.readline):
if tk.type == tokenize.COMMENT and not ASCII.fullmatch(tk.string):
print(f"{name}:{tk.start[0]}: non-ASCII comment: {tk.string}")
ok = 0
# A docstring begins and ends with `'''`
for node in ast.walk(ast.parse(pathlib.Path(src).read_text(), filename=name)):
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
if (doc := ast.get_docstring(node)) and not ASCII.fullmatch(doc):
print(f"{name}:{node.lineno}: non-ASCII docstring: {doc}")
ok = 0
return ok
if __name__ == "__main__":
status = 0
for file in sys.argv[1:]:
if not check(file, file):
status = 1
sys.exit(status)

View file

@ -3,15 +3,9 @@ import os
import threading
from typing import Any, Callable
import requests
from common.data_source.config import DocumentSource
from common.data_source.google_util.constant import GOOGLE_SCOPES
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
DEFAULT_DEVICE_INTERVAL = 5
def _get_requested_scopes(source: DocumentSource) -> list[str]:
"""Return the scopes to request, honoring an optional override env var."""
@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag
return result.get("value")
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
if "client_id" in credentials:
return credentials["client_id"], credentials.get("client_secret")
for key in ("installed", "web"):
if key in credentials and isinstance(credentials[key], dict):
nested = credentials[key]
if "client_id" not in nested:
break
return nested["client_id"], nested.get("client_secret")
raise ValueError("Provided Google OAuth credentials are missing client_id.")
def start_device_authorization_flow(
credentials: dict[str, Any],
source: DocumentSource,
) -> tuple[dict[str, Any], dict[str, Any]]:
client_id, client_secret = _extract_client_info(credentials)
data = {
"client_id": client_id,
"scope": " ".join(_get_requested_scopes(source)),
}
if client_secret:
data["client_secret"] = client_secret
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
resp.raise_for_status()
payload = resp.json()
state = {
"client_id": client_id,
"client_secret": client_secret,
"device_code": payload.get("device_code"),
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
}
response_data = {
"user_code": payload.get("user_code"),
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
"verification_url_complete": payload.get("verification_url_complete")
or payload.get("verification_uri_complete"),
"expires_in": payload.get("expires_in"),
"interval": state["interval"],
}
return state, response_data
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
data = {
"client_id": state["client_id"],
"device_code": state["device_code"],
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
}
if state.get("client_secret"):
data["client_secret"] = state["client_secret"]
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
resp.raise_for_status()
return resp.json()
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
port = int(preferred_port) if preferred_port else 0
timeout_secs = _get_oauth_timeout_secs()
timeout_message = (
f"Google OAuth verification timed out after {timeout_secs} seconds. "
"Close any pending consent windows and rerun the connector configuration to try again."
)
timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again."
print("Launching Google OAuth flow. A browser window should open shortly.")
print("If it does not, copy the URL shown in the console into your browser manually.")
@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
instructions = [
"Google rejected one or more of the requested OAuth scopes.",
"Fix options:",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
" (be aware the connector may lose functionality).",
]
raise RuntimeError("\n".join(instructions)) from warning
raise
@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource)
client_config = {"web": credentials["web"]}
if client_config is None:
raise ValueError(
"Provided Google OAuth credentials are missing both tokens and a client configuration."
)
raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.")
return _run_local_server_flow(client_config, source)

View file

@ -0,0 +1,8 @@
{
"label": "Add data source",
"position": 18,
"link": {
"type": "generated-index",
"description": "Add various data sources"
}
}

View file

@ -0,0 +1,137 @@
---
sidebar_position: 3
slug: /add_google_drive
---
# Add Google Drive
## 1. Create a Google Cloud Project
You can either create a dedicated project for RAGFlow or use an existing
Google Cloud external project.
**Steps:**
1. Open the project creation page\
`https://console.cloud.google.com/projectcreate`
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image1.jpeg?raw=true)
2. Select **External** as the Audience
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image2.png?raw=true)
3. Click **Create**
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image3.jpeg?raw=true)
------------------------------------------------------------------------
## 2. Configure OAuth Consent Screen
1. Go to **APIs & Services → OAuth consent screen**
2. Ensure **User Type = External**
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image4.jpeg?raw=true)
3. Add your test users under **Test Users** by entering email addresses
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image5.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image6.jpeg?raw=true)
------------------------------------------------------------------------
## 3. Create OAuth Client Credentials
1. Navigate to:\
`https://console.cloud.google.com/auth/clients`
2. Create a **Web Application**
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image7.png?raw=true)
3. Enter a name for the client
4. Add the following **Authorized Redirect URIs**:
```
http://localhost:9380/v1/connector/google-drive/oauth/web/callback
```
### If using Docker deployment:
**Authorized JavaScript origin:**
```
http://localhost:80
```
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image8.png?raw=true)
### If running from source:
**Authorized JavaScript origin:**
```
http://localhost:9222
```
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image9.png?raw=true)
5. After saving, click **Download JSON**. This file will later be
uploaded into RAGFlow.
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image10.png?raw=true)
------------------------------------------------------------------------
## 4. Add Scopes
1. Open **Data Access → Add or remove scopes**
2. Paste and add the following entries:
```
https://www.googleapis.com/auth/drive.readonly
https://www.googleapis.com/auth/drive.metadata.readonly
https://www.googleapis.com/auth/admin.directory.group.readonly
https://www.googleapis.com/auth/admin.directory.user.readonly
```
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image11.jpeg?raw=true)
3. Update and Save changes
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image12.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image13.jpeg?raw=true)
------------------------------------------------------------------------
## 5. Enable Required APIs
Navigate to the Google API Library:\
`https://console.cloud.google.com/apis/library`
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image14.png?raw=true)
Enable the following APIs:
- Google Drive API
- Admin SDK API
- Google Sheets API
- Google Docs API
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image15.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image16.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image17.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image18.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image19.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image21.png?raw=true)
------------------------------------------------------------------------
## 6. Add Google Drive As a Data Source in RAGFlow
1. Go to **Data Sources** inside RAGFlow
2. Select **Google Drive**
3. Upload the previously downloaded JSON credentials
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image22.jpeg?raw=true)
4. Enter the shared Google Drive folder link (https://drive.google.com/drive), such as:
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image23.png?raw=true)
5. Click **Authorize with Google**
A browser window will appear.
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image25.jpeg?raw=true)
Click: - **Continue** - **Select All → Continue** - Authorization should
succeed - Select **OK** to add the data source
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image26.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image27.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image28.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image29.png?raw=true)

View file

@ -1,6 +1,6 @@
{
"label": "Best practices",
"position": 11,
"position": 19,
"link": {
"type": "generated-index",
"description": "Best practices on configuring a dataset."

View file

@ -64,7 +64,10 @@ The Admin CLI and Admin Service form a client-server architectural suite for RAG
- -p: RAGFlow admin server port
## Default administrative account
- Username: admin@ragflow.io
- Password: admin
## Supported Commands

View file

@ -974,6 +974,237 @@ Failure:
---
### Construct knowledge graph
**POST** `/api/v1/datasets/{dataset_id}/run_graphrag`
Constructs a knowledge graph from a specified dataset.
#### Request
- Method: POST
- URL: `/api/v1/datasets/{dataset_id}/run_graphrag`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request POST \
--url http://{address}/api/v1/datasets/{dataset_id}/run_graphrag \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"graphrag_task_id":"e498de54bfbb11f0ba028f704583b57b"
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
### Get knowledge graph construction status
**GET** `/api/v1/datasets/{dataset_id}/trace_graphrag`
Retrieves the knowledge graph construction status for a specified dataset.
#### Request
- Method: GET
- URL: `/api/v1/datasets/{dataset_id}/trace_graphrag`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request GET \
--url http://{address}/api/v1/datasets/{dataset_id}/trace_graphrag \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"begin_at":"Wed, 12 Nov 2025 19:36:56 GMT",
"chunk_ids":"",
"create_date":"Wed, 12 Nov 2025 19:36:56 GMT",
"create_time":1762947416350,
"digest":"39e43572e3dcd84f",
"doc_id":"44661c10bde211f0bc93c164a47ffc40",
"from_page":100000000,
"id":"e498de54bfbb11f0ba028f704583b57b",
"priority":0,
"process_duration":2.45419,
"progress":1.0,
"progress_msg":"19:36:56 created task graphrag\n19:36:57 Task has been received.\n19:36:58 [GraphRAG] doc:083661febe2411f0bc79456921e5745f has no available chunks, skip generation.\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 start (chunks=1, timeout=10000000000s)\n19:36:58 Graph already contains 44661c10bde211f0bc93c164a47ffc40\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 empty\n19:36:58 [GraphRAG] kb:33137ed0bde211f0bc93c164a47ffc40 no subgraphs generated successfully, end.\n19:36:58 Knowledge Graph done (0.72s)","retry_count":1,
"task_type":"graphrag",
"to_page":100000000,
"update_date":"Wed, 12 Nov 2025 19:36:58 GMT",
"update_time":1762947418454
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
### Construct RAPTOR
**POST** `/api/v1/datasets/{dataset_id}/run_raptor`
Construct a RAPTOR from a specified dataset.
#### Request
- Method: POST
- URL: `/api/v1/datasets/{dataset_id}/run_raptor`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request POST \
--url http://{address}/api/v1/datasets/{dataset_id}/run_raptor \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"raptor_task_id":"50d3c31cbfbd11f0ba028f704583b57b"
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
### Get RAPTOR construction status
**GET** `/api/v1/datasets/{dataset_id}/trace_raptor`
Retrieves the RAPTOR construction status for a specified dataset.
#### Request
- Method: GET
- URL: `/api/v1/datasets/{dataset_id}/trace_raptor`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request GET \
--url http://{address}/api/v1/datasets/{dataset_id}/trace_raptor \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"begin_at":"Wed, 12 Nov 2025 19:47:07 GMT",
"chunk_ids":"",
"create_date":"Wed, 12 Nov 2025 19:47:07 GMT",
"create_time":1762948027427,
"digest":"8b279a6248cb8fc6",
"doc_id":"44661c10bde211f0bc93c164a47ffc40",
"from_page":100000000,
"id":"50d3c31cbfbd11f0ba028f704583b57b",
"priority":0,
"process_duration":0.948244,
"progress":1.0,
"progress_msg":"19:47:07 created task raptor\n19:47:07 Task has been received.\n19:47:07 Processing...\n19:47:07 Processing...\n19:47:07 Indexing done (0.01s).\n19:47:07 Task done (0.29s)",
"retry_count":1,
"task_type":"raptor",
"to_page":100000000,
"update_date":"Wed, 12 Nov 2025 19:47:07 GMT",
"update_time":1762948027948
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
## FILE MANAGEMENT WITHIN DATASET
---

View file

@ -114,7 +114,7 @@ class Extractor:
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = []
error_count = 0
max_errors = 3
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency)

View file

@ -482,7 +482,7 @@ def tree_merge(bull, sections, depth):
root = Node(level=0, depth=target_level, texts=[])
root.build_tree(lines)
return [("\n").join(element) for element in root.get_tree() if element]
return [element for element in root.get_tree() if element]
def hierarchical_merge(bull, sections, depth):

View file

@ -347,7 +347,7 @@ class Dealer:
## For rank feature(tag_fea) scores.
rank_fea = self._rank_feature_scores(rank_feature, sres)
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
return tkweight * np.array(tksim) + vtweight * vtsim + rank_fea, tksim, vtsim
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
return self.qryr.hybrid_similarity(ans_embd,

View file

@ -15,27 +15,35 @@
#
import logging
import re
import umap
import numpy as np
from sklearn.mixture import GaussianMixture
import trio
import umap
from sklearn.mixture import GaussianMixture
from api.db.services.task_service import has_canceled
from common.connection_utils import timeout
from common.exceptions import TaskCanceledException
from common.token_utils import truncate
from graphrag.utils import (
get_llm_cache,
chat_limiter,
get_embed_cache,
get_llm_cache,
set_embed_cache,
set_llm_cache,
chat_limiter,
)
from common.token_utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
self,
max_cluster,
llm_model,
embd_model,
prompt,
max_token=512,
threshold=0.1,
max_errors=3,
):
self._max_cluster = max_cluster
self._llm_model = llm_model
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._threshold = threshold
self._prompt = prompt
self._max_token = max_token
self._max_errors = max(1, max_errors)
self._error_count = 0
@timeout(60*20)
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
response = await trio.to_thread.run_sync(
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
)
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
if cached:
return cached
if response:
return response
response = await trio.to_thread.run_sync(
lambda: self._llm_model.chat(system, history, gen_conf)
)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await trio.to_thread.run_sync(
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
)
return response
last_exc = None
for attempt in range(3):
try:
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
return response
except Exception as exc:
last_exc = exc
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2:
await trio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20)
async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync(
lambda: get_embed_cache(self._embd_model.llm_name, txt)
)
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
if response is not None:
return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_clusters = np.arange(1, max_clusters)
bics = []
for n in n_clusters:
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
@timeout(60*20)
@timeout(60 * 20)
async def summarize(ck_idx: list[int]):
nonlocal chunks
@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
raise TaskCanceledException(f"Task {task_id} was cancelled")
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int(
(self._llm_model.max_length - self._max_token) / len(texts)
)
cluster_content = "\n".join(
[truncate(t, max(1, len_per_chunk)) for t in texts]
)
async with chat_limiter:
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
try:
async with chat_limiter:
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(cluster_content=cluster_content),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(
cluster_content=cluster_content
),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
except TaskCanceledException:
raise
except Exception as exc:
self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
labels = []
while end - start > 1:
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
if len(embeddings) == 2:
await summarize([start, start + 1])
if callback:
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
labels.extend([0, 0])
layers.append((end, len(chunks)))
start = end
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
nursery.start_soon(summarize, ck_idx)
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
len(chunks) - end, n_clusters
)
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls)
layers.append((end, len(chunks)))
if callback:
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
start = end
end = len(chunks)

View file

@ -442,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0
if len(tts) == len(cnts):
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0)
tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c
@timeout(60)
@ -465,8 +465,10 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
if not filename_embd_weight:
filename_embd_weight = 0.1
title_w = float(filename_embd_weight)
vects = (title_w * tts + (1 - title_w) *
cnts) if len(tts) == len(cnts) else cnts
if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape:
vects = title_w * tts + (1 - title_w) * cnts
else:
vects = cnts
assert len(vects) == len(docs)
vector_size = 0
@ -649,6 +651,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
res = []
tk_count = 0
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
async def generate(chunks, did):
nonlocal tk_count, res
raptor = Raptor(
@ -658,6 +662,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
raptor_config["prompt"],
raptor_config["max_token"],
raptor_config["threshold"],
max_errors=max_errors,
)
original_length = len(chunks)
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])

View file

@ -16,14 +16,15 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from common import create_dataset
from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN
from configs import DATASET_NAME_LIMIT, DEFAULT_PARSER_CONFIG, INVALID_API_TOKEN
from hypothesis import example, given, settings
from libs.auth import RAGFlowHttpApiAuth
from utils import encode_avatar
from utils.file_utils import create_image_file
from utils.hypothesis_utils import valid_names
from configs import DEFAULT_PARSER_CONFIG
from common import create_dataset
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@ -125,8 +126,8 @@ class TestDatasetCreate:
assert res["code"] == 0, res
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 103, res
assert res["message"] == f"Dataset name '{name}' already exists", res
assert res["code"] == 0, res
assert res["data"]["name"] == name + "(1)", res
@pytest.mark.p3
def test_name_case_insensitive(self, HttpApiAuth):
@ -137,8 +138,8 @@ class TestDatasetCreate:
payload = {"name": name.lower()}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 103, res
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
assert res["code"] == 0, res
assert res["data"]["name"] == name.lower() + "(1)", res
@pytest.mark.p2
def test_avatar(self, HttpApiAuth, tmp_path):

View file

@ -17,13 +17,13 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from operator import attrgetter
import pytest
from configs import DATASET_NAME_LIMIT, HOST_ADDRESS, INVALID_API_TOKEN
from configs import DATASET_NAME_LIMIT, DEFAULT_PARSER_CONFIG, HOST_ADDRESS, INVALID_API_TOKEN
from hypothesis import example, given, settings
from ragflow_sdk import DataSet, RAGFlow
from utils import encode_avatar
from utils.file_utils import create_image_file
from utils.hypothesis_utils import valid_names
from configs import DEFAULT_PARSER_CONFIG
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@ -95,9 +95,8 @@ class TestDatasetCreate:
payload = {"name": name}
client.create_dataset(**payload)
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
assert str(excinfo.value) == f"Dataset name '{name}' already exists", str(excinfo.value)
dataset = client.create_dataset(**payload)
assert dataset.name == name + "(1)", str(dataset)
@pytest.mark.p3
def test_name_case_insensitive(self, client):
@ -106,9 +105,8 @@ class TestDatasetCreate:
client.create_dataset(**payload)
payload = {"name": name.lower()}
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
assert str(excinfo.value) == f"Dataset name '{name.lower()}' already exists", str(excinfo.value)
dataset = client.create_dataset(**payload)
assert dataset.name == name.lower() + "(1)", str(dataset)
@pytest.mark.p2
def test_avatar(self, client, tmp_path):

View file

@ -61,6 +61,12 @@ export interface FormFieldConfig {
horizontal?: boolean;
onChange?: (value: any) => void;
tooltip?: React.ReactNode;
customValidate?: (
value: any,
formValues: any,
) => string | boolean | Promise<string | boolean>;
dependencies?: string[];
schema?: ZodSchema;
}
// Component props interface
@ -94,36 +100,40 @@ const generateSchema = (fields: FormFieldConfig[]): ZodSchema<any> => {
let fieldSchema: ZodSchema;
// Create base validation schema based on field type
switch (field.type) {
case FormFieldType.Email:
fieldSchema = z.string().email('Please enter a valid email address');
break;
case FormFieldType.Number:
fieldSchema = z.coerce.number();
if (field.validation?.min !== undefined) {
fieldSchema = (fieldSchema as z.ZodNumber).min(
field.validation.min,
field.validation.message ||
`Value cannot be less than ${field.validation.min}`,
);
}
if (field.validation?.max !== undefined) {
fieldSchema = (fieldSchema as z.ZodNumber).max(
field.validation.max,
field.validation.message ||
`Value cannot be greater than ${field.validation.max}`,
);
}
break;
case FormFieldType.Checkbox:
fieldSchema = z.boolean();
break;
case FormFieldType.Tag:
fieldSchema = z.array(z.string());
break;
default:
fieldSchema = z.string();
break;
if (field.schema) {
fieldSchema = field.schema;
} else {
switch (field.type) {
case FormFieldType.Email:
fieldSchema = z.string().email('Please enter a valid email address');
break;
case FormFieldType.Number:
fieldSchema = z.coerce.number();
if (field.validation?.min !== undefined) {
fieldSchema = (fieldSchema as z.ZodNumber).min(
field.validation.min,
field.validation.message ||
`Value cannot be less than ${field.validation.min}`,
);
}
if (field.validation?.max !== undefined) {
fieldSchema = (fieldSchema as z.ZodNumber).max(
field.validation.max,
field.validation.message ||
`Value cannot be greater than ${field.validation.max}`,
);
}
break;
case FormFieldType.Checkbox:
fieldSchema = z.boolean();
break;
case FormFieldType.Tag:
fieldSchema = z.array(z.string());
break;
default:
fieldSchema = z.string();
break;
}
}
// Handle required fields
@ -300,10 +310,90 @@ const DynamicForm = {
// Initialize form
const form = useForm<T>({
resolver: zodResolver(schema),
resolver: async (data, context, options) => {
const zodResult = await zodResolver(schema)(data, context, options);
let combinedErrors = { ...zodResult.errors };
const fieldErrors: Record<string, { type: string; message: string }> =
{};
for (const field of fields) {
if (field.customValidate && data[field.name] !== undefined) {
try {
const result = await field.customValidate(
data[field.name],
data,
);
if (typeof result === 'string') {
fieldErrors[field.name] = {
type: 'custom',
message: result,
};
} else if (result === false) {
fieldErrors[field.name] = {
type: 'custom',
message:
field.validation?.message || `${field.label} is invalid`,
};
}
} catch (error) {
fieldErrors[field.name] = {
type: 'custom',
message:
error instanceof Error
? error.message
: 'Validation failed',
};
}
}
}
combinedErrors = {
...combinedErrors,
...fieldErrors,
} as any;
console.log('combinedErrors', combinedErrors);
return {
values: Object.keys(combinedErrors).length ? {} : data,
errors: combinedErrors,
} as any;
},
defaultValues,
});
useEffect(() => {
const dependencyMap: Record<string, string[]> = {};
fields.forEach((field) => {
if (field.dependencies && field.dependencies.length > 0) {
field.dependencies.forEach((dep) => {
if (!dependencyMap[dep]) {
dependencyMap[dep] = [];
}
dependencyMap[dep].push(field.name);
});
}
});
const subscriptions = Object.keys(dependencyMap).map((depField) => {
return form.watch((values: any, { name }) => {
if (name === depField && dependencyMap[depField]) {
dependencyMap[depField].forEach((dependentField) => {
form.trigger(dependentField as any);
});
}
});
});
return () => {
subscriptions.forEach((sub) => {
if (sub.unsubscribe) {
sub.unsubscribe();
}
});
};
}, [fields, form]);
// Expose form methods via ref
useImperativeHandle(ref, () => ({
submit: () => form.handleSubmit(onSubmit)(),

View file

@ -51,6 +51,7 @@ export interface SegmentedProps
direction?: 'ltr' | 'rtl';
motionName?: string;
activeClassName?: string;
itemClassName?: string;
rounded?: keyof typeof segmentedVariants.round;
sizeType?: keyof typeof segmentedVariants.size;
buttonSize?: keyof typeof segmentedVariants.buttonSize;
@ -62,6 +63,7 @@ export function Segmented({
onChange,
className,
activeClassName,
itemClassName,
rounded = 'default',
sizeType = 'default',
buttonSize = 'default',
@ -92,12 +94,13 @@ export function Segmented({
<div
key={actualValue}
className={cn(
'inline-flex items-center text-base font-normal cursor-pointer',
'inline-flex items-center text-base font-normal cursor-pointer',
segmentedVariants.round[rounded],
segmentedVariants.buttonSize[buttonSize],
{
'text-text-primary bg-bg-base': selectedValue === actualValue,
},
itemClassName,
activeClassName && selectedValue === actualValue
? activeClassName
: '',

View file

@ -1009,6 +1009,7 @@ Example: general/v2/`,
pleaseUploadAtLeastOneFile: 'Please upload at least one file',
},
flow: {
formatTypeError: 'Format or type error',
variableNameMessage:
'Variable name can only contain letters and underscores',
variableDescription: 'Variable Description',

View file

@ -956,6 +956,7 @@ General实体和关系提取提示来自 GitHub - microsoft/graphrag基于
pleaseUploadAtLeastOneFile: '请上传至少一个文件',
},
flow: {
formatTypeError: '格式或类型错误',
variableNameMessage: '名称只能包含字母和下划线',
variableDescription: '变量的描述',
defaultValue: '默认值',

View file

@ -0,0 +1,134 @@
import {
DynamicForm,
DynamicFormRef,
FormFieldConfig,
} from '@/components/dynamic-form';
import { Modal } from '@/components/ui/modal/modal';
import { t } from 'i18next';
import { useEffect, useRef } from 'react';
import { FieldValues } from 'react-hook-form';
import { TypeMaps, TypesWithArray } from '../constant';
import { useHandleForm } from '../hooks/use-form';
import { useObjectFields } from '../hooks/use-object-fields';
export const AddVariableModal = (props: {
fields?: FormFieldConfig[];
setFields: (value: any) => void;
visible?: boolean;
hideModal: () => void;
defaultValues?: FieldValues;
setDefaultValues?: (value: FieldValues) => void;
}) => {
const {
fields,
setFields,
visible,
hideModal,
defaultValues,
setDefaultValues,
} = props;
const { handleSubmit: submitForm, loading } = useHandleForm();
const { handleCustomValidate, handleCustomSchema, handleRender } =
useObjectFields();
const formRef = useRef<DynamicFormRef>(null);
const handleFieldUpdate = (
fieldName: string,
updatedField: Partial<FormFieldConfig>,
) => {
setFields((prevFields: any) =>
prevFields.map((field: any) =>
field.name === fieldName ? { ...field, ...updatedField } : field,
),
);
};
useEffect(() => {
const typeField = fields?.find((item) => item.name === 'type');
if (typeField) {
typeField.onChange = (value) => {
handleFieldUpdate('value', {
type: TypeMaps[value as keyof typeof TypeMaps],
render: handleRender(value),
customValidate: handleCustomValidate(value),
schema: handleCustomSchema(value),
});
const values = formRef.current?.getValues();
// setTimeout(() => {
switch (value) {
case TypesWithArray.Boolean:
setDefaultValues?.({ ...values, value: false });
break;
case TypesWithArray.Number:
setDefaultValues?.({ ...values, value: 0 });
break;
case TypesWithArray.Object:
setDefaultValues?.({ ...values, value: {} });
break;
case TypesWithArray.ArrayString:
setDefaultValues?.({ ...values, value: [''] });
break;
case TypesWithArray.ArrayNumber:
setDefaultValues?.({ ...values, value: [''] });
break;
case TypesWithArray.ArrayBoolean:
setDefaultValues?.({ ...values, value: [false] });
break;
case TypesWithArray.ArrayObject:
setDefaultValues?.({ ...values, value: [] });
break;
default:
setDefaultValues?.({ ...values, value: '' });
break;
}
// }, 0);
};
}
}, [fields]);
const handleSubmit = async (fieldValue: FieldValues) => {
await submitForm(fieldValue);
hideModal();
};
return (
<Modal
title={t('flow.add') + t('flow.conversationVariable')}
open={visible || false}
onCancel={hideModal}
showfooter={false}
>
<DynamicForm.Root
ref={formRef}
fields={fields || []}
onSubmit={(data) => {
console.log(data);
}}
defaultValues={defaultValues}
onFieldUpdate={handleFieldUpdate}
>
<div className="flex items-center justify-end w-full gap-2">
<DynamicForm.CancelButton
handleCancel={() => {
hideModal?.();
}}
/>
<DynamicForm.SavingButton
submitLoading={loading || false}
buttonText={t('common.ok')}
submitFunc={(values: FieldValues) => {
handleSubmit(values);
// console.log(values);
// console.log(nodes, edges);
// handleOk(values);
}}
/>
</div>
</DynamicForm.Root>
</Modal>
);
};

View file

@ -13,14 +13,14 @@ export enum TypesWithArray {
String = 'string',
Number = 'number',
Boolean = 'boolean',
// Object = 'object',
// ArrayString = 'array<string>',
// ArrayNumber = 'array<number>',
// ArrayBoolean = 'array<boolean>',
// ArrayObject = 'array<object>',
Object = 'object',
ArrayString = 'array<string>',
ArrayNumber = 'array<number>',
ArrayBoolean = 'array<boolean>',
ArrayObject = 'array<object>',
}
export const GobalFormFields = [
export const GlobalFormFields = [
{
label: t('flow.name'),
name: 'name',
@ -50,11 +50,11 @@ export const GobalFormFields = [
label: t('flow.description'),
name: 'description',
placeholder: t('flow.variableDescription'),
type: 'textarea',
type: FormFieldType.Textarea,
},
] as FormFieldConfig[];
export const GobalVariableFormDefaultValues = {
export const GlobalVariableFormDefaultValues = {
name: '',
type: TypesWithArray.String,
value: '',
@ -65,9 +65,9 @@ export const TypeMaps = {
[TypesWithArray.String]: FormFieldType.Textarea,
[TypesWithArray.Number]: FormFieldType.Number,
[TypesWithArray.Boolean]: FormFieldType.Checkbox,
// [TypesWithArray.Object]: FormFieldType.Textarea,
// [TypesWithArray.ArrayString]: FormFieldType.Textarea,
// [TypesWithArray.ArrayNumber]: FormFieldType.Textarea,
// [TypesWithArray.ArrayBoolean]: FormFieldType.Textarea,
// [TypesWithArray.ArrayObject]: FormFieldType.Textarea,
[TypesWithArray.Object]: FormFieldType.Textarea,
[TypesWithArray.ArrayString]: FormFieldType.Textarea,
[TypesWithArray.ArrayNumber]: FormFieldType.Textarea,
[TypesWithArray.ArrayBoolean]: FormFieldType.Textarea,
[TypesWithArray.ArrayObject]: FormFieldType.Textarea,
};

View file

@ -0,0 +1,41 @@
import { useFetchAgent } from '@/hooks/use-agent-request';
import { GlobalVariableType } from '@/interfaces/database/agent';
import { useCallback } from 'react';
import { FieldValues } from 'react-hook-form';
import { useSaveGraph } from '../../hooks/use-save-graph';
import { TypesWithArray } from '../constant';
export const useHandleForm = () => {
const { data, refetch } = useFetchAgent();
const { saveGraph, loading } = useSaveGraph();
const handleObjectData = (value: any) => {
try {
return JSON.parse(value);
} catch (error) {
return value;
}
};
const handleSubmit = useCallback(async (fieldValue: FieldValues) => {
const param = {
...(data.dsl?.variables || {}),
[fieldValue.name]: {
...fieldValue,
value:
fieldValue.type === TypesWithArray.Object ||
fieldValue.type === TypesWithArray.ArrayObject
? handleObjectData(fieldValue.value)
: fieldValue.value,
},
} as Record<string, GlobalVariableType>;
const res = await saveGraph(undefined, {
globalVariables: param,
});
if (res.code === 0) {
refetch();
}
}, []);
return { handleSubmit, loading };
};

View file

@ -0,0 +1,246 @@
import { BlockButton, Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { Segmented } from '@/components/ui/segmented';
import { Editor } from '@monaco-editor/react';
import { t } from 'i18next';
import { Trash2, X } from 'lucide-react';
import { useCallback } from 'react';
import { FieldValues } from 'react-hook-form';
import { z } from 'zod';
import { TypesWithArray } from '../constant';
export const useObjectFields = () => {
const booleanRender = useCallback(
(field: FieldValues, className?: string) => {
const fieldValue = field.value ? true : false;
return (
<Segmented
options={
[
{ value: true, label: 'True' },
{ value: false, label: 'False' },
] as any
}
sizeType="sm"
value={fieldValue}
onChange={field.onChange}
className={className}
itemClassName="justify-center flex-1"
></Segmented>
);
},
[],
);
const objectRender = useCallback((field: FieldValues) => {
const fieldValue =
typeof field.value === 'object'
? JSON.stringify(field.value, null, 2)
: JSON.stringify({}, null, 2);
console.log('object-render-field', field, fieldValue);
return (
<Editor
height={200}
defaultLanguage="json"
theme="vs-dark"
value={fieldValue}
onChange={field.onChange}
/>
);
}, []);
const objectValidate = useCallback((value: any) => {
try {
if (!JSON.parse(value)) {
throw new Error(t('knowledgeDetails.formatTypeError'));
}
return true;
} catch (e) {
throw new Error(t('knowledgeDetails.formatTypeError'));
}
}, []);
const arrayStringRender = useCallback((field: FieldValues, type = 'text') => {
const values = Array.isArray(field.value)
? field.value
: [type === 'number' ? 0 : ''];
return (
<>
{values?.map((item: any, index: number) => (
<div key={index} className="flex gap-1 items-center">
<Input
type={type}
value={item}
onChange={(e) => {
const newValues = [...values];
newValues[index] = e.target.value;
field.onChange(newValues);
}}
/>
<Button
variant={'secondary'}
onClick={() => {
const newValues = [...values];
newValues.splice(index, 1);
field.onChange(newValues);
}}
>
<Trash2 />
</Button>
</div>
))}
<BlockButton
type="button"
onClick={() => {
field.onChange([...field.value, '']);
}}
>
{t('flow.add')}
</BlockButton>
</>
);
}, []);
const arrayBooleanRender = useCallback(
(field: FieldValues) => {
// const values = field.value || [false];
const values = Array.isArray(field.value) ? field.value : [false];
return (
<div className="flex items-center gap-1 flex-wrap ">
{values?.map((item: any, index: number) => (
<div
key={index}
className="flex gap-1 items-center bg-bg-card rounded-lg border-[0.5px] border-border-button"
>
{booleanRender(
{
value: item,
onChange: (value) => {
values[index] = !!value;
field.onChange(values);
},
},
'bg-transparent',
)}
<Button
variant={'transparent'}
className="border-none py-0 px-1"
onClick={() => {
const newValues = [...values];
newValues.splice(index, 1);
field.onChange(newValues);
}}
>
<X />
</Button>
</div>
))}
<BlockButton
className="w-auto"
type="button"
onClick={() => {
field.onChange([...field.value, false]);
}}
>
{t('flow.add')}
</BlockButton>
</div>
);
},
[booleanRender],
);
const arrayNumberRender = useCallback(
(field: FieldValues) => {
return arrayStringRender(field, 'number');
},
[arrayStringRender],
);
const arrayValidate = useCallback((value: any, type: string = 'string') => {
if (!Array.isArray(value) || !value.every((item) => typeof item === type)) {
throw new Error(t('flow.formatTypeError'));
}
return true;
}, []);
const arrayStringValidate = useCallback(
(value: any) => {
return arrayValidate(value, 'string');
},
[arrayValidate],
);
const arrayNumberValidate = useCallback(
(value: any) => {
return arrayValidate(value, 'number');
},
[arrayValidate],
);
const arrayBooleanValidate = useCallback(
(value: any) => {
return arrayValidate(value, 'boolean');
},
[arrayValidate],
);
const handleRender = (value: TypesWithArray) => {
switch (value) {
case TypesWithArray.Boolean:
return booleanRender;
case TypesWithArray.Object:
case TypesWithArray.ArrayObject:
return objectRender;
case TypesWithArray.ArrayString:
return arrayStringRender;
case TypesWithArray.ArrayNumber:
return arrayNumberRender;
case TypesWithArray.ArrayBoolean:
return arrayBooleanRender;
default:
return undefined;
}
};
const handleCustomValidate = (value: TypesWithArray) => {
switch (value) {
case TypesWithArray.Object:
case TypesWithArray.ArrayObject:
return objectValidate;
case TypesWithArray.ArrayString:
return arrayStringValidate;
case TypesWithArray.ArrayNumber:
return arrayNumberValidate;
case TypesWithArray.ArrayBoolean:
return arrayBooleanValidate;
default:
return undefined;
}
};
const handleCustomSchema = (value: TypesWithArray) => {
switch (value) {
case TypesWithArray.ArrayString:
return z.array(z.string());
case TypesWithArray.ArrayNumber:
return z.array(z.number());
case TypesWithArray.ArrayBoolean:
return z.array(z.boolean());
default:
return undefined;
}
};
return {
objectRender,
objectValidate,
arrayStringRender,
arrayStringValidate,
arrayNumberRender,
booleanRender,
arrayBooleanRender,
arrayNumberValidate,
arrayBooleanValidate,
handleRender,
handleCustomValidate,
handleCustomSchema,
};
};

View file

@ -1,12 +1,6 @@
import { ConfirmDeleteDialog } from '@/components/confirm-delete-dialog';
import {
DynamicForm,
DynamicFormRef,
FormFieldConfig,
FormFieldType,
} from '@/components/dynamic-form';
import { FormFieldConfig } from '@/components/dynamic-form';
import { BlockButton, Button } from '@/components/ui/button';
import { Modal } from '@/components/ui/modal/modal';
import {
Sheet,
SheetContent,
@ -19,117 +13,65 @@ import { GlobalVariableType } from '@/interfaces/database/agent';
import { cn } from '@/lib/utils';
import { t } from 'i18next';
import { Trash2 } from 'lucide-react';
import { useEffect, useRef, useState } from 'react';
import { useState } from 'react';
import { FieldValues } from 'react-hook-form';
import { useSaveGraph } from '../hooks/use-save-graph';
import { AddVariableModal } from './component/add-variable-modal';
import {
GobalFormFields,
GobalVariableFormDefaultValues,
GlobalFormFields,
GlobalVariableFormDefaultValues,
TypeMaps,
TypesWithArray,
} from './contant';
} from './constant';
import { useObjectFields } from './hooks/use-object-fields';
export type IGobalParamModalProps = {
export type IGlobalParamModalProps = {
data: any;
hideModal: (open: boolean) => void;
};
export const GobalParamSheet = (props: IGobalParamModalProps) => {
export const GlobalParamSheet = (props: IGlobalParamModalProps) => {
const { hideModal } = props;
const { data, refetch } = useFetchAgent();
const [fields, setFields] = useState<FormFieldConfig[]>(GobalFormFields);
const { visible, showModal, hideModal: hideAddModal } = useSetModalState();
const [fields, setFields] = useState<FormFieldConfig[]>(GlobalFormFields);
const [defaultValues, setDefaultValues] = useState<FieldValues>(
GobalVariableFormDefaultValues,
GlobalVariableFormDefaultValues,
);
const formRef = useRef<DynamicFormRef>(null);
const { handleCustomValidate, handleCustomSchema, handleRender } =
useObjectFields();
const { saveGraph } = useSaveGraph();
const handleFieldUpdate = (
fieldName: string,
updatedField: Partial<FormFieldConfig>,
) => {
setFields((prevFields) =>
prevFields.map((field) =>
field.name === fieldName ? { ...field, ...updatedField } : field,
),
);
};
useEffect(() => {
const typefileld = fields.find((item) => item.name === 'type');
if (typefileld) {
typefileld.onChange = (value) => {
// setWatchType(value);
handleFieldUpdate('value', {
type: TypeMaps[value as keyof typeof TypeMaps],
});
const values = formRef.current?.getValues();
setTimeout(() => {
switch (value) {
case TypesWithArray.Boolean:
setDefaultValues({ ...values, value: false });
break;
case TypesWithArray.Number:
setDefaultValues({ ...values, value: 0 });
break;
default:
setDefaultValues({ ...values, value: '' });
}
}, 0);
};
}
}, [fields]);
const { saveGraph, loading } = useSaveGraph();
const handleSubmit = async (value: FieldValues) => {
const param = {
...(data.dsl?.variables || {}),
[value.name]: value,
} as Record<string, GlobalVariableType>;
const res = await saveGraph(undefined, {
gobalVariables: param,
});
if (res.code === 0) {
refetch();
}
hideAddModal();
};
const handleDeleteGobalVariable = async (key: string) => {
const handleDeleteGlobalVariable = async (key: string) => {
const param = {
...(data.dsl?.variables || {}),
} as Record<string, GlobalVariableType>;
delete param[key];
const res = await saveGraph(undefined, {
gobalVariables: param,
globalVariables: param,
});
console.log('delete gobal variable-->', res);
if (res.code === 0) {
refetch();
}
};
const handleEditGobalVariable = (item: FieldValues) => {
fields.forEach((field) => {
if (field.name === 'value') {
switch (item.type) {
// [TypesWithArray.String]: FormFieldType.Textarea,
// [TypesWithArray.Number]: FormFieldType.Number,
// [TypesWithArray.Boolean]: FormFieldType.Checkbox,
case TypesWithArray.Boolean:
field.type = FormFieldType.Checkbox;
break;
case TypesWithArray.Number:
field.type = FormFieldType.Number;
break;
default:
field.type = FormFieldType.Textarea;
}
const handleEditGlobalVariable = (item: FieldValues) => {
const newFields = fields.map((field) => {
let newField = field;
newField.render = undefined;
newField.schema = undefined;
newField.customValidate = undefined;
if (newField.name === 'value') {
newField = {
...newField,
type: TypeMaps[item.type as keyof typeof TypeMaps],
render: handleRender(item.type),
customValidate: handleCustomValidate(item.type),
schema: handleCustomSchema(item.type),
};
}
return newField;
});
setFields(newFields);
setDefaultValues(item);
showModal();
};
@ -149,8 +91,8 @@ export const GobalParamSheet = (props: IGobalParamModalProps) => {
<div className="px-5 pb-5">
<BlockButton
onClick={() => {
setFields(GobalFormFields);
setDefaultValues(GobalVariableFormDefaultValues);
setFields(GlobalFormFields);
setDefaultValues(GlobalVariableFormDefaultValues);
showModal();
}}
>
@ -167,7 +109,7 @@ export const GobalParamSheet = (props: IGobalParamModalProps) => {
key={key}
className="flex items-center gap-3 min-h-14 justify-between px-5 py-3 border border-border-default rounded-lg hover:bg-bg-card group"
onClick={() => {
handleEditGobalVariable(item);
handleEditGlobalVariable(item);
}}
>
<div className="flex flex-col">
@ -177,13 +119,23 @@ export const GobalParamSheet = (props: IGobalParamModalProps) => {
{item.type}
</span>
</div>
<div>
<span className="text-text-primary">{item.value}</span>
</div>
{![
TypesWithArray.Object,
TypesWithArray.ArrayObject,
TypesWithArray.ArrayString,
TypesWithArray.ArrayNumber,
TypesWithArray.ArrayBoolean,
].includes(item.type as TypesWithArray) && (
<div>
<span className="text-text-primary">
{item.value}
</span>
</div>
)}
</div>
<div>
<ConfirmDeleteDialog
onOk={() => handleDeleteGobalVariable(key)}
onOk={() => handleDeleteGlobalVariable(key)}
>
<Button
variant={'secondary'}
@ -201,40 +153,14 @@ export const GobalParamSheet = (props: IGobalParamModalProps) => {
})}
</div>
</SheetContent>
<Modal
title={t('flow.add') + t('flow.conversationVariable')}
open={visible}
onCancel={hideAddModal}
showfooter={false}
>
<DynamicForm.Root
ref={formRef}
fields={fields}
onSubmit={(data) => {
console.log(data);
}}
defaultValues={defaultValues}
onFieldUpdate={handleFieldUpdate}
>
<div className="flex items-center justify-end w-full gap-2">
<DynamicForm.CancelButton
handleCancel={() => {
hideAddModal?.();
}}
/>
<DynamicForm.SavingButton
submitLoading={loading || false}
buttonText={t('common.ok')}
submitFunc={(values: FieldValues) => {
handleSubmit(values);
// console.log(values);
// console.log(nodes, edges);
// handleOk(values);
}}
/>
</div>
</DynamicForm.Root>
</Modal>
<AddVariableModal
visible={visible}
hideModal={hideAddModal}
fields={fields}
setFields={setFields}
defaultValues={defaultValues}
setDefaultValues={setDefaultValues}
/>
</Sheet>
</>
);

View file

@ -4,7 +4,7 @@ import { RAGFlowNodeType } from '@/interfaces/database/flow';
import { useCallback } from 'react';
import { Operator } from '../constant';
import useGraphStore from '../store';
import { buildDslComponentsByGraph, buildDslGobalVariables } from '../utils';
import { buildDslComponentsByGraph, buildDslGlobalVariables } from '../utils';
export const useBuildDslData = () => {
const { data } = useFetchAgent();
@ -13,7 +13,7 @@ export const useBuildDslData = () => {
const buildDslData = useCallback(
(
currentNodes?: RAGFlowNodeType[],
otherParam?: { gobalVariables: Record<string, GlobalVariableType> },
otherParam?: { globalVariables: Record<string, GlobalVariableType> },
) => {
const nodesToProcess = currentNodes ?? nodes;
@ -41,13 +41,13 @@ export const useBuildDslData = () => {
data.dsl.components,
);
const gobalVariables = buildDslGobalVariables(
const globalVariables = buildDslGlobalVariables(
data.dsl,
otherParam?.gobalVariables,
otherParam?.globalVariables,
);
return {
...data.dsl,
...gobalVariables,
...globalVariables,
graph: { nodes: filteredNodes, edges: filteredEdges },
components: dslComponents,
};

View file

@ -21,7 +21,7 @@ export const useSaveGraph = (showMessage: boolean = true) => {
const saveGraph = useCallback(
async (
currentNodes?: RAGFlowNodeType[],
otherParam?: { gobalVariables: Record<string, GlobalVariableType> },
otherParam?: { globalVariables: Record<string, GlobalVariableType> },
) => {
return setAgent({
id,

View file

@ -39,7 +39,7 @@ import { useParams } from 'umi';
import AgentCanvas from './canvas';
import { DropdownProvider } from './canvas/context';
import { Operator } from './constant';
import { GobalParamSheet } from './gobal-variable-sheet';
import { GlobalParamSheet } from './gobal-variable-sheet';
import { useCancelCurrentDataflow } from './hooks/use-cancel-dataflow';
import { useHandleExportJsonFile } from './hooks/use-export-json';
import { useFetchDataOnMount } from './hooks/use-fetch-data';
@ -126,9 +126,9 @@ export default function Agent() {
} = useSetModalState();
const {
visible: gobalParamSheetVisible,
showModal: showGobalParamSheet,
hideModal: hideGobalParamSheet,
visible: globalParamSheetVisible,
showModal: showGlobalParamSheet,
hideModal: hideGlobalParamSheet,
} = useSetModalState();
const {
@ -216,7 +216,7 @@ export default function Agent() {
</ButtonLoading>
<ButtonLoading
variant={'secondary'}
onClick={() => showGobalParamSheet()}
onClick={() => showGlobalParamSheet()}
loading={loading}
>
<MessageSquareCode /> {t('flow.conversationVariable')}
@ -314,11 +314,11 @@ export default function Agent() {
loading={pipelineRunning}
></PipelineRunSheet>
)}
{gobalParamSheetVisible && (
<GobalParamSheet
{globalParamSheetVisible && (
<GlobalParamSheet
data={{}}
hideModal={hideGobalParamSheet}
></GobalParamSheet>
hideModal={hideGlobalParamSheet}
></GlobalParamSheet>
)}
</section>
);

View file

@ -348,30 +348,30 @@ export const buildDslComponentsByGraph = (
return components;
};
export const buildDslGobalVariables = (
export const buildDslGlobalVariables = (
dsl: DSL,
gobalVariables?: Record<string, GlobalVariableType>,
globalVariables?: Record<string, GlobalVariableType>,
) => {
if (!gobalVariables) {
if (!globalVariables) {
return { globals: dsl.globals, variables: dsl.variables || {} };
}
let gobalVariablesTemp: Record<string, any> = {};
let gobalSystem: Record<string, any> = {};
let globalVariablesTemp: Record<string, any> = {};
let globalSystem: Record<string, any> = {};
Object.keys(dsl.globals)?.forEach((key) => {
if (key.indexOf('sys') > -1) {
gobalSystem[key] = dsl.globals[key];
globalSystem[key] = dsl.globals[key];
}
});
Object.keys(gobalVariables).forEach((key) => {
gobalVariablesTemp['env.' + key] = gobalVariables[key].value;
Object.keys(globalVariables).forEach((key) => {
globalVariablesTemp['env.' + key] = globalVariables[key].value;
});
const gobalVariablesResult = {
...gobalSystem,
...gobalVariablesTemp,
const globalVariablesResult = {
...globalSystem,
...globalVariablesTemp,
};
return { globals: gobalVariablesResult, variables: gobalVariables };
return { globals: globalVariablesResult, variables: globalVariables };
};
export const receiveMessageError = (res: any) =>

View file

@ -7,11 +7,14 @@ import {
FormMessage,
} from '@/components/ui/form';
import { Radio } from '@/components/ui/radio';
import { Spin } from '@/components/ui/spin';
import { Switch } from '@/components/ui/switch';
import { useTranslate } from '@/hooks/common-hooks';
import { cn } from '@/lib/utils';
import { useMemo, useState } from 'react';
import { useFormContext } from 'react-hook-form';
import {
useHandleKbEmbedding,
useHasParsedDocument,
useSelectChunkMethodList,
useSelectEmbeddingModelOptions,
@ -62,11 +65,17 @@ export function ChunkMethodItem(props: IProps) {
/>
);
}
export function EmbeddingModelItem({ line = 1, isEdit = true }: IProps) {
export function EmbeddingModelItem({ line = 1, isEdit }: IProps) {
const { t } = useTranslate('knowledgeConfiguration');
const form = useFormContext();
const embeddingModelOptions = useSelectEmbeddingModelOptions();
const { handleChange } = useHandleKbEmbedding();
const disabled = useHasParsedDocument(isEdit);
const oldValue = useMemo(() => {
const embdStr = form.getValues('embd_id');
return embdStr || '';
}, [form]);
const [loading, setLoading] = useState(false);
return (
<>
<FormField
@ -93,14 +102,33 @@ export function EmbeddingModelItem({ line = 1, isEdit = true }: IProps) {
className={cn('text-muted-foreground', { 'w-3/4': line === 1 })}
>
<FormControl>
<SelectWithSearch
onChange={field.onChange}
value={field.value}
options={embeddingModelOptions}
disabled={isEdit ? disabled : false}
placeholder={t('embeddingModelPlaceholder')}
triggerClassName="!bg-bg-base"
/>
<Spin
spinning={loading}
className={cn(' rounded-lg after:bg-bg-base', {
'opacity-20': loading,
})}
>
<SelectWithSearch
onChange={async (value) => {
field.onChange(value);
if (isEdit && disabled) {
setLoading(true);
const res = await handleChange({
embed_id: value,
callback: field.onChange,
});
if (res.code !== 0) {
field.onChange(oldValue);
}
setLoading(false);
}
}}
value={field.value}
options={embeddingModelOptions}
placeholder={t('embeddingModelPlaceholder')}
triggerClassName="!bg-bg-base"
/>
</Spin>
</FormControl>
</div>
</div>

View file

@ -88,7 +88,7 @@ export function GeneralForm() {
}}
/>
<PermissionFormField></PermissionFormField>
<EmbeddingModelItem></EmbeddingModelItem>
<EmbeddingModelItem isEdit={true}></EmbeddingModelItem>
<PageRankFormField></PageRankFormField>
<TagItems></TagItems>

View file

@ -4,10 +4,12 @@ import { useSetModalState } from '@/hooks/common-hooks';
import { useSelectLlmOptionsByModelType } from '@/hooks/llm-hooks';
import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-request';
import { useSelectParserList } from '@/hooks/user-setting-hooks';
import kbService from '@/services/knowledge-service';
import { useIsFetching } from '@tanstack/react-query';
import { pick } from 'lodash';
import { useCallback, useEffect, useState } from 'react';
import { UseFormReturn } from 'react-hook-form';
import { useParams, useSearchParams } from 'umi';
import { z } from 'zod';
import { formSchema } from './form-schema';
@ -98,3 +100,22 @@ export const useRenameKnowledgeTag = () => {
showTagRenameModal: handleShowTagRenameModal,
};
};
export const useHandleKbEmbedding = () => {
const { id } = useParams();
const [searchParams] = useSearchParams();
const knowledgeBaseId = searchParams.get('id') || id;
const handleChange = useCallback(
async ({ embed_id }: { embed_id: string }) => {
const res = await kbService.checkEmbedding({
kb_id: knowledgeBaseId,
embd_id: embed_id,
});
return res.data;
},
[knowledgeBaseId],
);
return {
handleChange,
};
};

View file

@ -47,6 +47,7 @@ const {
traceGraphRag,
runRaptor,
traceRaptor,
check_embedding,
} = api;
const methods = {
@ -214,6 +215,11 @@ const methods = {
url: api.pipelineRerun,
method: 'post',
},
checkEmbedding: {
url: check_embedding,
method: 'post',
},
};
const kbService = registerServer<keyof typeof methods>(methods, request);

View file

@ -49,6 +49,8 @@ export default {
llm_tools: `${api_host}/plugin/llm_tools`,
// knowledge base
check_embedding: `${api_host}/kb/check_embedding`,
kb_list: `${api_host}/kb/list`,
create_kb: `${api_host}/kb/create`,
update_kb: `${api_host}/kb/update`,