Implement keyword context mode
This commit is contained in:
parent
af997c02c2
commit
035920a562
3 changed files with 297 additions and 37 deletions
|
|
@ -13,6 +13,7 @@ from .operate import (
|
||||||
global_query,
|
global_query,
|
||||||
hybird_query,
|
hybird_query,
|
||||||
naive_query,
|
naive_query,
|
||||||
|
keyword_context_query,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .storage import (
|
from .storage import (
|
||||||
|
|
@ -242,11 +243,11 @@ class LightRAG:
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def query(self, query: str, param: QueryParam = QueryParam()):
|
def query(self, query: str, param: QueryParam = QueryParam(), history_messages=[]):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(self.aquery(query, param))
|
return loop.run_until_complete(self.aquery(query, param, history_messages))
|
||||||
|
|
||||||
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
async def aquery(self, query: str, param: QueryParam = QueryParam(), history_messages=[]):
|
||||||
if param.mode == "local":
|
if param.mode == "local":
|
||||||
response = await local_query(
|
response = await local_query(
|
||||||
query,
|
query,
|
||||||
|
|
@ -285,6 +286,17 @@ class LightRAG:
|
||||||
param,
|
param,
|
||||||
asdict(self),
|
asdict(self),
|
||||||
)
|
)
|
||||||
|
elif param.mode == "keyword_context":
|
||||||
|
response = await keyword_context_query(
|
||||||
|
query,
|
||||||
|
history_messages,
|
||||||
|
self.chunk_entity_relation_graph,
|
||||||
|
self.entities_vdb,
|
||||||
|
self.relationships_vdb,
|
||||||
|
self.text_chunks,
|
||||||
|
param,
|
||||||
|
asdict(self),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown mode {param.mode}")
|
raise ValueError(f"Unknown mode {param.mode}")
|
||||||
await self._query_done()
|
await self._query_done()
|
||||||
|
|
@ -298,5 +310,3 @@ class LightRAG:
|
||||||
continue
|
continue
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -988,3 +988,215 @@ async def naive_query(
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def keyword_context_query(
|
||||||
|
query,
|
||||||
|
history_messages,
|
||||||
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
|
entities_vdb: BaseVectorStorage,
|
||||||
|
relationships_vdb: BaseVectorStorage,
|
||||||
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
|
query_param: QueryParam,
|
||||||
|
global_config: dict,
|
||||||
|
) -> str:
|
||||||
|
use_model_func = global_config["llm_model_func"]
|
||||||
|
embedding_func = global_config["embedding_func"]
|
||||||
|
|
||||||
|
# プロンプトからキーワードを抽出 (これはlocal_queryと同じ)
|
||||||
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
|
kw_prompt = kw_prompt_temp.format(query=query)
|
||||||
|
result = await use_model_func(kw_prompt)
|
||||||
|
try:
|
||||||
|
keywords_data = json.loads(result)
|
||||||
|
keywords = keywords_data.get("low_level_keywords", [])
|
||||||
|
if not keywords:
|
||||||
|
return PROMPTS["fail_response"]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
|
# 各キーワードに対して処理を行う
|
||||||
|
all_entity_results = []
|
||||||
|
all_relationship_results = []
|
||||||
|
for keyword in keywords:
|
||||||
|
# キーワードに関連するコンテキストを会話履歴から生成
|
||||||
|
# 例) history: 「歴史面白いよな」「戦国時代の武将誰が好き?」, prompt「信長かな〜」
|
||||||
|
# -> 「信長は戦国時代の武将です」 (情報を付与せずhistory + promptからコンテキストを生成)
|
||||||
|
keyword_context = await generate_keyword_context_from_history(keyword, history_messages, use_model_func)
|
||||||
|
if not keyword_context:
|
||||||
|
keyword_context = keyword # コンテキストが生成されない場合はキーワード自体を使用
|
||||||
|
|
||||||
|
# コンテキストを埋め込みに変換
|
||||||
|
context_embedding = await embedding_func([keyword_context])
|
||||||
|
context_embedding = context_embedding[0]
|
||||||
|
|
||||||
|
# エンティティVDBを検索
|
||||||
|
entity_results = await entities_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k)
|
||||||
|
all_entity_results.extend(entity_results)
|
||||||
|
|
||||||
|
# リレーションシップVDBを検索
|
||||||
|
relationship_results = await relationships_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k)
|
||||||
|
all_relationship_results.extend(relationship_results)
|
||||||
|
|
||||||
|
unique_entity_results = {res["id"]: res for res in all_entity_results}.values()
|
||||||
|
unique_relationship_results = { (res["metadata"]["src_id"], res["metadata"]["tgt_id"]): res for res in all_relationship_results }.values()
|
||||||
|
|
||||||
|
# 回答用のコンテキストを作成
|
||||||
|
use_entities = await _find_most_related_entities_from_results(
|
||||||
|
unique_entity_results, knowledge_graph_inst, query_param
|
||||||
|
)
|
||||||
|
use_relationships = await _find_most_related_relationships_from_results(
|
||||||
|
unique_relationship_results, knowledge_graph_inst, query_param
|
||||||
|
)
|
||||||
|
use_text_units = await _find_related_text_units(
|
||||||
|
use_entities, use_relationships, text_chunks_db, query_param
|
||||||
|
)
|
||||||
|
|
||||||
|
context_data = await _build_keyword_context(
|
||||||
|
use_entities, use_relationships, use_text_units, query_param
|
||||||
|
)
|
||||||
|
|
||||||
|
# コンテキストを使用して最終回答を生成
|
||||||
|
if query_param.only_need_context:
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
sys_prompt_temp = PROMPTS["rag_response"]
|
||||||
|
sys_prompt = sys_prompt_temp.format(
|
||||||
|
context_data=context_data, response_type=query_param.response_type
|
||||||
|
)
|
||||||
|
response = await use_model_func(
|
||||||
|
query,
|
||||||
|
system_prompt=sys_prompt,
|
||||||
|
)
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
async def generate_keyword_context_from_history(keyword, history_messages, use_model_func):
|
||||||
|
context_prompt_temp = PROMPTS["keyword_context_from_history"]
|
||||||
|
history_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history_messages])
|
||||||
|
context_prompt = context_prompt_temp.format(
|
||||||
|
keyword=keyword,
|
||||||
|
history=history_text
|
||||||
|
)
|
||||||
|
context = await use_model_func(context_prompt)
|
||||||
|
return context.strip()
|
||||||
|
|
||||||
|
async def _find_most_related_entities_from_results(results, knowledge_graph_inst, query_param):
|
||||||
|
entity_ids = [r["metadata"].get("entity_name") for r in results]
|
||||||
|
entities = await asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.get_node(entity_id) for entity_id in entity_ids]
|
||||||
|
)
|
||||||
|
node_degrees = await asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.node_degree(entity_id) for entity_id in entity_ids]
|
||||||
|
)
|
||||||
|
entities_data = [
|
||||||
|
{
|
||||||
|
"entity_name": entity_id,
|
||||||
|
"entity_type": entity.get("entity_type", "UNKNOWN"),
|
||||||
|
"description": entity.get("description", ""),
|
||||||
|
"rank": degree,
|
||||||
|
"source_id": entity.get("source_id", ""),
|
||||||
|
}
|
||||||
|
for entity_id, entity, degree in zip(entity_ids, entities, node_degrees)
|
||||||
|
if entity is not None
|
||||||
|
]
|
||||||
|
entities_data = truncate_list_by_token_size(
|
||||||
|
entities_data,
|
||||||
|
key=lambda x: x["description"],
|
||||||
|
max_token_size=query_param.max_token_for_local_context,
|
||||||
|
)
|
||||||
|
return entities_data
|
||||||
|
|
||||||
|
async def _find_most_related_relationships_from_results(results, knowledge_graph_inst, query_param):
|
||||||
|
relationship_ids = [(r["metadata"]["src_id"], r["metadata"]["tgt_id"]) for r in results]
|
||||||
|
relationships = await asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.get_edge(src_id, tgt_id) for src_id, tgt_id in relationship_ids]
|
||||||
|
)
|
||||||
|
edge_degrees = await asyncio.gather(
|
||||||
|
*[knowledge_graph_inst.edge_degree(src_id, tgt_id) for src_id, tgt_id in relationship_ids]
|
||||||
|
)
|
||||||
|
relationships_data = [
|
||||||
|
{
|
||||||
|
"src_id": src_id,
|
||||||
|
"tgt_id": tgt_id,
|
||||||
|
"description": edge.get("description", ""),
|
||||||
|
"keywords": edge.get("keywords", ""),
|
||||||
|
"weight": edge.get("weight", 1),
|
||||||
|
"rank": degree,
|
||||||
|
"source_id": edge.get("source_id", ""),
|
||||||
|
}
|
||||||
|
for (src_id, tgt_id), edge, degree in zip(relationship_ids, relationships, edge_degrees)
|
||||||
|
if edge is not None
|
||||||
|
]
|
||||||
|
relationships_data = truncate_list_by_token_size(
|
||||||
|
relationships_data,
|
||||||
|
key=lambda x: x["description"],
|
||||||
|
max_token_size=query_param.max_token_for_local_context,
|
||||||
|
)
|
||||||
|
return relationships_data
|
||||||
|
|
||||||
|
async def _find_related_text_units(use_entities, use_relationships, text_chunks_db, query_param):
|
||||||
|
text_unit_ids = set()
|
||||||
|
for entity in use_entities:
|
||||||
|
source_ids = entity.get("source_id", "")
|
||||||
|
text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP))
|
||||||
|
for relationship in use_relationships:
|
||||||
|
source_ids = relationship.get("source_id", "")
|
||||||
|
text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP))
|
||||||
|
text_unit_ids = list(text_unit_ids)
|
||||||
|
text_units = await text_chunks_db.get_by_ids(text_unit_ids)
|
||||||
|
text_units = [t for t in text_units if t is not None]
|
||||||
|
text_units = truncate_list_by_token_size(
|
||||||
|
text_units,
|
||||||
|
key=lambda x: x["content"],
|
||||||
|
max_token_size=query_param.max_token_for_text_unit,
|
||||||
|
)
|
||||||
|
return text_units
|
||||||
|
|
||||||
|
async def _build_keyword_context(use_entities, use_relationships, use_text_units, query_param):
|
||||||
|
entities_section_list = [["id", "entity", "type", "description", "rank"]]
|
||||||
|
for i, n in enumerate(use_entities):
|
||||||
|
entities_section_list.append(
|
||||||
|
[
|
||||||
|
i,
|
||||||
|
n["entity_name"],
|
||||||
|
n.get("entity_type", "UNKNOWN"),
|
||||||
|
n.get("description", "UNKNOWN"),
|
||||||
|
n["rank"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
entities_context = list_of_list_to_csv(entities_section_list)
|
||||||
|
relations_section_list = [
|
||||||
|
["id", "src_entity", "tgt_entity", "description", "keywords", "weight", "rank"]
|
||||||
|
]
|
||||||
|
for i, e in enumerate(use_relationships):
|
||||||
|
relations_section_list.append(
|
||||||
|
[
|
||||||
|
i,
|
||||||
|
e["src_id"],
|
||||||
|
e["tgt_id"],
|
||||||
|
e["description"],
|
||||||
|
e["keywords"],
|
||||||
|
e["weight"],
|
||||||
|
e["rank"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
relations_context = list_of_list_to_csv(relations_section_list)
|
||||||
|
|
||||||
|
text_units_section_list = [["id", "content"]]
|
||||||
|
for i, t in enumerate(use_text_units):
|
||||||
|
text_units_section_list.append([i, t["content"]])
|
||||||
|
text_units_context = list_of_list_to_csv(text_units_section_list)
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
-----Entities-----
|
||||||
|
```csv
|
||||||
|
{entities_context}
|
||||||
|
```
|
||||||
|
-----Relationships-----
|
||||||
|
```csv
|
||||||
|
{relations_context}
|
||||||
|
```
|
||||||
|
-----Sources-----
|
||||||
|
```csv
|
||||||
|
{text_units_context}
|
||||||
|
```
|
||||||
|
""".strip()
|
||||||
|
|
@ -254,3 +254,41 @@ Do not include information where the supporting evidence for it is not provided.
|
||||||
---Target response length and format---
|
---Target response length and format---
|
||||||
{response_type}
|
{response_type}
|
||||||
"""
|
"""
|
||||||
|
PROMPTS["keywords_extraction_with_history"] = """---Role---
|
||||||
|
|
||||||
|
You are a helpful assistant tasked with identifying both high-level and low-level keywords from the conversation history.
|
||||||
|
|
||||||
|
---Goal---
|
||||||
|
|
||||||
|
Given the conversation history, list both high-level and low-level keywords that are relevant to the user's query.
|
||||||
|
|
||||||
|
---Instructions---
|
||||||
|
|
||||||
|
- Output the keywords in JSON format.
|
||||||
|
- The JSON should have two keys:
|
||||||
|
- "high_level_keywords" for overarching concepts or themes.
|
||||||
|
- "low_level_keywords" for specific entities or details.
|
||||||
|
|
||||||
|
---Conversation History---
|
||||||
|
{history}
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROMPTS["context_generation_from_history"] = """---Role---
|
||||||
|
|
||||||
|
You are a helpful assistant tasked with generating context for the user's query based on the conversation history and the provided keywords.
|
||||||
|
|
||||||
|
---Goal---
|
||||||
|
|
||||||
|
Using the conversation history and the keywords, summarize the relevant information that the assistant knows so far about the keywords.
|
||||||
|
|
||||||
|
---Instructions---
|
||||||
|
|
||||||
|
- Provide a concise summary that connects the keywords with the context from the conversation history.
|
||||||
|
- Do not include irrelevant information.
|
||||||
|
|
||||||
|
---Keywords---
|
||||||
|
{keywords}
|
||||||
|
|
||||||
|
---Conversation History---
|
||||||
|
{history}
|
||||||
|
"""
|
||||||
Loading…
Add table
Reference in a new issue