Implement keyword context mode

This commit is contained in:
sakasegawa 2024-10-15 02:29:44 +09:00
parent af997c02c2
commit 035920a562
3 changed files with 297 additions and 37 deletions

View file

@ -13,6 +13,7 @@ from .operate import (
global_query, global_query,
hybird_query, hybird_query,
naive_query, naive_query,
keyword_context_query,
) )
from .storage import ( from .storage import (
@ -242,11 +243,11 @@ class LightRAG:
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def query(self, query: str, param: QueryParam = QueryParam()): def query(self, query: str, param: QueryParam = QueryParam(), history_messages=[]):
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param)) return loop.run_until_complete(self.aquery(query, param, history_messages))
async def aquery(self, query: str, param: QueryParam = QueryParam()): async def aquery(self, query: str, param: QueryParam = QueryParam(), history_messages=[]):
if param.mode == "local": if param.mode == "local":
response = await local_query( response = await local_query(
query, query,
@ -285,6 +286,17 @@ class LightRAG:
param, param,
asdict(self), asdict(self),
) )
elif param.mode == "keyword_context":
response = await keyword_context_query(
query,
history_messages,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")
await self._query_done() await self._query_done()
@ -298,5 +310,3 @@ class LightRAG:
continue continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View file

@ -988,3 +988,215 @@ async def naive_query(
return response return response
async def keyword_context_query(
query,
history_messages,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
use_model_func = global_config["llm_model_func"]
embedding_func = global_config["embedding_func"]
# プロンプトからキーワードを抽出 (これはlocal_queryと同じ)
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
if not keywords:
return PROMPTS["fail_response"]
except json.JSONDecodeError:
return PROMPTS["fail_response"]
# 各キーワードに対して処理を行う
all_entity_results = []
all_relationship_results = []
for keyword in keywords:
# キーワードに関連するコンテキストを会話履歴から生成
# 例) history: 「歴史面白いよな」「戦国時代の武将誰が好き?」, prompt「信長かな〜」
# -> 「信長は戦国時代の武将です」 (情報を付与せずhistory + promptからコンテキストを生成)
keyword_context = await generate_keyword_context_from_history(keyword, history_messages, use_model_func)
if not keyword_context:
keyword_context = keyword # コンテキストが生成されない場合はキーワード自体を使用
# コンテキストを埋め込みに変換
context_embedding = await embedding_func([keyword_context])
context_embedding = context_embedding[0]
# エンティティVDBを検索
entity_results = await entities_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k)
all_entity_results.extend(entity_results)
# リレーションシップVDBを検索
relationship_results = await relationships_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k)
all_relationship_results.extend(relationship_results)
unique_entity_results = {res["id"]: res for res in all_entity_results}.values()
unique_relationship_results = { (res["metadata"]["src_id"], res["metadata"]["tgt_id"]): res for res in all_relationship_results }.values()
# 回答用のコンテキストを作成
use_entities = await _find_most_related_entities_from_results(
unique_entity_results, knowledge_graph_inst, query_param
)
use_relationships = await _find_most_related_relationships_from_results(
unique_relationship_results, knowledge_graph_inst, query_param
)
use_text_units = await _find_related_text_units(
use_entities, use_relationships, text_chunks_db, query_param
)
context_data = await _build_keyword_context(
use_entities, use_relationships, use_text_units, query_param
)
# コンテキストを使用して最終回答を生成
if query_param.only_need_context:
return context_data
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context_data, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
return response.strip()
async def generate_keyword_context_from_history(keyword, history_messages, use_model_func):
context_prompt_temp = PROMPTS["keyword_context_from_history"]
history_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history_messages])
context_prompt = context_prompt_temp.format(
keyword=keyword,
history=history_text
)
context = await use_model_func(context_prompt)
return context.strip()
async def _find_most_related_entities_from_results(results, knowledge_graph_inst, query_param):
entity_ids = [r["metadata"].get("entity_name") for r in results]
entities = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_id) for entity_id in entity_ids]
)
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(entity_id) for entity_id in entity_ids]
)
entities_data = [
{
"entity_name": entity_id,
"entity_type": entity.get("entity_type", "UNKNOWN"),
"description": entity.get("description", ""),
"rank": degree,
"source_id": entity.get("source_id", ""),
}
for entity_id, entity, degree in zip(entity_ids, entities, node_degrees)
if entity is not None
]
entities_data = truncate_list_by_token_size(
entities_data,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_local_context,
)
return entities_data
async def _find_most_related_relationships_from_results(results, knowledge_graph_inst, query_param):
relationship_ids = [(r["metadata"]["src_id"], r["metadata"]["tgt_id"]) for r in results]
relationships = await asyncio.gather(
*[knowledge_graph_inst.get_edge(src_id, tgt_id) for src_id, tgt_id in relationship_ids]
)
edge_degrees = await asyncio.gather(
*[knowledge_graph_inst.edge_degree(src_id, tgt_id) for src_id, tgt_id in relationship_ids]
)
relationships_data = [
{
"src_id": src_id,
"tgt_id": tgt_id,
"description": edge.get("description", ""),
"keywords": edge.get("keywords", ""),
"weight": edge.get("weight", 1),
"rank": degree,
"source_id": edge.get("source_id", ""),
}
for (src_id, tgt_id), edge, degree in zip(relationship_ids, relationships, edge_degrees)
if edge is not None
]
relationships_data = truncate_list_by_token_size(
relationships_data,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_local_context,
)
return relationships_data
async def _find_related_text_units(use_entities, use_relationships, text_chunks_db, query_param):
text_unit_ids = set()
for entity in use_entities:
source_ids = entity.get("source_id", "")
text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP))
for relationship in use_relationships:
source_ids = relationship.get("source_id", "")
text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP))
text_unit_ids = list(text_unit_ids)
text_units = await text_chunks_db.get_by_ids(text_unit_ids)
text_units = [t for t in text_units if t is not None]
text_units = truncate_list_by_token_size(
text_units,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
)
return text_units
async def _build_keyword_context(use_entities, use_relationships, use_text_units, query_param):
entities_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(use_entities):
entities_section_list.append(
[
i,
n["entity_name"],
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
n["rank"],
]
)
entities_context = list_of_list_to_csv(entities_section_list)
relations_section_list = [
["id", "src_entity", "tgt_entity", "description", "keywords", "weight", "rank"]
]
for i, e in enumerate(use_relationships):
relations_section_list.append(
[
i,
e["src_id"],
e["tgt_id"],
e["description"],
e["keywords"],
e["weight"],
e["rank"],
]
)
relations_context = list_of_list_to_csv(relations_section_list)
text_units_section_list = [["id", "content"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return f"""
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
""".strip()

View file

@ -254,3 +254,41 @@ Do not include information where the supporting evidence for it is not provided.
---Target response length and format--- ---Target response length and format---
{response_type} {response_type}
""" """
PROMPTS["keywords_extraction_with_history"] = """---Role---
You are a helpful assistant tasked with identifying both high-level and low-level keywords from the conversation history.
---Goal---
Given the conversation history, list both high-level and low-level keywords that are relevant to the user's query.
---Instructions---
- Output the keywords in JSON format.
- The JSON should have two keys:
- "high_level_keywords" for overarching concepts or themes.
- "low_level_keywords" for specific entities or details.
---Conversation History---
{history}
"""
PROMPTS["context_generation_from_history"] = """---Role---
You are a helpful assistant tasked with generating context for the user's query based on the conversation history and the provided keywords.
---Goal---
Using the conversation history and the keywords, summarize the relevant information that the assistant knows so far about the keywords.
---Instructions---
- Provide a concise summary that connects the keywords with the context from the conversation history.
- Do not include irrelevant information.
---Keywords---
{keywords}
---Conversation History---
{history}
"""