Implement keyword context mode
This commit is contained in:
parent
af997c02c2
commit
035920a562
3 changed files with 297 additions and 37 deletions
|
|
@ -13,6 +13,7 @@ from .operate import (
|
||||||
global_query,
|
global_query,
|
||||||
hybird_query,
|
hybird_query,
|
||||||
naive_query,
|
naive_query,
|
||||||
|
keyword_context_query,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .storage import (
|
from .storage import (
|
||||||
|
|
@ -78,7 +79,7 @@ class LightRAG:
|
||||||
|
|
||||||
# text embedding
|
# text embedding
|
||||||
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
||||||
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
|
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
|
||||||
embedding_batch_num: int = 32
|
embedding_batch_num: int = 32
|
||||||
embedding_func_max_async: int = 16
|
embedding_func_max_async: int = 16
|
||||||
|
|
||||||
|
|
@ -99,11 +100,11 @@ class LightRAG:
|
||||||
addon_params: dict = field(default_factory=dict)
|
addon_params: dict = field(default_factory=dict)
|
||||||
convert_response_to_json_func: callable = convert_response_to_json
|
convert_response_to_json_func: callable = convert_response_to_json
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
log_file = os.path.join(self.working_dir, "lightrag.log")
|
log_file = os.path.join(self.working_dir, "lightrag.log")
|
||||||
set_logger(log_file)
|
set_logger(log_file)
|
||||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||||
|
|
||||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
||||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||||
|
|
||||||
|
|
@ -155,7 +156,7 @@ class LightRAG:
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
||||||
)
|
)
|
||||||
|
|
@ -242,11 +243,11 @@ class LightRAG:
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def query(self, query: str, param: QueryParam = QueryParam()):
|
def query(self, query: str, param: QueryParam = QueryParam(), history_messages=[]):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(self.aquery(query, param))
|
return loop.run_until_complete(self.aquery(query, param, history_messages))
|
||||||
|
|
||||||
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
async def aquery(self, query: str, param: QueryParam = QueryParam(), history_messages=[]):
|
||||||
if param.mode == "local":
|
if param.mode == "local":
|
||||||
response = await local_query(
|
response = await local_query(
|
||||||
query,
|
query,
|
||||||
|
|
@ -285,11 +286,22 @@ class LightRAG:
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
)
|
)
|
||||||
|
elif param.mode == "keyword_context":
|
||||||
|
response = await keyword_context_query(
|
||||||
|
query,
|
||||||
|
history_messages,
|
||||||
|
self.chunk_entity_relation_graph,
|
||||||
|
self.entities_vdb,
|
||||||
|
self.relationships_vdb,
|
||||||
|
self.text_chunks,
|
||||||
|
param,
|
||||||
|
asdict(self),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown mode {param.mode}")
|
raise ValueError(f"Unknown mode {param.mode}")
|
||||||
await self._query_done()
|
await self._query_done()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def _query_done(self):
|
async def _query_done(self):
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
@ -298,5 +310,3 @@ class LightRAG:
|
||||||
continue
|
continue
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -229,7 +229,7 @@ async def _merge_edges_then_upsert(
|
||||||
description=description,
|
description=description,
|
||||||
keywords=keywords,
|
keywords=keywords,
|
||||||
)
|
)
|
||||||
|
|
||||||
return edge_data
|
return edge_data
|
||||||
|
|
||||||
async def extract_entities(
|
async def extract_entities(
|
||||||
|
|
@ -392,7 +392,7 @@ async def local_query(
|
||||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
kw_prompt = kw_prompt_temp.format(query=query)
|
kw_prompt = kw_prompt_temp.format(query=query)
|
||||||
result = await use_model_func(kw_prompt)
|
result = await use_model_func(kw_prompt)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
keywords = keywords_data.get("low_level_keywords", [])
|
keywords = keywords_data.get("low_level_keywords", [])
|
||||||
|
|
@ -428,7 +428,7 @@ async def local_query(
|
||||||
)
|
)
|
||||||
if len(response)>len(sys_prompt):
|
if len(response)>len(sys_prompt):
|
||||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _build_local_query_context(
|
async def _build_local_query_context(
|
||||||
|
|
@ -619,7 +619,7 @@ async def global_query(
|
||||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
kw_prompt = kw_prompt_temp.format(query=query)
|
kw_prompt = kw_prompt_temp.format(query=query)
|
||||||
result = await use_model_func(kw_prompt)
|
result = await use_model_func(kw_prompt)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
keywords = keywords_data.get("high_level_keywords", [])
|
keywords = keywords_data.get("high_level_keywords", [])
|
||||||
|
|
@ -630,7 +630,7 @@ async def global_query(
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
keywords = keywords_data.get("high_level_keywords", [])
|
keywords = keywords_data.get("high_level_keywords", [])
|
||||||
keywords = ', '.join(keywords)
|
keywords = ', '.join(keywords)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
# Handle parsing error
|
# Handle parsing error
|
||||||
print(f"JSON parsing error: {e}")
|
print(f"JSON parsing error: {e}")
|
||||||
|
|
@ -644,12 +644,12 @@ async def global_query(
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
|
|
||||||
if query_param.only_need_context:
|
if query_param.only_need_context:
|
||||||
return context
|
return context
|
||||||
if context is None:
|
if context is None:
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
sys_prompt_temp = PROMPTS["rag_response"]
|
sys_prompt_temp = PROMPTS["rag_response"]
|
||||||
sys_prompt = sys_prompt_temp.format(
|
sys_prompt = sys_prompt_temp.format(
|
||||||
context_data=context, response_type=query_param.response_type
|
context_data=context, response_type=query_param.response_type
|
||||||
|
|
@ -660,7 +660,7 @@ async def global_query(
|
||||||
)
|
)
|
||||||
if len(response)>len(sys_prompt):
|
if len(response)>len(sys_prompt):
|
||||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _build_global_query_context(
|
async def _build_global_query_context(
|
||||||
|
|
@ -672,14 +672,14 @@ async def _build_global_query_context(
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
):
|
):
|
||||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
||||||
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
edge_datas = await asyncio.gather(
|
edge_datas = await asyncio.gather(
|
||||||
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
||||||
)
|
)
|
||||||
|
|
||||||
if not all([n is not None for n in edge_datas]):
|
if not all([n is not None for n in edge_datas]):
|
||||||
logger.warning("Some edges are missing, maybe the storage is damaged")
|
logger.warning("Some edges are missing, maybe the storage is damaged")
|
||||||
edge_degree = await asyncio.gather(
|
edge_degree = await asyncio.gather(
|
||||||
|
|
@ -767,7 +767,7 @@ async def _find_most_related_entities_from_relationships(
|
||||||
for e in edge_datas:
|
for e in edge_datas:
|
||||||
entity_names.add(e["src_id"])
|
entity_names.add(e["src_id"])
|
||||||
entity_names.add(e["tgt_id"])
|
entity_names.add(e["tgt_id"])
|
||||||
|
|
||||||
node_datas = await asyncio.gather(
|
node_datas = await asyncio.gather(
|
||||||
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
||||||
)
|
)
|
||||||
|
|
@ -794,7 +794,7 @@ async def _find_related_text_unit_from_relationships(
|
||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
):
|
):
|
||||||
|
|
||||||
text_units = [
|
text_units = [
|
||||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||||
for dp in edge_datas
|
for dp in edge_datas
|
||||||
|
|
@ -809,7 +809,7 @@ async def _find_related_text_unit_from_relationships(
|
||||||
"data": await text_chunks_db.get_by_id(c_id),
|
"data": await text_chunks_db.get_by_id(c_id),
|
||||||
"order": index,
|
"order": index,
|
||||||
}
|
}
|
||||||
|
|
||||||
if any([v is None for v in all_text_units_lookup.values()]):
|
if any([v is None for v in all_text_units_lookup.values()]):
|
||||||
logger.warning("Text chunks are missing, maybe the storage is damaged")
|
logger.warning("Text chunks are missing, maybe the storage is damaged")
|
||||||
all_text_units = [
|
all_text_units = [
|
||||||
|
|
@ -840,7 +840,7 @@ async def hybird_query(
|
||||||
|
|
||||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
kw_prompt = kw_prompt_temp.format(query=query)
|
kw_prompt = kw_prompt_temp.format(query=query)
|
||||||
|
|
||||||
result = await use_model_func(kw_prompt)
|
result = await use_model_func(kw_prompt)
|
||||||
try:
|
try:
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
|
|
@ -884,7 +884,7 @@ async def hybird_query(
|
||||||
return context
|
return context
|
||||||
if context is None:
|
if context is None:
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
sys_prompt_temp = PROMPTS["rag_response"]
|
sys_prompt_temp = PROMPTS["rag_response"]
|
||||||
sys_prompt = sys_prompt_temp.format(
|
sys_prompt = sys_prompt_temp.format(
|
||||||
context_data=context, response_type=query_param.response_type
|
context_data=context, response_type=query_param.response_type
|
||||||
|
|
@ -904,17 +904,17 @@ def combine_contexts(high_level_context, low_level_context):
|
||||||
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||||
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||||
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||||
|
|
||||||
entities = entities_match.group(1) if entities_match else ''
|
entities = entities_match.group(1) if entities_match else ''
|
||||||
relationships = relationships_match.group(1) if relationships_match else ''
|
relationships = relationships_match.group(1) if relationships_match else ''
|
||||||
sources = sources_match.group(1) if sources_match else ''
|
sources = sources_match.group(1) if sources_match else ''
|
||||||
|
|
||||||
return entities, relationships, sources
|
return entities, relationships, sources
|
||||||
|
|
||||||
# Extract sections from both contexts
|
# Extract sections from both contexts
|
||||||
|
|
||||||
if high_level_context==None:
|
if high_level_context==None:
|
||||||
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
|
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
|
||||||
hl_entities, hl_relationships, hl_sources = '','',''
|
hl_entities, hl_relationships, hl_sources = '','',''
|
||||||
else:
|
else:
|
||||||
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
||||||
|
|
@ -927,19 +927,19 @@ def combine_contexts(high_level_context, low_level_context):
|
||||||
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Combine and deduplicate the entities
|
# Combine and deduplicate the entities
|
||||||
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
||||||
combined_entities = '\n'.join(combined_entities_set)
|
combined_entities = '\n'.join(combined_entities_set)
|
||||||
|
|
||||||
# Combine and deduplicate the relationships
|
# Combine and deduplicate the relationships
|
||||||
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
|
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
|
||||||
combined_relationships = '\n'.join(combined_relationships_set)
|
combined_relationships = '\n'.join(combined_relationships_set)
|
||||||
|
|
||||||
# Combine and deduplicate the sources
|
# Combine and deduplicate the sources
|
||||||
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
|
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
|
||||||
combined_sources = '\n'.join(combined_sources_set)
|
combined_sources = '\n'.join(combined_sources_set)
|
||||||
|
|
||||||
# Format the combined context
|
# Format the combined context
|
||||||
return f"""
|
return f"""
|
||||||
-----Entities-----
|
-----Entities-----
|
||||||
|
|
@ -985,6 +985,218 @@ async def naive_query(
|
||||||
|
|
||||||
if len(response)>len(sys_prompt):
|
if len(response)>len(sys_prompt):
|
||||||
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def keyword_context_query(
|
||||||
|
query,
|
||||||
|
history_messages,
|
||||||
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
|
entities_vdb: BaseVectorStorage,
|
||||||
|
relationships_vdb: BaseVectorStorage,
|
||||||
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
|
query_param: QueryParam,
|
||||||
|
global_config: dict,
|
||||||
|
) -> str:
|
||||||
|
use_model_func = global_config["llm_model_func"]
|
||||||
|
embedding_func = global_config["embedding_func"]
|
||||||
|
|
||||||
|
# プロンプトからキーワードを抽出 (これはlocal_queryと同じ)
|
||||||
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
|
kw_prompt = kw_prompt_temp.format(query=query)
|
||||||
|
result = await use_model_func(kw_prompt)
|
||||||
|
try:
|
||||||
|
keywords_data = json.loads(result)
|
||||||
|
keywords = keywords_data.get("low_level_keywords", [])
|
||||||
|
if not keywords:
|
||||||
|
return PROMPTS["fail_response"]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
|
# 各キーワードに対して処理を行う
|
||||||
|
all_entity_results = []
|
||||||
|
all_relationship_results = []
|
||||||
|
for keyword in keywords:
|
||||||
|
# キーワードに関連するコンテキストを会話履歴から生成
|
||||||
|
# 例) history: 「歴史面白いよな」「戦国時代の武将誰が好き?」, prompt「信長かな〜」
|
||||||
|
# -> 「信長は戦国時代の武将です」 (情報を付与せずhistory + promptからコンテキストを生成)
|
||||||
|
keyword_context = await generate_keyword_context_from_history(keyword, history_messages, use_model_func)
|
||||||
|
if not keyword_context:
|
||||||
|
keyword_context = keyword # コンテキストが生成されない場合はキーワード自体を使用
|
||||||
|
|
||||||
|
# コンテキストを埋め込みに変換
|
||||||
|
context_embedding = await embedding_func([keyword_context])
|
||||||
|
context_embedding = context_embedding[0]
|
||||||
|
|
||||||
|
# エンティティVDBを検索
|
||||||
|
entity_results = await entities_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k)
|
||||||
|
all_entity_results.extend(entity_results)
|
||||||
|
|
||||||
|
# リレーションシップVDBを検索
|
||||||
|
relationship_results = await relationships_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k)
|
||||||
|
all_relationship_results.extend(relationship_results)
|
||||||
|
|
||||||
|
unique_entity_results = {res["id"]: res for res in all_entity_results}.values()
|
||||||
|
unique_relationship_results = { (res["metadata"]["src_id"], res["metadata"]["tgt_id"]): res for res in all_relationship_results }.values()
|
||||||
|
|
||||||
|
# 回答用のコンテキストを作成
|
||||||
|
use_entities = await _find_most_related_entities_from_results(
|
||||||
|
unique_entity_results, knowledge_graph_inst, query_param
|
||||||
|
)
|
||||||
|
use_relationships = await _find_most_related_relationships_from_results(
|
||||||
|
unique_relationship_results, knowledge_graph_inst, query_param
|
||||||
|
)
|
||||||
|
use_text_units = await _find_related_text_units(
|
||||||
|
use_entities, use_relationships, text_chunks_db, query_param
|
||||||
|
)
|
||||||
|
|
||||||
|
context_data = await _build_keyword_context(
|
||||||
|
use_entities, use_relationships, use_text_units, query_param
|
||||||
|
)
|
||||||
|
|
||||||
|
# コンテキストを使用して最終回答を生成
|
||||||
|
if query_param.only_need_context:
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
sys_prompt_temp = PROMPTS["rag_response"]
|
||||||
|
sys_prompt = sys_prompt_temp.format(
|
||||||
|
context_data=context_data, response_type=query_param.response_type
|
||||||
|
)
|
||||||
|
response = await use_model_func(
|
||||||
|
query,
|
||||||
|
system_prompt=sys_prompt,
|
||||||
|
)
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
async def generate_keyword_context_from_history(keyword, history_messages, use_model_func):
|
||||||
|
context_prompt_temp = PROMPTS["keyword_context_from_history"]
|
||||||
|
history_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history_messages])
|
||||||
|
context_prompt = context_prompt_temp.format(
|
||||||
|
keyword=keyword,
|
||||||
|
history=history_text
|
||||||
|
)
|
||||||
|
context = await use_model_func(context_prompt)
|
||||||
|
return context.strip()
|
||||||
|
|
||||||
|
async def _find_most_related_entities_from_results(results, knowledge_graph_inst, query_param):
|
||||||
|
entity_ids = [r["metadata"].get("entity_name") for r in results]
|
||||||
|
entities = await asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.get_node(entity_id) for entity_id in entity_ids]
|
||||||
|
)
|
||||||
|
node_degrees = await asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.node_degree(entity_id) for entity_id in entity_ids]
|
||||||
|
)
|
||||||
|
entities_data = [
|
||||||
|
{
|
||||||
|
"entity_name": entity_id,
|
||||||
|
"entity_type": entity.get("entity_type", "UNKNOWN"),
|
||||||
|
"description": entity.get("description", ""),
|
||||||
|
"rank": degree,
|
||||||
|
"source_id": entity.get("source_id", ""),
|
||||||
|
}
|
||||||
|
for entity_id, entity, degree in zip(entity_ids, entities, node_degrees)
|
||||||
|
if entity is not None
|
||||||
|
]
|
||||||
|
entities_data = truncate_list_by_token_size(
|
||||||
|
entities_data,
|
||||||
|
key=lambda x: x["description"],
|
||||||
|
max_token_size=query_param.max_token_for_local_context,
|
||||||
|
)
|
||||||
|
return entities_data
|
||||||
|
|
||||||
|
async def _find_most_related_relationships_from_results(results, knowledge_graph_inst, query_param):
|
||||||
|
relationship_ids = [(r["metadata"]["src_id"], r["metadata"]["tgt_id"]) for r in results]
|
||||||
|
relationships = await asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.get_edge(src_id, tgt_id) for src_id, tgt_id in relationship_ids]
|
||||||
|
)
|
||||||
|
edge_degrees = await asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.edge_degree(src_id, tgt_id) for src_id, tgt_id in relationship_ids]
|
||||||
|
)
|
||||||
|
relationships_data = [
|
||||||
|
{
|
||||||
|
"src_id": src_id,
|
||||||
|
"tgt_id": tgt_id,
|
||||||
|
"description": edge.get("description", ""),
|
||||||
|
"keywords": edge.get("keywords", ""),
|
||||||
|
"weight": edge.get("weight", 1),
|
||||||
|
"rank": degree,
|
||||||
|
"source_id": edge.get("source_id", ""),
|
||||||
|
}
|
||||||
|
for (src_id, tgt_id), edge, degree in zip(relationship_ids, relationships, edge_degrees)
|
||||||
|
if edge is not None
|
||||||
|
]
|
||||||
|
relationships_data = truncate_list_by_token_size(
|
||||||
|
relationships_data,
|
||||||
|
key=lambda x: x["description"],
|
||||||
|
max_token_size=query_param.max_token_for_local_context,
|
||||||
|
)
|
||||||
|
return relationships_data
|
||||||
|
|
||||||
|
async def _find_related_text_units(use_entities, use_relationships, text_chunks_db, query_param):
|
||||||
|
text_unit_ids = set()
|
||||||
|
for entity in use_entities:
|
||||||
|
source_ids = entity.get("source_id", "")
|
||||||
|
text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP))
|
||||||
|
for relationship in use_relationships:
|
||||||
|
source_ids = relationship.get("source_id", "")
|
||||||
|
text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP))
|
||||||
|
text_unit_ids = list(text_unit_ids)
|
||||||
|
text_units = await text_chunks_db.get_by_ids(text_unit_ids)
|
||||||
|
text_units = [t for t in text_units if t is not None]
|
||||||
|
text_units = truncate_list_by_token_size(
|
||||||
|
text_units,
|
||||||
|
key=lambda x: x["content"],
|
||||||
|
max_token_size=query_param.max_token_for_text_unit,
|
||||||
|
)
|
||||||
|
return text_units
|
||||||
|
|
||||||
|
async def _build_keyword_context(use_entities, use_relationships, use_text_units, query_param):
|
||||||
|
entities_section_list = [["id", "entity", "type", "description", "rank"]]
|
||||||
|
for i, n in enumerate(use_entities):
|
||||||
|
entities_section_list.append(
|
||||||
|
[
|
||||||
|
i,
|
||||||
|
n["entity_name"],
|
||||||
|
n.get("entity_type", "UNKNOWN"),
|
||||||
|
n.get("description", "UNKNOWN"),
|
||||||
|
n["rank"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
entities_context = list_of_list_to_csv(entities_section_list)
|
||||||
|
relations_section_list = [
|
||||||
|
["id", "src_entity", "tgt_entity", "description", "keywords", "weight", "rank"]
|
||||||
|
]
|
||||||
|
for i, e in enumerate(use_relationships):
|
||||||
|
relations_section_list.append(
|
||||||
|
[
|
||||||
|
i,
|
||||||
|
e["src_id"],
|
||||||
|
e["tgt_id"],
|
||||||
|
e["description"],
|
||||||
|
e["keywords"],
|
||||||
|
e["weight"],
|
||||||
|
e["rank"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
relations_context = list_of_list_to_csv(relations_section_list)
|
||||||
|
|
||||||
|
text_units_section_list = [["id", "content"]]
|
||||||
|
for i, t in enumerate(use_text_units):
|
||||||
|
text_units_section_list.append([i, t["content"]])
|
||||||
|
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
-----Entities-----
|
||||||
|
```csv
|
||||||
|
{entities_context}
|
||||||
|
```
|
||||||
|
-----Relationships-----
|
||||||
|
```csv
|
||||||
|
{relations_context}
|
||||||
|
```
|
||||||
|
-----Sources-----
|
||||||
|
```csv
|
||||||
|
{text_units_context}
|
||||||
|
```
|
||||||
|
""".strip()
|
||||||
|
|
@ -32,7 +32,7 @@ Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tupl
|
||||||
|
|
||||||
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
|
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
|
||||||
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
|
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
|
||||||
|
|
||||||
4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
||||||
|
|
||||||
5. When finished, output {completion_delimiter}
|
5. When finished, output {completion_delimiter}
|
||||||
|
|
@ -254,3 +254,41 @@ Do not include information where the supporting evidence for it is not provided.
|
||||||
---Target response length and format---
|
---Target response length and format---
|
||||||
{response_type}
|
{response_type}
|
||||||
"""
|
"""
|
||||||
|
PROMPTS["keywords_extraction_with_history"] = """---Role---
|
||||||
|
|
||||||
|
You are a helpful assistant tasked with identifying both high-level and low-level keywords from the conversation history.
|
||||||
|
|
||||||
|
---Goal---
|
||||||
|
|
||||||
|
Given the conversation history, list both high-level and low-level keywords that are relevant to the user's query.
|
||||||
|
|
||||||
|
---Instructions---
|
||||||
|
|
||||||
|
- Output the keywords in JSON format.
|
||||||
|
- The JSON should have two keys:
|
||||||
|
- "high_level_keywords" for overarching concepts or themes.
|
||||||
|
- "low_level_keywords" for specific entities or details.
|
||||||
|
|
||||||
|
---Conversation History---
|
||||||
|
{history}
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROMPTS["context_generation_from_history"] = """---Role---
|
||||||
|
|
||||||
|
You are a helpful assistant tasked with generating context for the user's query based on the conversation history and the provided keywords.
|
||||||
|
|
||||||
|
---Goal---
|
||||||
|
|
||||||
|
Using the conversation history and the keywords, summarize the relevant information that the assistant knows so far about the keywords.
|
||||||
|
|
||||||
|
---Instructions---
|
||||||
|
|
||||||
|
- Provide a concise summary that connects the keywords with the context from the conversation history.
|
||||||
|
- Do not include irrelevant information.
|
||||||
|
|
||||||
|
---Keywords---
|
||||||
|
{keywords}
|
||||||
|
|
||||||
|
---Conversation History---
|
||||||
|
{history}
|
||||||
|
"""
|
||||||
Loading…
Add table
Reference in a new issue