use oracle bind variables to avoid error
This commit is contained in:
parent
33caba3e12
commit
d6589684ef
4 changed files with 193 additions and 146 deletions
|
|
@ -17,6 +17,7 @@ T = TypeVar("T")
|
||||||
class QueryParam:
|
class QueryParam:
|
||||||
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
||||||
only_need_context: bool = False
|
only_need_context: bool = False
|
||||||
|
only_need_prompt: bool = False
|
||||||
response_type: str = "Multiple Paragraphs"
|
response_type: str = "Multiple Paragraphs"
|
||||||
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
||||||
top_k: int = 60
|
top_k: int = 60
|
||||||
|
|
|
||||||
|
|
@ -114,16 +114,17 @@ class OracleDB:
|
||||||
|
|
||||||
logger.info("Finished check all tables in Oracle database")
|
logger.info("Finished check all tables in Oracle database")
|
||||||
|
|
||||||
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
|
async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
|
||||||
async with self.pool.acquire() as connection:
|
async with self.pool.acquire() as connection:
|
||||||
connection.inputtypehandler = self.input_type_handler
|
connection.inputtypehandler = self.input_type_handler
|
||||||
connection.outputtypehandler = self.output_type_handler
|
connection.outputtypehandler = self.output_type_handler
|
||||||
with connection.cursor() as cursor:
|
with connection.cursor() as cursor:
|
||||||
try:
|
try:
|
||||||
await cursor.execute(sql)
|
await cursor.execute(sql, params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Oracle database error: {e}")
|
logger.error(f"Oracle database error: {e}")
|
||||||
print(sql)
|
print(sql)
|
||||||
|
print(params)
|
||||||
raise
|
raise
|
||||||
columns = [column[0].lower() for column in cursor.description]
|
columns = [column[0].lower() for column in cursor.description]
|
||||||
if multirows:
|
if multirows:
|
||||||
|
|
@ -140,7 +141,7 @@ class OracleDB:
|
||||||
data = None
|
data = None
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def execute(self, sql: str, data: list = None):
|
async def execute(self, sql: str, data: list | dict = None):
|
||||||
# logger.info("go into OracleDB execute method")
|
# logger.info("go into OracleDB execute method")
|
||||||
try:
|
try:
|
||||||
async with self.pool.acquire() as connection:
|
async with self.pool.acquire() as connection:
|
||||||
|
|
@ -172,11 +173,10 @@ class OracleKVStorage(BaseKVStorage):
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||||
"""根据 id 获取 doc_full 数据."""
|
"""根据 id 获取 doc_full 数据."""
|
||||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
|
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
workspace=self.db.workspace, id=id
|
params = {"workspace":self.db.workspace, "id":id}
|
||||||
)
|
|
||||||
# print("get_by_id:"+SQL)
|
# print("get_by_id:"+SQL)
|
||||||
res = await self.db.query(SQL)
|
res = await self.db.query(SQL,params)
|
||||||
if res:
|
if res:
|
||||||
data = res # {"data":res}
|
data = res # {"data":res}
|
||||||
# print (data)
|
# print (data)
|
||||||
|
|
@ -187,11 +187,11 @@ class OracleKVStorage(BaseKVStorage):
|
||||||
# Query by id
|
# Query by id
|
||||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||||
"""根据 id 获取 doc_chunks 数据"""
|
"""根据 id 获取 doc_chunks 数据"""
|
||||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
|
||||||
workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
|
params = {"workspace":self.db.workspace}
|
||||||
)
|
#print("get_by_ids:"+SQL)
|
||||||
# print("get_by_ids:"+SQL)
|
#print(params)
|
||||||
res = await self.db.query(SQL, multirows=True)
|
res = await self.db.query(SQL,params, multirows=True)
|
||||||
if res:
|
if res:
|
||||||
data = res # [{"data":i} for i in res]
|
data = res # [{"data":i} for i in res]
|
||||||
# print(data)
|
# print(data)
|
||||||
|
|
@ -201,12 +201,16 @@ class OracleKVStorage(BaseKVStorage):
|
||||||
|
|
||||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||||
"""过滤掉重复内容"""
|
"""过滤掉重复内容"""
|
||||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
|
||||||
table_name=N_T[self.namespace],
|
ids=",".join([f"'{id}'" for id in keys]))
|
||||||
workspace=self.db.workspace,
|
params = {"workspace":self.db.workspace}
|
||||||
ids=",".join([f"'{k}'" for k in keys]),
|
try:
|
||||||
)
|
await self.db.query(SQL, params)
|
||||||
res = await self.db.query(SQL, multirows=True)
|
except Exception as e:
|
||||||
|
logger.error(f"Oracle database error: {e}")
|
||||||
|
print(SQL)
|
||||||
|
print(params)
|
||||||
|
res = await self.db.query(SQL, params,multirows=True)
|
||||||
data = None
|
data = None
|
||||||
if res:
|
if res:
|
||||||
exist_keys = [key["id"] for key in res]
|
exist_keys = [key["id"] for key in res]
|
||||||
|
|
@ -243,29 +247,31 @@ class OracleKVStorage(BaseKVStorage):
|
||||||
d["__vector__"] = embeddings[i]
|
d["__vector__"] = embeddings[i]
|
||||||
# print(list_data)
|
# print(list_data)
|
||||||
for item in list_data:
|
for item in list_data:
|
||||||
merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
|
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
||||||
|
data = {"check_id":item["__id__"],
|
||||||
values = [
|
"id":item["__id__"],
|
||||||
item["__id__"],
|
"content":item["content"],
|
||||||
item["content"],
|
"workspace":self.db.workspace,
|
||||||
self.db.workspace,
|
"tokens":item["tokens"],
|
||||||
item["tokens"],
|
"chunk_order_index":item["chunk_order_index"],
|
||||||
item["chunk_order_index"],
|
"full_doc_id":item["full_doc_id"],
|
||||||
item["full_doc_id"],
|
"content_vector":item["__vector__"]
|
||||||
item["__vector__"],
|
}
|
||||||
]
|
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, values)
|
await self.db.execute(merge_sql, data)
|
||||||
|
|
||||||
if self.namespace == "full_docs":
|
if self.namespace == "full_docs":
|
||||||
for k, v in self._data.items():
|
for k, v in self._data.items():
|
||||||
# values.clear()
|
# values.clear()
|
||||||
merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
|
merge_sql = SQL_TEMPLATES["merge_doc_full"]
|
||||||
check_id=k,
|
data = {
|
||||||
)
|
"check_id":k,
|
||||||
values = [k, self._data[k]["content"], self.db.workspace]
|
"id":k,
|
||||||
|
"content":v["content"],
|
||||||
|
"workspace":self.db.workspace
|
||||||
|
}
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, values)
|
await self.db.execute(merge_sql, data)
|
||||||
return left_data
|
return left_data
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
|
|
@ -295,18 +301,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||||
# 转换精度
|
# 转换精度
|
||||||
dtype = str(embedding.dtype).upper()
|
dtype = str(embedding.dtype).upper()
|
||||||
dimension = embedding.shape[0]
|
dimension = embedding.shape[0]
|
||||||
embedding_string = ", ".join(map(str, embedding.tolist()))
|
embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
|
||||||
|
|
||||||
SQL = SQL_TEMPLATES[self.namespace].format(
|
SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
|
||||||
embedding_string=embedding_string,
|
params = {
|
||||||
dimension=dimension,
|
"embedding_string": embedding_string,
|
||||||
dtype=dtype,
|
"workspace": self.db.workspace,
|
||||||
workspace=self.db.workspace,
|
"top_k": top_k,
|
||||||
top_k=top_k,
|
"better_than_threshold": self.cosine_better_than_threshold,
|
||||||
better_than_threshold=self.cosine_better_than_threshold,
|
}
|
||||||
)
|
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
results = await self.db.query(SQL, multirows=True)
|
results = await self.db.query(SQL,params=params, multirows=True)
|
||||||
# print("vector search result:",results)
|
# print("vector search result:",results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
@ -339,22 +344,18 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
embeddings = np.concatenate(embeddings_list)
|
embeddings = np.concatenate(embeddings_list)
|
||||||
content_vector = embeddings[0]
|
content_vector = embeddings[0]
|
||||||
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
merge_sql = SQL_TEMPLATES["merge_node"]
|
||||||
workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
|
data = {
|
||||||
)
|
"workspace":self.db.workspace,
|
||||||
|
"name":entity_name,
|
||||||
|
"entity_type":entity_type,
|
||||||
|
"description":description,
|
||||||
|
"source_chunk_id":source_id,
|
||||||
|
"content":content,
|
||||||
|
"content_vector":content_vector
|
||||||
|
}
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(
|
await self.db.execute(merge_sql,data)
|
||||||
merge_sql,
|
|
||||||
[
|
|
||||||
self.db.workspace,
|
|
||||||
entity_name,
|
|
||||||
entity_type,
|
|
||||||
description,
|
|
||||||
source_id,
|
|
||||||
content,
|
|
||||||
content_vector,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
# self._graph.add_node(node_id, **node_data)
|
# self._graph.add_node(node_id, **node_data)
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
|
|
@ -379,27 +380,20 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
embeddings = np.concatenate(embeddings_list)
|
embeddings = np.concatenate(embeddings_list)
|
||||||
content_vector = embeddings[0]
|
content_vector = embeddings[0]
|
||||||
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
merge_sql = SQL_TEMPLATES["merge_edge"]
|
||||||
workspace=self.db.workspace,
|
data = {
|
||||||
source_name=source_name,
|
"workspace":self.db.workspace,
|
||||||
target_name=target_name,
|
"source_name":source_name,
|
||||||
source_chunk_id=source_chunk_id,
|
"target_name":target_name,
|
||||||
)
|
"weight":weight,
|
||||||
|
"keywords":keywords,
|
||||||
|
"description":description,
|
||||||
|
"source_chunk_id":source_chunk_id,
|
||||||
|
"content":content,
|
||||||
|
"content_vector":content_vector
|
||||||
|
}
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(
|
await self.db.execute(merge_sql,data)
|
||||||
merge_sql,
|
|
||||||
[
|
|
||||||
self.db.workspace,
|
|
||||||
source_name,
|
|
||||||
target_name,
|
|
||||||
weight,
|
|
||||||
keywords,
|
|
||||||
description,
|
|
||||||
source_chunk_id,
|
|
||||||
content,
|
|
||||||
content_vector,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||||
|
|
@ -429,12 +423,14 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||||
#################### query method #################
|
#################### query method #################
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
"""根据节点id检查节点是否存在"""
|
"""根据节点id检查节点是否存在"""
|
||||||
SQL = SQL_TEMPLATES["has_node"].format(
|
SQL = SQL_TEMPLATES["has_node"]
|
||||||
workspace=self.db.workspace, node_id=node_id
|
params = {
|
||||||
)
|
"workspace":self.db.workspace,
|
||||||
|
"node_id":node_id
|
||||||
|
}
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
# print(self.db.workspace, node_id)
|
# print(self.db.workspace, node_id)
|
||||||
res = await self.db.query(SQL)
|
res = await self.db.query(SQL,params)
|
||||||
if res:
|
if res:
|
||||||
# print("Node exist!",res)
|
# print("Node exist!",res)
|
||||||
return True
|
return True
|
||||||
|
|
@ -444,13 +440,14 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
"""根据源和目标节点id检查边是否存在"""
|
"""根据源和目标节点id检查边是否存在"""
|
||||||
SQL = SQL_TEMPLATES["has_edge"].format(
|
SQL = SQL_TEMPLATES["has_edge"]
|
||||||
workspace=self.db.workspace,
|
params = {
|
||||||
source_node_id=source_node_id,
|
"workspace":self.db.workspace,
|
||||||
target_node_id=target_node_id,
|
"source_node_id":source_node_id,
|
||||||
)
|
"target_node_id":target_node_id
|
||||||
|
}
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL)
|
res = await self.db.query(SQL,params)
|
||||||
if res:
|
if res:
|
||||||
# print("Edge exist!",res)
|
# print("Edge exist!",res)
|
||||||
return True
|
return True
|
||||||
|
|
@ -460,11 +457,13 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
"""根据节点id获取节点的度"""
|
"""根据节点id获取节点的度"""
|
||||||
SQL = SQL_TEMPLATES["node_degree"].format(
|
SQL = SQL_TEMPLATES["node_degree"]
|
||||||
workspace=self.db.workspace, node_id=node_id
|
params = {
|
||||||
)
|
"workspace":self.db.workspace,
|
||||||
|
"node_id":node_id
|
||||||
|
}
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL)
|
res = await self.db.query(SQL,params)
|
||||||
if res:
|
if res:
|
||||||
# print("Node degree",res["degree"])
|
# print("Node degree",res["degree"])
|
||||||
return res["degree"]
|
return res["degree"]
|
||||||
|
|
@ -480,12 +479,14 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||||
"""根据节点id获取节点数据"""
|
"""根据节点id获取节点数据"""
|
||||||
SQL = SQL_TEMPLATES["get_node"].format(
|
SQL = SQL_TEMPLATES["get_node"]
|
||||||
workspace=self.db.workspace, node_id=node_id
|
params = {
|
||||||
)
|
"workspace":self.db.workspace,
|
||||||
|
"node_id":node_id
|
||||||
|
}
|
||||||
# print(self.db.workspace, node_id)
|
# print(self.db.workspace, node_id)
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL)
|
res = await self.db.query(SQL,params)
|
||||||
if res:
|
if res:
|
||||||
# print("Get node!",self.db.workspace, node_id,res)
|
# print("Get node!",self.db.workspace, node_id,res)
|
||||||
return res
|
return res
|
||||||
|
|
@ -497,12 +498,13 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> Union[dict, None]:
|
||||||
"""根据源和目标节点id获取边"""
|
"""根据源和目标节点id获取边"""
|
||||||
SQL = SQL_TEMPLATES["get_edge"].format(
|
SQL = SQL_TEMPLATES["get_edge"]
|
||||||
workspace=self.db.workspace,
|
params = {
|
||||||
source_node_id=source_node_id,
|
"workspace":self.db.workspace,
|
||||||
target_node_id=target_node_id,
|
"source_node_id":source_node_id,
|
||||||
)
|
"target_node_id":target_node_id
|
||||||
res = await self.db.query(SQL)
|
}
|
||||||
|
res = await self.db.query(SQL,params)
|
||||||
if res:
|
if res:
|
||||||
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
||||||
return res
|
return res
|
||||||
|
|
@ -513,10 +515,12 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||||
async def get_node_edges(self, source_node_id: str):
|
async def get_node_edges(self, source_node_id: str):
|
||||||
"""根据节点id获取节点的所有边"""
|
"""根据节点id获取节点的所有边"""
|
||||||
if await self.has_node(source_node_id):
|
if await self.has_node(source_node_id):
|
||||||
SQL = SQL_TEMPLATES["get_node_edges"].format(
|
SQL = SQL_TEMPLATES["get_node_edges"]
|
||||||
workspace=self.db.workspace, source_node_id=source_node_id
|
params = {
|
||||||
)
|
"workspace":self.db.workspace,
|
||||||
res = await self.db.query(sql=SQL, multirows=True)
|
"source_node_id":source_node_id
|
||||||
|
}
|
||||||
|
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||||
if res:
|
if res:
|
||||||
data = [(i["source_name"], i["target_name"]) for i in res]
|
data = [(i["source_name"], i["target_name"]) for i in res]
|
||||||
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
||||||
|
|
@ -524,8 +528,22 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||||
else:
|
else:
|
||||||
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def get_all_nodes(self, limit: int):
|
||||||
|
"""查询所有节点"""
|
||||||
|
SQL = SQL_TEMPLATES["get_all_nodes"]
|
||||||
|
params = {"workspace":self.db.workspace, "limit":str(limit)}
|
||||||
|
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
||||||
|
if res:
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def get_all_edges(self, limit: int):
|
||||||
|
"""查询所有边"""
|
||||||
|
SQL = SQL_TEMPLATES["get_all_edges"]
|
||||||
|
params = {"workspace":self.db.workspace, "limit":str(limit)}
|
||||||
|
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
||||||
|
if res:
|
||||||
|
return res
|
||||||
N_T = {
|
N_T = {
|
||||||
"full_docs": "LIGHTRAG_DOC_FULL",
|
"full_docs": "LIGHTRAG_DOC_FULL",
|
||||||
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||||
|
|
@ -619,82 +637,96 @@ TABLES = {
|
||||||
|
|
||||||
SQL_TEMPLATES = {
|
SQL_TEMPLATES = {
|
||||||
# SQL for KVStorage
|
# SQL for KVStorage
|
||||||
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
|
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
|
||||||
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
|
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
|
||||||
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
|
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})",
|
||||||
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
|
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
|
||||||
"filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
|
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
|
||||||
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
|
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
|
||||||
USING DUAL
|
USING DUAL
|
||||||
ON (a.id = '{check_id}')
|
ON (a.id = :check_id)
|
||||||
WHEN NOT MATCHED THEN
|
WHEN NOT MATCHED THEN
|
||||||
INSERT(id,content,workspace) values(:1,:2,:3)
|
INSERT(id,content,workspace) values(:id,:content,:workspace)
|
||||||
""",
|
""",
|
||||||
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
|
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
|
||||||
USING DUAL
|
USING DUAL
|
||||||
ON (a.id = '{check_id}')
|
ON (a.id = :check_id)
|
||||||
WHEN NOT MATCHED THEN
|
WHEN NOT MATCHED THEN
|
||||||
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
|
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
|
||||||
values (:1,:2,:3,:4,:5,:6,:7) """,
|
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
|
||||||
# SQL for VectorStorage
|
# SQL for VectorStorage
|
||||||
"entities": """SELECT name as entity_name FROM
|
"entities": """SELECT name as entity_name FROM
|
||||||
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
||||||
FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
|
FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace)
|
||||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
||||||
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
|
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
|
||||||
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
||||||
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
|
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace)
|
||||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
||||||
"chunks": """SELECT id FROM
|
"chunks": """SELECT id FROM
|
||||||
(SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
(SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
||||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
|
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace)
|
||||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
||||||
# SQL for GraphStorage
|
# SQL for GraphStorage
|
||||||
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
||||||
MATCH (a)
|
MATCH (a)
|
||||||
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
WHERE a.workspace=:workspace AND a.name=:node_id
|
||||||
COLUMNS (a.name))""",
|
COLUMNS (a.name))""",
|
||||||
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
||||||
MATCH (a) -[e]-> (b)
|
MATCH (a) -[e]-> (b)
|
||||||
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
||||||
AND a.name='{source_node_id}' AND b.name='{target_node_id}'
|
AND a.name=:source_node_id AND b.name=:target_node_id
|
||||||
COLUMNS (e.source_name,e.target_name) )""",
|
COLUMNS (e.source_name,e.target_name) )""",
|
||||||
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
||||||
MATCH (a)-[e]->(b)
|
MATCH (a)-[e]->(b)
|
||||||
WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
WHERE a.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
||||||
AND a.name='{node_id}' or b.name = '{node_id}'
|
AND a.name=:node_id or b.name = :node_id
|
||||||
COLUMNS (a.name))""",
|
COLUMNS (a.name))""",
|
||||||
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
||||||
FROM GRAPH_TABLE (lightrag_graph
|
FROM GRAPH_TABLE (lightrag_graph
|
||||||
MATCH (a)
|
MATCH (a)
|
||||||
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
WHERE a.workspace=:workspace AND a.name=:node_id
|
||||||
COLUMNS (a.name)
|
COLUMNS (a.name)
|
||||||
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
||||||
WHERE t2.workspace='{workspace}'""",
|
WHERE t2.workspace=:workspace""",
|
||||||
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
||||||
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
||||||
FROM GRAPH_TABLE (lightrag_graph
|
FROM GRAPH_TABLE (lightrag_graph
|
||||||
MATCH (a)-[e]->(b)
|
MATCH (a)-[e]->(b)
|
||||||
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
||||||
AND a.name='{source_node_id}' and b.name = '{target_node_id}'
|
AND a.name=:source_node_id and b.name = :target_node_id
|
||||||
COLUMNS (e.id,a.name as source_id)
|
COLUMNS (e.id,a.name as source_id)
|
||||||
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
||||||
"get_node_edges": """SELECT source_name,target_name
|
"get_node_edges": """SELECT source_name,target_name
|
||||||
FROM GRAPH_TABLE (lightrag_graph
|
FROM GRAPH_TABLE (lightrag_graph
|
||||||
MATCH (a)-[e]->(b)
|
MATCH (a)-[e]->(b)
|
||||||
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
||||||
AND a.name='{source_node_id}'
|
AND a.name=:source_node_id
|
||||||
COLUMNS (a.name as source_name,b.name as target_name))""",
|
COLUMNS (a.name as source_name,b.name as target_name))""",
|
||||||
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
||||||
USING DUAL
|
USING DUAL
|
||||||
ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
|
ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id)
|
||||||
WHEN NOT MATCHED THEN
|
WHEN NOT MATCHED THEN
|
||||||
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
|
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
|
||||||
values (:1,:2,:3,:4,:5,:6,:7) """,
|
values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """,
|
||||||
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
|
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
|
||||||
USING DUAL
|
USING DUAL
|
||||||
ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
|
ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
|
||||||
WHEN NOT MATCHED THEN
|
WHEN NOT MATCHED THEN
|
||||||
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
||||||
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
|
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
|
||||||
|
"get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content
|
||||||
|
FROM LIGHTRAG_GRAPH_NODES t1
|
||||||
|
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
||||||
|
WHERE t1.workspace=:workspace
|
||||||
|
order by t1.CREATETIME DESC
|
||||||
|
fetch first :limit rows only
|
||||||
|
""",
|
||||||
|
"get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
|
||||||
|
t1.weight,t1.DESCRIPTION,t2.content
|
||||||
|
FROM LIGHTRAG_GRAPH_EDGES t1
|
||||||
|
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
||||||
|
WHERE t1.workspace=:workspace
|
||||||
|
order by t1.CREATETIME DESC
|
||||||
|
fetch first :limit rows only"""
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -405,12 +405,13 @@ async def local_query(
|
||||||
kw_prompt = kw_prompt_temp.format(query=query)
|
kw_prompt = kw_prompt_temp.format(query=query)
|
||||||
result = await use_model_func(kw_prompt)
|
result = await use_model_func(kw_prompt)
|
||||||
json_text = locate_json_string_body_from_string(result)
|
json_text = locate_json_string_body_from_string(result)
|
||||||
|
logger.debug("local_query json_text:", json_text)
|
||||||
try:
|
try:
|
||||||
keywords_data = json.loads(json_text)
|
keywords_data = json.loads(json_text)
|
||||||
keywords = keywords_data.get("low_level_keywords", [])
|
keywords = keywords_data.get("low_level_keywords", [])
|
||||||
keywords = ", ".join(keywords)
|
keywords = ", ".join(keywords)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
print(result)
|
||||||
try:
|
try:
|
||||||
result = (
|
result = (
|
||||||
result.replace(kw_prompt[:-1], "")
|
result.replace(kw_prompt[:-1], "")
|
||||||
|
|
@ -443,6 +444,8 @@ async def local_query(
|
||||||
sys_prompt = sys_prompt_temp.format(
|
sys_prompt = sys_prompt_temp.format(
|
||||||
context_data=context, response_type=query_param.response_type
|
context_data=context, response_type=query_param.response_type
|
||||||
)
|
)
|
||||||
|
if query_param.only_need_prompt:
|
||||||
|
return sys_prompt
|
||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
|
|
@ -672,12 +675,12 @@ async def global_query(
|
||||||
kw_prompt = kw_prompt_temp.format(query=query)
|
kw_prompt = kw_prompt_temp.format(query=query)
|
||||||
result = await use_model_func(kw_prompt)
|
result = await use_model_func(kw_prompt)
|
||||||
json_text = locate_json_string_body_from_string(result)
|
json_text = locate_json_string_body_from_string(result)
|
||||||
|
logger.debug("global json_text:", json_text)
|
||||||
try:
|
try:
|
||||||
keywords_data = json.loads(json_text)
|
keywords_data = json.loads(json_text)
|
||||||
keywords = keywords_data.get("high_level_keywords", [])
|
keywords = keywords_data.get("high_level_keywords", [])
|
||||||
keywords = ", ".join(keywords)
|
keywords = ", ".join(keywords)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
try:
|
try:
|
||||||
result = (
|
result = (
|
||||||
result.replace(kw_prompt[:-1], "")
|
result.replace(kw_prompt[:-1], "")
|
||||||
|
|
@ -714,6 +717,8 @@ async def global_query(
|
||||||
sys_prompt = sys_prompt_temp.format(
|
sys_prompt = sys_prompt_temp.format(
|
||||||
context_data=context, response_type=query_param.response_type
|
context_data=context, response_type=query_param.response_type
|
||||||
)
|
)
|
||||||
|
if query_param.only_need_prompt:
|
||||||
|
return sys_prompt
|
||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
|
|
@ -914,6 +919,7 @@ async def hybrid_query(
|
||||||
|
|
||||||
result = await use_model_func(kw_prompt)
|
result = await use_model_func(kw_prompt)
|
||||||
json_text = locate_json_string_body_from_string(result)
|
json_text = locate_json_string_body_from_string(result)
|
||||||
|
logger.debug("hybrid_query json_text:", json_text)
|
||||||
try:
|
try:
|
||||||
keywords_data = json.loads(json_text)
|
keywords_data = json.loads(json_text)
|
||||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||||
|
|
@ -969,6 +975,8 @@ async def hybrid_query(
|
||||||
sys_prompt = sys_prompt_temp.format(
|
sys_prompt = sys_prompt_temp.format(
|
||||||
context_data=context, response_type=query_param.response_type
|
context_data=context, response_type=query_param.response_type
|
||||||
)
|
)
|
||||||
|
if query_param.only_need_prompt:
|
||||||
|
return sys_prompt
|
||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
|
|
@ -1079,6 +1087,8 @@ async def naive_query(
|
||||||
sys_prompt = sys_prompt_temp.format(
|
sys_prompt = sys_prompt_temp.format(
|
||||||
content_data=section, response_type=query_param.response_type
|
content_data=section, response_type=query_param.response_type
|
||||||
)
|
)
|
||||||
|
if query_param.only_need_prompt:
|
||||||
|
return sys_prompt
|
||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,11 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||||
"""Locate the JSON string body from a string"""
|
"""Locate the JSON string body from a string"""
|
||||||
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
||||||
if maybe_json_str is not None:
|
if maybe_json_str is not None:
|
||||||
return maybe_json_str.group(0)
|
maybe_json_str = maybe_json_str.group(0)
|
||||||
|
maybe_json_str = maybe_json_str.replace("\\n", "")
|
||||||
|
maybe_json_str = maybe_json_str.replace("\n", "")
|
||||||
|
maybe_json_str = maybe_json_str.replace("'", '"')
|
||||||
|
return maybe_json_str
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue