ragflow/rag/utils/ob_conn.py
2025-11-18 14:35:39 +08:00

1569 lines
66 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
import re
import time
from typing import Any, Optional
from elasticsearch_dsl import Q, Search
from pydantic import BaseModel
from pymysql.converters import escape_string
from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR
from pyobvector.client.hybrid_search import HybridSearch
from pyobvector.util import ObVersion
from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table
from sqlalchemy.dialects.mysql import LONGTEXT, TEXT
from sqlalchemy.sql.type_api import TypeEngine
from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton
from common.float_utils import get_float
from rag.nlp import rag_tokenizer
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
MatchDenseExpr
ATTEMPT_TIME = 2
OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000"))
logger = logging.getLogger('ragflow.ob_conn')
column_order_id = Column("_order_id", Integer, nullable=True, comment="chunk order id for maintaining sequence")
column_group_id = Column("group_id", String(256), nullable=True, comment="group id for external retrieval")
column_definitions: list[Column] = [
Column("id", String(256), primary_key=True, comment="chunk id"),
Column("kb_id", String(256), nullable=False, index=True, comment="knowledge base id"),
Column("doc_id", String(256), nullable=True, index=True, comment="document id"),
Column("docnm_kwd", String(256), nullable=True, comment="document name"),
Column("doc_type_kwd", String(256), nullable=True, comment="document type"),
Column("title_tks", String(256), nullable=True, comment="title tokens"),
Column("title_sm_tks", String(256), nullable=True, comment="fine-grained (small) title tokens"),
Column("content_with_weight", LONGTEXT, nullable=True, comment="the original content"),
Column("content_ltks", LONGTEXT, nullable=True, comment="long text tokens derived from content_with_weight"),
Column("content_sm_ltks", LONGTEXT, nullable=True, comment="fine-grained (small) tokens derived from content_ltks"),
Column("pagerank_fea", Integer, nullable=True, comment="page rank priority, usually set in kb level"),
Column("important_kwd", ARRAY(String(256)), nullable=True, comment="keywords"),
Column("important_tks", TEXT, nullable=True, comment="keyword tokens"),
Column("question_kwd", ARRAY(String(1024)), nullable=True, comment="questions"),
Column("question_tks", TEXT, nullable=True, comment="question tokens"),
Column("tag_kwd", ARRAY(String(256)), nullable=True, comment="tags"),
Column("tag_feas", JSON, nullable=True,
comment="tag features used for 'rank_feature', format: [tag -> relevance score]"),
Column("available_int", Integer, nullable=False, index=True, server_default="1",
comment="status of availability, 0 for unavailable, 1 for available"),
Column("create_time", String(19), nullable=True, comment="creation time in YYYY-MM-DD HH:MM:SS format"),
Column("create_timestamp_flt", Double, nullable=True, comment="creation timestamp in float format"),
Column("img_id", String(128), nullable=True, comment="image id"),
Column("position_int", ARRAY(ARRAY(Integer)), nullable=True, comment="position"),
Column("page_num_int", ARRAY(Integer), nullable=True, comment="page number"),
Column("top_int", ARRAY(Integer), nullable=True, comment="rank from the top"),
Column("knowledge_graph_kwd", String(256), nullable=True, index=True, comment="knowledge graph chunk type"),
Column("source_id", ARRAY(String(256)), nullable=True, comment="source document id"),
Column("entity_kwd", String(256), nullable=True, comment="entity name"),
Column("entity_type_kwd", String(256), nullable=True, index=True, comment="entity type"),
Column("from_entity_kwd", String(256), nullable=True, comment="the source entity of this edge"),
Column("to_entity_kwd", String(256), nullable=True, comment="the target entity of this edge"),
Column("weight_int", Integer, nullable=True, comment="the weight of this edge"),
Column("weight_flt", Double, nullable=True, comment="the weight of community report"),
Column("entities_kwd", ARRAY(String(256)), nullable=True, comment="node ids of entities"),
Column("rank_flt", Double, nullable=True, comment="rank of this entity"),
Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'",
comment="whether it has been deleted"),
Column("metadata", JSON, nullable=True, comment="metadata for this chunk"),
Column("extra", JSON, nullable=True, comment="extra information of non-general chunk"),
column_order_id,
column_group_id,
]
column_names: list[str] = [col.name for col in column_definitions]
column_types: dict[str, TypeEngine] = {col.name: col.type for col in column_definitions}
array_columns: list[str] = [col.name for col in column_definitions if isinstance(col.type, ARRAY)]
vector_column_pattern = re.compile(r"q_(?P<vector_size>\d+)_vec")
index_columns: list[str] = [
"kb_id",
"doc_id",
"available_int",
"knowledge_graph_kwd",
"entity_type_kwd",
"removed_kwd",
]
fulltext_search_columns: list[str] = [
"docnm_kwd",
"content_with_weight",
"title_tks",
"title_sm_tks",
"important_tks",
"question_tks",
"content_ltks",
"content_sm_ltks"
]
fts_columns_origin: list[str] = [
"docnm_kwd^10",
"content_with_weight",
"important_tks^20",
"question_tks^20",
]
fts_columns_tks: list[str] = [
"title_tks^10",
"title_sm_tks^5",
"important_tks^20",
"question_tks^20",
"content_ltks^2",
"content_sm_ltks",
]
index_name_template = "ix_%s_%s"
fulltext_index_name_template = "fts_idx_%s"
# MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607
fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)"
# cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938
vector_search_template = "cosine_distance(%s, %s)"
class SearchResult(BaseModel):
total: int
chunks: list[dict]
def get_column_value(column_name: str, value: Any) -> Any:
if column_name in column_types:
column_type = column_types[column_name]
if isinstance(column_type, String):
return str(value)
elif isinstance(column_type, Integer):
return int(value)
elif isinstance(column_type, Double):
return float(value)
elif isinstance(column_type, ARRAY) or isinstance(column_type, JSON):
if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError:
# 如果JSON解析失败返回原始字符串
return value
else:
return value
else:
raise ValueError(f"Unsupported column type for column '{column_name}': {column_type}")
elif vector_column_pattern.match(column_name):
if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError:
# 如果JSON解析失败返回原始字符串
return value
else:
return value
elif column_name == "_score":
return float(value)
else:
raise ValueError(f"Unknown column '{column_name}' with value '{value}'.")
def get_default_value(column_name: str) -> Any:
if column_name == "available_int":
return 1
elif column_name == "removed_kwd":
return "N"
elif column_name == "_order_id":
return 0
else:
return None
def get_value_str(value: Any) -> str:
if isinstance(value, str):
# 对字符串进行额外的清理确保不会导致JSON解析错误
cleaned_str = value.replace('\\', '\\\\') # 转义反斜杠
cleaned_str = cleaned_str.replace('\n', '\\n') # 转义换行符
cleaned_str = cleaned_str.replace('\r', '\\r') # 转义回车符
cleaned_str = cleaned_str.replace('\t', '\\t') # 转义制表符
return f"'{escape_string(cleaned_str)}'"
elif isinstance(value, bool):
return "true" if value else "false"
elif value is None:
return "NULL"
elif isinstance(value, (list, dict)):
# 确保JSON字符串中的特殊字符被正确转义
json_str = json.dumps(value, ensure_ascii=False)
return f"'{escape_string(json_str)}'"
else:
return str(value)
def get_metadata_filter_expression(metadata_filtering_conditions: dict) -> str:
"""
Convert metadata filtering conditions to MySQL JSON path expression.
Args:
metadata_filtering_conditions: dict with 'conditions' and 'logical_operator' keys
Returns:
MySQL JSON path expression string
"""
if not metadata_filtering_conditions:
return ""
conditions = metadata_filtering_conditions.get("conditions", [])
logical_operator = metadata_filtering_conditions.get("logical_operator", "and").upper()
if not conditions:
return ""
if logical_operator not in ["AND", "OR"]:
raise ValueError(f"Unsupported logical operator: {logical_operator}. Only 'and' and 'or' are supported.")
metadata_filters = []
for condition in conditions:
name = condition.get("name")
comparison_operator = condition.get("comparison_operator")
value = condition.get("value")
if not all([name, comparison_operator]):
continue
expr = f"JSON_EXTRACT(metadata, '$.{name}')"
value_str = get_value_str(value) if value else ""
# Convert comparison operator to MySQL JSON path syntax
if comparison_operator == "is":
# JSON_EXTRACT(metadata, '$.field_name') = 'value'
metadata_filters.append(f"{expr} = {value_str}")
elif comparison_operator == "is not":
metadata_filters.append(f"{expr} != {value_str}")
elif comparison_operator == "contains":
metadata_filters.append(f"JSON_CONTAINS({expr}, {value_str})")
elif comparison_operator == "not contains":
metadata_filters.append(f"NOT JSON_CONTAINS({expr}, {value_str})")
elif comparison_operator == "start with":
metadata_filters.append(f"{expr} LIKE CONCAT({value_str}, '%')")
elif comparison_operator == "end with":
metadata_filters.append(f"{expr} LIKE CONCAT('%', {value_str})")
elif comparison_operator == "empty":
metadata_filters.append(f"({expr} IS NULL OR {expr} = '' OR {expr} = '[]' OR {expr} = '{{}}')")
elif comparison_operator == "not empty":
metadata_filters.append(f"({expr} IS NOT NULL AND {expr} != '' AND {expr} != '[]' AND {expr} != '{{}}')")
# Number operators
elif comparison_operator == "=":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) = {value_str}")
elif comparison_operator == "":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) != {value_str}")
elif comparison_operator == ">":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) > {value_str}")
elif comparison_operator == "<":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) < {value_str}")
elif comparison_operator == "":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) >= {value_str}")
elif comparison_operator == "":
metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) <= {value_str}")
# Time operators
elif comparison_operator == "before":
metadata_filters.append(f"CAST({expr} AS DATETIME) < {value_str}")
elif comparison_operator == "after":
metadata_filters.append(f"CAST({expr} AS DATETIME) > {value_str}")
else:
logger.warning(f"Unsupported comparison operator: {comparison_operator}")
continue
if not metadata_filters:
return ""
return f"({f' {logical_operator} '.join(metadata_filters)})"
def get_filters(condition: dict) -> list[str]:
filters: list[str] = []
for k, v in condition.items():
if not v:
continue
if k == "exists":
filters.append(f"{v} IS NOT NULL")
elif k == "must_not" and isinstance(v, dict) and "exists" in v:
filters.append(f"{v.get('exists')} IS NULL")
elif k == "metadata_filtering_conditions":
# Handle metadata filtering conditions
metadata_filter = get_metadata_filter_expression(v)
if metadata_filter:
filters.append(metadata_filter)
elif k in array_columns:
if isinstance(v, list):
array_filters = []
for vv in v:
array_filters.append(f"array_contains({k}, {get_value_str(vv)})")
array_filter = " OR ".join(array_filters)
filters.append(f"({array_filter})")
else:
filters.append(f"array_contains({k}, {get_value_str(v)})")
elif isinstance(v, list):
values: list[str] = []
for item in v:
values.append(get_value_str(item))
value = ", ".join(values)
filters.append(f"{k} IN ({value})")
else:
filters.append(f"{k} = {get_value_str(v)}")
return filters
def _try_with_lock(lock_name: str, process_func, check_func, timeout: int = None):
if not timeout:
timeout = int(os.environ.get("OB_DDL_TIMEOUT", "60"))
if not check_func():
from rag.utils.redis_conn import RedisDistributedLock
lock = RedisDistributedLock(lock_name)
if lock.acquire():
logger.info(f"acquired lock success: {lock_name}, start processing.")
try:
process_func()
return
finally:
lock.release()
if not check_func():
logger.info(f"Waiting for process complete for {lock_name} on other task executors.")
time.sleep(1)
count = 1
while count < timeout and not check_func():
count += 1
time.sleep(1)
if count >= timeout and not check_func():
raise Exception(f"Timeout to wait for process complete for {lock_name}.")
@singleton
class OBConnection(DocStoreConnection):
def __init__(self):
scheme: str = settings.OB.get("scheme")
ob_config = settings.OB.get("config", {})
if scheme and scheme.lower() == "mysql":
mysql_config = settings.get_base_config("mysql", {})
logger.info("Use MySQL scheme to create OceanBase connection.")
host = mysql_config.get("host", "localhost")
port = mysql_config.get("port", 2881)
self.username = mysql_config.get("user", "root@test")
self.password = mysql_config.get("password", "infini_rag_flow")
else:
logger.info("Use customized config to create OceanBase connection.")
host = ob_config.get("host", "localhost")
port = ob_config.get("port", 2881)
self.username = ob_config.get("user", "root@test")
self.password = ob_config.get("password", "infini_rag_flow")
self.db_name = ob_config.get("db_name", "test")
self.uri = f"{host}:{port}"
logger.info(f"Use OceanBase '{self.uri}' as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
self.client = ObVecClient(
uri=self.uri,
user=self.username,
password=self.password,
db_name=self.db_name,
pool_pre_ping=True,
pool_recycle=3600,
)
break
except Exception as e:
logger.warning(f"{str(e)}. Waiting OceanBase {self.uri} to be healthy.")
time.sleep(5)
if self.client is None:
msg = f"OceanBase {self.uri} connection failed after {ATTEMPT_TIME} attempts."
logger.error(msg)
raise Exception(msg)
self._load_env_vars()
self._check_ob_version()
self._try_to_update_ob_query_timeout()
logger.info(f"OceanBase {self.uri} is healthy.")
def _check_ob_version(self):
try:
res = self.client.perform_raw_text_sql("SELECT OB_VERSION() FROM DUAL").fetchone()
version_str = res[0] if res else None
logger.info(f"OceanBase {self.uri} version is {version_str}")
except Exception as e:
raise Exception(f"Failed to get OceanBase version from {self.uri}, error: {str(e)}")
if not version_str:
raise Exception(f"Failed to get OceanBase version from {self.uri}.")
ob_version = ObVersion.from_db_version_string(version_str)
if ob_version < ObVersion.from_db_version_nums(4, 3, 5, 1):
raise Exception(
f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}"
)
self.es = None
if not ob_version < ObVersion.from_db_version_nums(4, 4, 1, 0) and self.enable_hybrid_search:
self.es = HybridSearch(
uri=self.uri,
user=self.username,
password=self.password,
db_name=self.db_name,
pool_pre_ping=True,
pool_recycle=3600,
)
logger.info("OceanBase Hybrid Search feature is enabled")
def _try_to_update_ob_query_timeout(self):
try:
val = self._get_variable_value("ob_query_timeout")
if val and int(val) >= OB_QUERY_TIMEOUT:
return
except Exception as e:
logger.warning("Failed to get 'ob_query_timeout' variable: %s", str(e))
try:
self.client.perform_raw_text_sql(f"SET GLOBAL ob_query_timeout={OB_QUERY_TIMEOUT}")
logger.info("Set GLOBAL variable 'ob_query_timeout' to %d.", OB_QUERY_TIMEOUT)
# refresh connection pool to ensure 'ob_query_timeout' has taken effect
self.client.engine.dispose()
if self.es is not None:
self.es.engine.dispose()
logger.info("Disposed all connections in engine pool to refresh connection pool")
except Exception as e:
logger.warning(f"Failed to set 'ob_query_timeout' variable: {str(e)}")
def _load_env_vars(self):
def is_true(var: str, default: str) -> bool:
return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y']
self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true')
self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true')
self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true')
self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false')
"""
Database operations
"""
def dbType(self) -> str:
return "oceanbase"
def health(self) -> dict:
return {
"uri": self.uri,
"version_comment": self._get_variable_value("version_comment")
}
def _get_variable_value(self, var_name: str) -> Any:
rows = self.client.perform_raw_text_sql(f"SHOW VARIABLES LIKE '{var_name}'")
for row in rows:
return row[1]
raise Exception(f"Variable '{var_name}' not found.")
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
vector_field_name = f"q_{vectorSize}_vec"
vector_index_name = f"{vector_field_name}_idx"
try:
_try_with_lock(
lock_name=f"ob_create_table_{indexName}",
check_func=lambda: self.client.check_table_exists(indexName),
process_func=lambda: self._create_table(indexName),
)
for column_name in index_columns:
_try_with_lock(
lock_name=f"ob_add_idx_{indexName}_{column_name}",
check_func=lambda: self._index_exists(indexName, index_name_template % (indexName, column_name)),
process_func=lambda: self._add_index(indexName, column_name),
)
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
for fts_column in fts_columns:
column_name = fts_column.split("^")[0]
_try_with_lock(
lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}",
check_func=lambda: self._index_exists(indexName, fulltext_index_name_template % column_name),
process_func=lambda: self._add_fulltext_index(indexName, column_name),
)
_try_with_lock(
lock_name=f"ob_add_vector_column_{indexName}_{vector_field_name}",
check_func=lambda: self._column_exist(indexName, vector_field_name),
process_func=lambda: self._add_vector_column(indexName, vectorSize),
)
_try_with_lock(
lock_name=f"ob_add_vector_idx_{indexName}_{vector_field_name}",
check_func=lambda: self._index_exists(indexName, vector_index_name),
process_func=lambda: self._add_vector_index(indexName, vector_field_name),
)
# new columns migration
for column in [column_order_id, column_group_id]:
_try_with_lock(
lock_name=f"ob_add_{column.name}_{indexName}",
check_func=lambda: self._column_exist(indexName, column.name),
process_func=lambda: self._add_column(indexName, column),
)
except Exception as e:
raise Exception(f"OBConnection.createIndex error: {str(e)}")
finally:
# always refresh metadata to make sure it contains the latest table structure
self.client.refresh_metadata([indexName])
def deleteIdx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
try:
if self.client.check_table_exists(table_name=indexName):
self.client.drop_table_if_exist(indexName)
logger.info(f"Dropped table '{indexName}'.")
except Exception as e:
raise Exception(f"OBConnection.deleteIndex error: {str(e)}")
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
try:
if not self.client.check_table_exists(indexName):
return False
for column_name in index_columns:
if not self._index_exists(indexName, index_name_template % (indexName, column_name)):
return False
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
for fts_column in fts_columns:
column_name = fts_column.split("^")[0]
if not self._index_exists(indexName, fulltext_index_name_template % column_name):
return False
for column in [column_order_id, column_group_id]:
if not self._column_exist(indexName, column.name):
return False
except Exception as e:
raise Exception(f"OBConnection.indexExist error: {str(e)}")
return True
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else ""
(count,) = self.client.perform_raw_text_sql(
f"SELECT COUNT(*) FROM {table_name} {where_clause}"
).fetchone()
return count
def _column_exist(self, table_name: str, column_name: str) -> bool:
return self._get_count(
table_name="INFORMATION_SCHEMA.COLUMNS",
filter_list=[
f"TABLE_SCHEMA = '{self.db_name}'",
f"TABLE_NAME = '{table_name}'",
f"COLUMN_NAME = '{column_name}'",
]) > 0
def _index_exists(self, table_name: str, index_name: str) -> bool:
return self._get_count(
table_name="INFORMATION_SCHEMA.STATISTICS",
filter_list=[
f"TABLE_SCHEMA = '{self.db_name}'",
f"TABLE_NAME = '{table_name}'",
f"INDEX_NAME = '{index_name}'",
]) > 0
def _create_table(self, table_name: str):
# remove outdated metadata for external changes
if table_name in self.client.metadata_obj.tables:
self.client.metadata_obj.remove(Table(table_name, self.client.metadata_obj))
table_options = {
"mysql_charset": "utf8mb4",
"mysql_collate": "utf8mb4_unicode_ci",
"mysql_organization": "heap",
}
self.client.create_table(
table_name=table_name,
columns=column_definitions,
**table_options,
)
logger.info(f"Created table '{table_name}'.")
def _add_index(self, table_name: str, column_name: str):
index_name = index_name_template % (table_name, column_name)
self.client.create_index(
table_name=table_name,
is_vec_index=False,
index_name=index_name,
column_names=[column_name],
)
logger.info(f"Created index '{index_name}' on table '{table_name}'.")
def _add_fulltext_index(self, table_name: str, column_name: str):
fulltext_index_name = fulltext_index_name_template % column_name
self.client.create_fts_idx_with_fts_index_param(
table_name=table_name,
fts_idx_param=FtsIndexParam(
index_name=fulltext_index_name,
field_names=[column_name],
parser_type=FtsParser.IK,
),
)
logger.info(f"Created full text index '{fulltext_index_name}' on table '{table_name}'.")
def _add_vector_column(self, table_name: str, vector_size: int):
vector_field_name = f"q_{vector_size}_vec"
self.client.add_columns(
table_name=table_name,
columns=[Column(vector_field_name, VECTOR(vector_size), nullable=True)],
)
logger.info(f"Added vector column '{vector_field_name}' to table '{table_name}'.")
def _add_vector_index(self, table_name: str, vector_field_name: str):
vector_index_name = f"{vector_field_name}_idx"
self.client.create_index(
table_name=table_name,
is_vec_index=True,
index_name=vector_index_name,
column_names=[vector_field_name],
vidx_params="distance=cosine, type=hnsw, lib=vsag",
)
logger.info(
f"Created vector index '{vector_index_name}' on table '{table_name}' with column '{vector_field_name}'."
)
def _add_column(self, table_name: str, column: Column):
try:
self.client.add_columns(
table_name=table_name,
columns=[column],
)
logger.info(f"Added column '{column.name}' to table '{table_name}'.")
except Exception as e:
logger.warning(f"Failed to add column '{column.name}' to table '{table_name}': {str(e)}")
"""
CRUD operations
"""
def search(
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None,
**kwargs,
):
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
indexNames = list(set(indexNames))
if len(matchExprs) == 3:
if not self.enable_fulltext_search:
# disable fulltext search in fusion search, which means fallback to vector search
matchExprs = [m for m in matchExprs if isinstance(m, MatchDenseExpr)]
else:
for m in matchExprs:
if isinstance(m, FusionExpr):
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
# skip the search if its weight is zero
if vector_similarity_weight <= 0.0:
matchExprs = [m for m in matchExprs if isinstance(m, MatchTextExpr)]
elif vector_similarity_weight >= 1.0:
matchExprs = [m for m in matchExprs if isinstance(m, MatchDenseExpr)]
result: SearchResult = SearchResult(
total=0,
chunks=[],
)
# copied from es_conn.py
if len(matchExprs) == 3 and self.es:
bqry = Q("bool", must=[])
condition["kb_id"] = knowledgebaseIds
for k, v in condition.items():
if k == "available_int":
if v == 0:
bqry.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in matchExprs:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(
matchExprs[1],
MatchDenseExpr) and isinstance(
matchExprs[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in matchExprs:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bqry.must.append(Q("query_string", fields=fts_columns_tks,
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bqry.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bqry is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(m.vector_column_name,
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bqry.to_dict(),
similarity=similarity,
)
if bqry and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bqry:
s = s.query(bqry)
# for field in highlightFields:
# s = s.highlight(field)
if orderBy:
orders = list()
for field, order in orderBy.fields:
order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}
elif field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
else:
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
for fld in aggFields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset + limit]
q = s.to_dict()
logger.debug(f"OBConnection.hybrid_search {str(indexNames)} query: " + json.dumps(q))
for index_name in indexNames:
start_time = time.time()
res = self.es.search(index=index_name,
body=q,
timeout="600s",
track_total_hits=True,
_source=True)
elapsed_time = time.time() - start_time
logger.info(
f"OBConnection.search table {index_name}, search type: hybrid, elapsed time: {elapsed_time:.3f} seconds,"
f" got count: {len(res)}"
)
for chunk in res:
result.chunks.append(self._es_row_to_entity(chunk))
result.total = result.total + 1
return result
output_fields = selectFields.copy()
if "id" not in output_fields:
output_fields = ["id"] + output_fields
if "_score" in output_fields:
output_fields.remove("_score")
if highlightFields:
for field in highlightFields:
if field not in output_fields:
output_fields.append(field)
fields_expr = ", ".join(output_fields)
condition["kb_id"] = knowledgebaseIds
filters: list[str] = get_filters(condition)
filters_expr = " AND ".join(filters)
fulltext_query: Optional[str] = None
fulltext_topn: Optional[int] = None
fulltext_search_weight: dict[str, float] = {}
fulltext_search_expr: dict[str, str] = {}
fulltext_search_idx_list: list[str] = []
fulltext_search_score_expr: Optional[str] = None
fulltext_search_filter: Optional[str] = None
vector_column_name: Optional[str] = None
vector_data: Optional[list[float]] = None
vector_topn: Optional[int] = None
vector_similarity_threshold: Optional[float] = None
vector_similarity_weight: Optional[float] = None
vector_search_expr: Optional[str] = None
vector_search_score_expr: Optional[str] = None
vector_search_filter: Optional[str] = None
for m in matchExprs:
if isinstance(m, MatchTextExpr):
assert "original_query" in m.extra_options, "'original_query' is missing in extra_options."
fulltext_query = m.extra_options["original_query"]
fulltext_query = escape_string(fulltext_query.strip())
fulltext_topn = m.topn
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
# get fulltext match expression and weight values
for field in fts_columns:
parts = field.split("^")
column_name: str = parts[0]
column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0
fulltext_search_weight[column_name] = column_weight
fulltext_search_expr[column_name] = fulltext_search_template % (column_name, fulltext_query)
fulltext_search_idx_list.append(fulltext_index_name_template % column_name)
# adjust the weight to 0~1
weight_sum = sum(fulltext_search_weight.values())
for column_name in fulltext_search_weight.keys():
fulltext_search_weight[column_name] = fulltext_search_weight[column_name] / weight_sum
elif isinstance(m, MatchDenseExpr):
assert m.embedding_data_type == "float", f"embedding data type '{m.embedding_data_type}' is not float."
vector_column_name = m.vector_column_name
vector_data = m.embedding_data
vector_topn = m.topn
vector_similarity_threshold = m.extra_options.get("similarity", 0.0)
elif isinstance(m, FusionExpr):
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
if fulltext_query:
fulltext_search_filter = f"({' OR '.join([expr for expr in fulltext_search_expr.values()])})"
fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})"
if vector_data:
vector_search_expr = vector_search_template % (vector_column_name, vector_data)
# use (1 - cosine_distance) as score, which should be [-1, 1]
# https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323
vector_search_score_expr = f"(1 - {vector_search_expr})"
vector_search_filter = f"{vector_search_score_expr} >= {vector_similarity_threshold}"
pagerank_score_expr = f"(CAST(IFNULL({PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)"
# TODO use tag rank_feature in sorting
# tag_rank_fea = {k: float(v) for k, v in (rank_feature or {}).items() if k != PAGERANK_FLD}
if fulltext_query and vector_data:
search_type = "fusion"
elif fulltext_query:
search_type = "fulltext"
elif vector_data:
search_type = "vector"
elif len(aggFields) > 0:
search_type = "aggregation"
else:
search_type = "filter"
if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields:
output_fields.append("_score")
group_results = kwargs.get("group_results", False)
for index_name in indexNames:
if not self.client.check_table_exists(index_name):
continue
fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else ""
if search_type == "fusion":
# fusion search, usually for chat
num_candidates = vector_topn + fulltext_topn
if group_results:
count_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {num_candidates}"
f"),"
f" scored_results AS ("
f" SELECT *"
f" FROM fulltext_results"
f" WHERE {vector_search_filter}"
f"),"
f" group_results AS ("
f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id) as rn"
f" FROM scored_results"
f")"
f" SELECT COUNT(*)"
f" FROM group_results"
f" WHERE rn = 1"
)
else:
count_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {num_candidates}"
f")"
f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}"
)
logger.debug("OBConnection.search with count sql: %s", count_sql)
start_time = time.time()
res = self.client.perform_raw_text_sql(count_sql)
total_count = res.fetchone()[0] if res else 0
result.total += total_count
elapsed_time = time.time() - start_time
logger.info(
f"OBConnection.search table {index_name}, search type: fusion, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
f" vector column: '{vector_column_name}',"
f" query text: '{fulltext_query}',"
f" condition: '{condition}',"
f" vector_similarity_threshold: {vector_similarity_threshold},"
f" got count: {total_count}"
)
if total_count == 0:
continue
score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})"
if group_results:
fusion_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {num_candidates}"
f"),"
f" scored_results AS ("
f" SELECT *, {score_expr} AS _score"
f" FROM fulltext_results"
f" WHERE {vector_search_filter}"
f"),"
f" group_results AS ("
f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id ORDER BY _score DESC) as rn"
f" FROM scored_results"
f")"
f" SELECT {fields_expr}, _score"
f" FROM group_results"
f" WHERE rn = 1"
f" ORDER BY _score DESC"
f" LIMIT {offset}, {limit}"
)
else:
fusion_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {num_candidates}"
f")"
f" SELECT {fields_expr}, {score_expr} AS _score"
f" FROM fulltext_results"
f" WHERE {vector_search_filter}"
f" ORDER BY _score DESC"
f" LIMIT {offset}, {limit}"
)
logger.debug("OBConnection.search with fusion sql: %s", fusion_sql)
start_time = time.time()
res = self.client.perform_raw_text_sql(fusion_sql)
rows = res.fetchall()
elapsed_time = time.time() - start_time
logger.info(
f"OBConnection.search table {index_name}, search type: fusion, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
f" select fields: '{output_fields}',"
f" vector column: '{vector_column_name}',"
f" query text: '{fulltext_query}',"
f" condition: '{condition}',"
f" vector_similarity_threshold: {vector_similarity_threshold},"
f" vector_similarity_weight: {vector_similarity_weight},"
f" return rows count: {len(rows)}"
)
for row in rows:
result.chunks.append(self._row_to_entity(row, output_fields))
elif search_type == "vector":
# vector search, usually used for graph search
count_sql = f"SELECT COUNT(id) FROM {index_name} WHERE {filters_expr} AND {vector_search_filter}"
logger.debug("OBConnection.search with vector count sql: %s", count_sql)
start_time = time.time()
res = self.client.perform_raw_text_sql(count_sql)
total_count = res.fetchone()[0] if res else 0
result.total += total_count
elapsed_time = time.time() - start_time
logger.info(
f"OBConnection.search table {index_name}, search type: vector, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
f" vector column: '{vector_column_name}',"
f" condition: '{condition}',"
f" vector_similarity_threshold: {vector_similarity_threshold},"
f" got count: {total_count}"
)
if total_count == 0:
continue
vector_sql = (
f"SELECT {fields_expr}, {vector_search_score_expr} AS _score"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {vector_search_filter}"
f" ORDER BY {vector_search_expr}"
f" APPROXIMATE LIMIT {limit if limit != 0 else vector_topn}"
)
if offset != 0:
vector_sql += f" OFFSET {offset}"
logger.debug("OBConnection.search with vector sql: %s", vector_sql)
start_time = time.time()
res = self.client.perform_raw_text_sql(vector_sql)
rows = res.fetchall()
elapsed_time = time.time() - start_time
logger.info(
f"OBConnection.search table {index_name}, search type: vector, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
f" select fields: '{output_fields}',"
f" vector column: '{vector_column_name}',"
f" condition: '{condition}',"
f" vector_similarity_threshold: {vector_similarity_threshold},"
f" return rows count: {len(rows)}"
)
for row in rows:
result.chunks.append(self._row_to_entity(row, output_fields))
elif search_type == "fulltext":
# fulltext search, usually used to search chunks in one dataset
count_sql = f"SELECT {fulltext_search_hint} COUNT(id) FROM {index_name} WHERE {filters_expr} AND {fulltext_search_filter}"
logger.debug("OBConnection.search with fulltext count sql: %s", count_sql)
start_time = time.time()
res = self.client.perform_raw_text_sql(count_sql)
total_count = res.fetchone()[0] if res else 0
result.total += total_count
elapsed_time = time.time() - start_time
logger.info(
f"OBConnection.search table {index_name}, search type: fulltext, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
f" query text: '{fulltext_query}',"
f" condition: '{condition}',"
f" got count: {total_count}"
)
if total_count == 0:
continue
fulltext_sql = (
f"SELECT {fulltext_search_hint} {fields_expr}, {fulltext_search_score_expr} AS _score"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY _score DESC"
f" LIMIT {offset}, {limit if limit != 0 else fulltext_topn}"
)
logger.debug("OBConnection.search with fulltext sql: %s", fulltext_sql)
start_time = time.time()
res = self.client.perform_raw_text_sql(fulltext_sql)
rows = res.fetchall()
elapsed_time = time.time() - start_time
logger.info(
f"OBConnection.search table {index_name}, search type: fulltext, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
f" select fields: '{output_fields}',"
f" query text: '{fulltext_query}',"
f" condition: '{condition}',"
f" return rows count: {len(rows)}"
)
for row in rows:
result.chunks.append(self._row_to_entity(row, output_fields))
elif search_type == "aggregation":
# aggregation search
assert len(aggFields) == 1, "Only one aggregation field is supported in OceanBase."
agg_field = aggFields[0]
if agg_field in array_columns:
res = self.client.perform_raw_text_sql(
f"SELECT {agg_field} FROM {index_name}"
f" WHERE {agg_field} IS NOT NULL AND {filters_expr}"
)
counts = {}
for row in res:
if row[0]:
if isinstance(row[0], str):
try:
arr = json.loads(row[0])
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON array: {row[0]}")
continue
else:
arr = row[0]
if isinstance(arr, list):
for v in arr:
if isinstance(v, str) and v.strip():
counts[v] = counts.get(v, 0) + 1
for v, count in counts.items():
result.chunks.append({
"value": v,
"count": count,
})
result.total += len(counts)
else:
res = self.client.perform_raw_text_sql(
f"SELECT {agg_field}, COUNT(*) as count FROM {index_name}"
f" WHERE {agg_field} IS NOT NULL AND {filters_expr}"
f" GROUP BY {agg_field}"
)
for row in res:
result.chunks.append({
"value": row[0],
"count": int(row[1]),
})
result.total += 1
else:
# only filter
orders: list[str] = []
if orderBy:
for field, order in orderBy.fields:
if isinstance(column_types[field], ARRAY):
f = field + "_sort"
fields_expr += f", array_to_string({field}, ',') AS {f}"
field = f
order = "ASC" if order == 0 else "DESC"
orders.append(f"{field} {order}")
count_sql = f"SELECT COUNT(id) FROM {index_name} WHERE {filters_expr}"
logger.debug("OBConnection.search with normal count sql: %s", count_sql)
start_time = time.time()
res = self.client.perform_raw_text_sql(count_sql)
total_count = res.fetchone()[0] if res else 0
result.total += total_count
elapsed_time = time.time() - start_time
logger.info(
f"OBConnection.search table {index_name}, search type: normal, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
f" condition: '{condition}',"
f" got count: {total_count}"
)
if total_count == 0:
continue
order_by_expr = ("ORDER BY " + ", ".join(orders)) if len(orders) > 0 else ""
limit_expr = f"LIMIT {offset}, {limit}" if limit != 0 else ""
filter_sql = (
f"SELECT {fields_expr}"
f" FROM {index_name}"
f" WHERE {filters_expr}"
f" {order_by_expr} {limit_expr}"
)
logger.debug("OBConnection.search with normal sql: %s", filter_sql)
start_time = time.time()
res = self.client.perform_raw_text_sql(filter_sql)
rows = res.fetchall()
elapsed_time = time.time() - start_time
logger.info(
f"OBConnection.search table {index_name}, search type: normal, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
f" select fields: '{output_fields}',"
f" condition: '{condition}',"
f" return rows count: {len(rows)}"
)
for row in rows:
result.chunks.append(self._row_to_entity(row, output_fields))
return result
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
if not self.client.check_table_exists(indexName):
return None
try:
res = self.client.get(
table_name=indexName,
ids=[chunkId],
)
row = res.fetchone()
if row is None:
raise Exception(f"ChunkId {chunkId} not found in index {indexName}.")
return self._row_to_entity(row, fields=list(res.keys()))
except json.JSONDecodeError as e:
logger.error(f"JSON decode error when getting chunk {chunkId}: {str(e)}")
# 如果JSON解析失败尝试返回一个基本的chunk信息
return {
"id": chunkId,
"error": f"Failed to parse chunk data due to invalid JSON: {str(e)}"
}
except Exception as e:
logger.error(f"Error getting chunk {chunkId}: {str(e)}")
raise
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
if not documents:
return []
docs: list[dict] = []
ids: list[str] = []
for document in documents:
d: dict = {}
for k, v in document.items():
if vector_column_pattern.match(k):
d[k] = v
continue
if k not in column_names:
if "extra" not in d:
d["extra"] = {}
d["extra"][k] = v
continue
if v is None:
d[k] = get_default_value(k)
continue
if k == "kb_id" and isinstance(v, list):
d[k] = v[0]
elif k == "content_with_weight" and isinstance(v, dict):
d[k] = json.dumps(v, ensure_ascii=False)
elif k == "position_int":
d[k] = json.dumps([list(vv) for vv in v], ensure_ascii=False)
elif isinstance(v, list):
# remove characters like '\t' for JSON dump and clean special characters
cleaned_v = []
for vv in v:
if isinstance(vv, str):
# 清理可能导致JSON解析错误的特殊字符
cleaned_str = vv.strip()
# 移除或替换可能导致JSON解析错误的字符
cleaned_str = cleaned_str.replace('\\', '\\\\') # 转义反斜杠
cleaned_str = cleaned_str.replace('\n', '\\n') # 转义换行符
cleaned_str = cleaned_str.replace('\r', '\\r') # 转义回车符
cleaned_str = cleaned_str.replace('\t', '\\t') # 转义制表符
cleaned_v.append(cleaned_str)
else:
cleaned_v.append(vv)
d[k] = json.dumps(cleaned_v, ensure_ascii=False)
else:
d[k] = v
ids.append(d["id"])
# this is to fix https://github.com/sqlalchemy/sqlalchemy/issues/9703
for column_name in column_names:
if column_name not in d:
d[column_name] = get_default_value(column_name)
metadata = d.get("metadata", {})
if metadata is None:
metadata = {}
group_id = metadata.get("_group_id")
title = metadata.get("_title")
if d.get("doc_id"):
if group_id:
d["group_id"] = group_id
else:
d["group_id"] = d["doc_id"]
if title:
d["docnm_kwd"] = title
docs.append(d)
logger.debug("OBConnection.insert chunks: %s", docs)
res = []
try:
self.client.upsert(indexName, docs)
except Exception as e:
logger.error(f"OBConnection.insert error: {str(e)}")
res.append(str(e))
return res
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
if not self.client.check_table_exists(indexName):
return True
condition["kb_id"] = knowledgebaseId
filters = get_filters(condition)
set_values: list[str] = []
for k, v in newValue.items():
if k == "remove":
if isinstance(v, str):
set_values.append(f"{v} = NULL")
else:
assert isinstance(v, dict), f"Expected str or dict for 'remove', got {type(newValue[k])}."
for kk, vv in v.items():
assert kk in array_columns, f"Column '{kk}' is not an array column."
set_values.append(f"{kk} = array_remove({kk}, {get_value_str(vv)})")
elif k == "add":
assert isinstance(v, dict), f"Expected str or dict for 'add', got {type(newValue[k])}."
for kk, vv in v.items():
assert kk in array_columns, f"Column '{kk}' is not an array column."
set_values.append(f"{kk} = array_append({kk}, {get_value_str(vv)})")
elif k == "metadata":
assert isinstance(v, dict), f"Expected dict for 'metadata', got {type(newValue[k])}"
set_values.append(f"{k} = {get_value_str(v)}")
if v and "doc_id" in condition:
group_id = v.get("_group_id")
title = v.get("_title")
if group_id:
set_values.append(f"group_id = {get_value_str(group_id)}")
if title:
set_values.append(f"docnm_kwd = {get_value_str(title)}")
else:
set_values.append(f"{k} = {get_value_str(v)}")
if not set_values:
return True
update_sql = (
f"UPDATE {indexName}"
f" SET {', '.join(set_values)}"
f" WHERE {' AND '.join(filters)}"
)
logger.debug("OBConnection.update sql: %s", update_sql)
try:
self.client.perform_raw_text_sql(update_sql)
return True
except Exception as e:
logger.error(f"OBConnection.update error: {str(e)}")
return False
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
if not self.client.check_table_exists(indexName):
return 0
condition["kb_id"] = knowledgebaseId
try:
res = self.client.get(
table_name=indexName,
ids=None,
where_clause=[text(f) for f in get_filters(condition)],
output_column_name=["id"],
)
rows = res.fetchall()
if len(rows) == 0:
return 0
ids = [row[0] for row in rows]
logger.debug(f"OBConnection.delete chunks, filters: {condition}, ids: {ids}")
self.client.delete(
table_name=indexName,
ids=ids,
)
return len(ids)
except Exception as e:
logger.error(f"OBConnection.delete error: {str(e)}")
return 0
@staticmethod
def _row_to_entity(data: Row, fields: list[str]) -> dict:
entity = {}
for i, field in enumerate(fields):
value = data[i]
if value is None:
continue
entity[field] = get_column_value(field, value)
return entity
@staticmethod
def _es_row_to_entity(data: dict) -> dict:
entity = {}
for k, v in data.items():
if v is None:
continue
entity[k] = get_column_value(k, v)
return entity
"""
Helper functions for search result
"""
def getTotal(self, res) -> int:
return res.total
def getChunkIds(self, res) -> list[str]:
return [row["id"] for row in res.chunks]
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
result = {}
for row in res.chunks:
data = {}
for field in fields:
v = row.get(field)
if v is not None:
data[field] = v
result[row["id"]] = data
return result
# copied from query.FulltextQueryer
def is_chinese(self, line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
def highlight(self, txt: str, tks: str, question: str, keywords: list[str]) -> Optional[str]:
if not txt or not keywords:
return None
highlighted_txt = txt
if question and not self.is_chinese(question):
highlighted_txt = re.sub(
r"(^|\W)(%s)(\W|$)" % re.escape(question),
r"\1<em>\2</em>\3", highlighted_txt,
flags=re.IGNORECASE | re.MULTILINE,
)
if re.search(r"<em>[^<>]+</em>", highlighted_txt, flags=re.IGNORECASE | re.MULTILINE):
return highlighted_txt
for keyword in keywords:
highlighted_txt = re.sub(
r"(^|\W)(%s)(\W|$)" % re.escape(keyword),
r"\1<em>\2</em>\3", highlighted_txt,
flags=re.IGNORECASE | re.MULTILINE,
)
if len(re.findall(r'</em><em>', highlighted_txt)) > 0 or len(
re.findall(r'</em>\s*<em>', highlighted_txt)) > 0:
return highlighted_txt
else:
return None
if not tks:
tks = rag_tokenizer.tokenize(txt)
tokens = tks.split()
if not tokens:
return None
last_pos = len(txt)
for i in range(len(tokens) - 1, -1, -1):
token = tokens[i]
token_pos = highlighted_txt.rfind(token, 0, last_pos)
if token_pos != -1:
if token in keywords:
highlighted_txt = (
highlighted_txt[:token_pos] +
f'<em>{token}</em>' +
highlighted_txt[token_pos + len(token):]
)
last_pos = token_pos
return re.sub(r'</em><em>', '', highlighted_txt)
def getHighlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
if len(res.chunks) == 0 or len(keywords) == 0:
return ans
for d in res.chunks:
txt = d.get(fieldnm)
if not txt:
continue
tks = d.get("content_ltks") if fieldnm == "content_with_weight" else ""
highlighted_txt = self.highlight(txt, tks, " ".join(keywords), keywords)
if highlighted_txt:
ans[d["id"]] = highlighted_txt
return ans
def getAggregation(self, res, fieldnm: str):
if len(res.chunks) == 0:
return []
counts = {}
result = []
for d in res.chunks:
if "value" in d and "count" in d:
# directly use the aggregation result
result.append((d["value"], d["count"]))
elif fieldnm in d:
# aggregate the values of specific field
v = d[fieldnm]
if isinstance(v, list):
for vv in v:
if isinstance(vv, str) and vv.strip():
counts[vv] = counts.get(vv, 0) + 1
elif isinstance(v, str) and v.strip():
counts[v] = counts.get(v, 0) + 1
if len(counts) > 0:
for k, v in counts.items():
result.append((k, v))
return result
"""
SQL
"""
def sql(sql: str, fetch_size: int, format: str):
# TODO: execute the sql generated by text-to-sql
return None