Implement keyword context mode

This commit is contained in:
sakasegawa 2024-10-15 02:29:44 +09:00
parent af997c02c2
commit 035920a562
3 changed files with 297 additions and 37 deletions

View file

@ -13,6 +13,7 @@ from .operate import (
global_query, global_query,
hybird_query, hybird_query,
naive_query, naive_query,
keyword_context_query,
) )
from .storage import ( from .storage import (
@ -78,7 +79,7 @@ class LightRAG:
# text embedding # text embedding
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)# embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
embedding_batch_num: int = 32 embedding_batch_num: int = 32
embedding_func_max_async: int = 16 embedding_func_max_async: int = 16
@ -99,11 +100,11 @@ class LightRAG:
addon_params: dict = field(default_factory=dict) addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json convert_response_to_json_func: callable = convert_response_to_json
def __post_init__(self): def __post_init__(self):
log_file = os.path.join(self.working_dir, "lightrag.log") log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file) set_logger(log_file)
logger.info(f"Logger initialized for working directory: {self.working_dir}") logger.info(f"Logger initialized for working directory: {self.working_dir}")
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@ -155,7 +156,7 @@ class LightRAG:
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
) )
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(self.llm_model_func, hashing_kv=self.llm_response_cache) partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
) )
@ -242,11 +243,11 @@ class LightRAG:
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def query(self, query: str, param: QueryParam = QueryParam()): def query(self, query: str, param: QueryParam = QueryParam(), history_messages=[]):
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param)) return loop.run_until_complete(self.aquery(query, param, history_messages))
async def aquery(self, query: str, param: QueryParam = QueryParam()): async def aquery(self, query: str, param: QueryParam = QueryParam(), history_messages=[]):
if param.mode == "local": if param.mode == "local":
response = await local_query( response = await local_query(
query, query,
@ -285,11 +286,22 @@ class LightRAG:
param, param,
asdict(self), asdict(self),
) )
elif param.mode == "keyword_context":
response = await keyword_context_query(
query,
history_messages,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
asdict(self),
)
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")
await self._query_done() await self._query_done()
return response return response
async def _query_done(self): async def _query_done(self):
tasks = [] tasks = []
@ -298,5 +310,3 @@ class LightRAG:
continue continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View file

@ -229,7 +229,7 @@ async def _merge_edges_then_upsert(
description=description, description=description,
keywords=keywords, keywords=keywords,
) )
return edge_data return edge_data
async def extract_entities( async def extract_entities(
@ -392,7 +392,7 @@ async def local_query(
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query) kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt) result = await use_model_func(kw_prompt)
try: try:
keywords_data = json.loads(result) keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", []) keywords = keywords_data.get("low_level_keywords", [])
@ -428,7 +428,7 @@ async def local_query(
) )
if len(response)>len(sys_prompt): if len(response)>len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip() response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
return response return response
async def _build_local_query_context( async def _build_local_query_context(
@ -619,7 +619,7 @@ async def global_query(
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query) kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt) result = await use_model_func(kw_prompt)
try: try:
keywords_data = json.loads(result) keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", []) keywords = keywords_data.get("high_level_keywords", [])
@ -630,7 +630,7 @@ async def global_query(
keywords_data = json.loads(result) keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", []) keywords = keywords_data.get("high_level_keywords", [])
keywords = ', '.join(keywords) keywords = ', '.join(keywords)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# Handle parsing error # Handle parsing error
print(f"JSON parsing error: {e}") print(f"JSON parsing error: {e}")
@ -644,12 +644,12 @@ async def global_query(
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
if query_param.only_need_context: if query_param.only_need_context:
return context return context
if context is None: if context is None:
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"] sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type context_data=context, response_type=query_param.response_type
@ -660,7 +660,7 @@ async def global_query(
) )
if len(response)>len(sys_prompt): if len(response)>len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip() response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
return response return response
async def _build_global_query_context( async def _build_global_query_context(
@ -672,14 +672,14 @@ async def _build_global_query_context(
query_param: QueryParam, query_param: QueryParam,
): ):
results = await relationships_vdb.query(keywords, top_k=query_param.top_k) results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
if not len(results): if not len(results):
return None return None
edge_datas = await asyncio.gather( edge_datas = await asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
) )
if not all([n is not None for n in edge_datas]): if not all([n is not None for n in edge_datas]):
logger.warning("Some edges are missing, maybe the storage is damaged") logger.warning("Some edges are missing, maybe the storage is damaged")
edge_degree = await asyncio.gather( edge_degree = await asyncio.gather(
@ -767,7 +767,7 @@ async def _find_most_related_entities_from_relationships(
for e in edge_datas: for e in edge_datas:
entity_names.add(e["src_id"]) entity_names.add(e["src_id"])
entity_names.add(e["tgt_id"]) entity_names.add(e["tgt_id"])
node_datas = await asyncio.gather( node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names] *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
) )
@ -794,7 +794,7 @@ async def _find_related_text_unit_from_relationships(
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
): ):
text_units = [ text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in edge_datas for dp in edge_datas
@ -809,7 +809,7 @@ async def _find_related_text_unit_from_relationships(
"data": await text_chunks_db.get_by_id(c_id), "data": await text_chunks_db.get_by_id(c_id),
"order": index, "order": index,
} }
if any([v is None for v in all_text_units_lookup.values()]): if any([v is None for v in all_text_units_lookup.values()]):
logger.warning("Text chunks are missing, maybe the storage is damaged") logger.warning("Text chunks are missing, maybe the storage is damaged")
all_text_units = [ all_text_units = [
@ -840,7 +840,7 @@ async def hybird_query(
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query) kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt) result = await use_model_func(kw_prompt)
try: try:
keywords_data = json.loads(result) keywords_data = json.loads(result)
@ -884,7 +884,7 @@ async def hybird_query(
return context return context
if context is None: if context is None:
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
sys_prompt_temp = PROMPTS["rag_response"] sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type context_data=context, response_type=query_param.response_type
@ -904,17 +904,17 @@ def combine_contexts(high_level_context, low_level_context):
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
entities = entities_match.group(1) if entities_match else '' entities = entities_match.group(1) if entities_match else ''
relationships = relationships_match.group(1) if relationships_match else '' relationships = relationships_match.group(1) if relationships_match else ''
sources = sources_match.group(1) if sources_match else '' sources = sources_match.group(1) if sources_match else ''
return entities, relationships, sources return entities, relationships, sources
# Extract sections from both contexts # Extract sections from both contexts
if high_level_context==None: if high_level_context==None:
warnings.warn("High Level context is None. Return empty High entity/relationship/source") warnings.warn("High Level context is None. Return empty High entity/relationship/source")
hl_entities, hl_relationships, hl_sources = '','','' hl_entities, hl_relationships, hl_sources = '','',''
else: else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context) hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
@ -927,19 +927,19 @@ def combine_contexts(high_level_context, low_level_context):
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
# Combine and deduplicate the entities # Combine and deduplicate the entities
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n'))) combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
combined_entities = '\n'.join(combined_entities_set) combined_entities = '\n'.join(combined_entities_set)
# Combine and deduplicate the relationships # Combine and deduplicate the relationships
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n'))) combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
combined_relationships = '\n'.join(combined_relationships_set) combined_relationships = '\n'.join(combined_relationships_set)
# Combine and deduplicate the sources # Combine and deduplicate the sources
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n'))) combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
combined_sources = '\n'.join(combined_sources_set) combined_sources = '\n'.join(combined_sources_set)
# Format the combined context # Format the combined context
return f""" return f"""
-----Entities----- -----Entities-----
@ -985,6 +985,218 @@ async def naive_query(
if len(response)>len(sys_prompt): if len(response)>len(sys_prompt):
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip() response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
return response return response
async def keyword_context_query(
query,
history_messages,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
use_model_func = global_config["llm_model_func"]
embedding_func = global_config["embedding_func"]
# プロンプトからキーワードを抽出 (これはlocal_queryと同じ)
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
if not keywords:
return PROMPTS["fail_response"]
except json.JSONDecodeError:
return PROMPTS["fail_response"]
# 各キーワードに対して処理を行う
all_entity_results = []
all_relationship_results = []
for keyword in keywords:
# キーワードに関連するコンテキストを会話履歴から生成
# 例) history: 「歴史面白いよな」「戦国時代の武将誰が好き?」, prompt「信長かな〜」
# -> 「信長は戦国時代の武将です」 (情報を付与せずhistory + promptからコンテキストを生成)
keyword_context = await generate_keyword_context_from_history(keyword, history_messages, use_model_func)
if not keyword_context:
keyword_context = keyword # コンテキストが生成されない場合はキーワード自体を使用
# コンテキストを埋め込みに変換
context_embedding = await embedding_func([keyword_context])
context_embedding = context_embedding[0]
# エンティティVDBを検索
entity_results = await entities_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k)
all_entity_results.extend(entity_results)
# リレーションシップVDBを検索
relationship_results = await relationships_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k)
all_relationship_results.extend(relationship_results)
unique_entity_results = {res["id"]: res for res in all_entity_results}.values()
unique_relationship_results = { (res["metadata"]["src_id"], res["metadata"]["tgt_id"]): res for res in all_relationship_results }.values()
# 回答用のコンテキストを作成
use_entities = await _find_most_related_entities_from_results(
unique_entity_results, knowledge_graph_inst, query_param
)
use_relationships = await _find_most_related_relationships_from_results(
unique_relationship_results, knowledge_graph_inst, query_param
)
use_text_units = await _find_related_text_units(
use_entities, use_relationships, text_chunks_db, query_param
)
context_data = await _build_keyword_context(
use_entities, use_relationships, use_text_units, query_param
)
# コンテキストを使用して最終回答を生成
if query_param.only_need_context:
return context_data
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context_data, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
return response.strip()
async def generate_keyword_context_from_history(keyword, history_messages, use_model_func):
context_prompt_temp = PROMPTS["keyword_context_from_history"]
history_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history_messages])
context_prompt = context_prompt_temp.format(
keyword=keyword,
history=history_text
)
context = await use_model_func(context_prompt)
return context.strip()
async def _find_most_related_entities_from_results(results, knowledge_graph_inst, query_param):
entity_ids = [r["metadata"].get("entity_name") for r in results]
entities = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_id) for entity_id in entity_ids]
)
node_degrees = await asyncio.gather(
*[knowledge_graph_inst.node_degree(entity_id) for entity_id in entity_ids]
)
entities_data = [
{
"entity_name": entity_id,
"entity_type": entity.get("entity_type", "UNKNOWN"),
"description": entity.get("description", ""),
"rank": degree,
"source_id": entity.get("source_id", ""),
}
for entity_id, entity, degree in zip(entity_ids, entities, node_degrees)
if entity is not None
]
entities_data = truncate_list_by_token_size(
entities_data,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_local_context,
)
return entities_data
async def _find_most_related_relationships_from_results(results, knowledge_graph_inst, query_param):
relationship_ids = [(r["metadata"]["src_id"], r["metadata"]["tgt_id"]) for r in results]
relationships = await asyncio.gather(
*[knowledge_graph_inst.get_edge(src_id, tgt_id) for src_id, tgt_id in relationship_ids]
)
edge_degrees = await asyncio.gather(
*[knowledge_graph_inst.edge_degree(src_id, tgt_id) for src_id, tgt_id in relationship_ids]
)
relationships_data = [
{
"src_id": src_id,
"tgt_id": tgt_id,
"description": edge.get("description", ""),
"keywords": edge.get("keywords", ""),
"weight": edge.get("weight", 1),
"rank": degree,
"source_id": edge.get("source_id", ""),
}
for (src_id, tgt_id), edge, degree in zip(relationship_ids, relationships, edge_degrees)
if edge is not None
]
relationships_data = truncate_list_by_token_size(
relationships_data,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_local_context,
)
return relationships_data
async def _find_related_text_units(use_entities, use_relationships, text_chunks_db, query_param):
text_unit_ids = set()
for entity in use_entities:
source_ids = entity.get("source_id", "")
text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP))
for relationship in use_relationships:
source_ids = relationship.get("source_id", "")
text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP))
text_unit_ids = list(text_unit_ids)
text_units = await text_chunks_db.get_by_ids(text_unit_ids)
text_units = [t for t in text_units if t is not None]
text_units = truncate_list_by_token_size(
text_units,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
)
return text_units
async def _build_keyword_context(use_entities, use_relationships, use_text_units, query_param):
entities_section_list = [["id", "entity", "type", "description", "rank"]]
for i, n in enumerate(use_entities):
entities_section_list.append(
[
i,
n["entity_name"],
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
n["rank"],
]
)
entities_context = list_of_list_to_csv(entities_section_list)
relations_section_list = [
["id", "src_entity", "tgt_entity", "description", "keywords", "weight", "rank"]
]
for i, e in enumerate(use_relationships):
relations_section_list.append(
[
i,
e["src_id"],
e["tgt_id"],
e["description"],
e["keywords"],
e["weight"],
e["rank"],
]
)
relations_context = list_of_list_to_csv(relations_section_list)
text_units_section_list = [["id", "content"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"]])
text_units_context = list_of_list_to_csv(text_units_section_list)
return f"""
-----Entities-----
```csv
{entities_context}
```
-----Relationships-----
```csv
{relations_context}
```
-----Sources-----
```csv
{text_units_context}
```
""".strip()

View file

@ -32,7 +32,7 @@ Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tupl
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document. 3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>) Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. 4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
5. When finished, output {completion_delimiter} 5. When finished, output {completion_delimiter}
@ -254,3 +254,41 @@ Do not include information where the supporting evidence for it is not provided.
---Target response length and format--- ---Target response length and format---
{response_type} {response_type}
""" """
PROMPTS["keywords_extraction_with_history"] = """---Role---
You are a helpful assistant tasked with identifying both high-level and low-level keywords from the conversation history.
---Goal---
Given the conversation history, list both high-level and low-level keywords that are relevant to the user's query.
---Instructions---
- Output the keywords in JSON format.
- The JSON should have two keys:
- "high_level_keywords" for overarching concepts or themes.
- "low_level_keywords" for specific entities or details.
---Conversation History---
{history}
"""
PROMPTS["context_generation_from_history"] = """---Role---
You are a helpful assistant tasked with generating context for the user's query based on the conversation history and the provided keywords.
---Goal---
Using the conversation history and the keywords, summarize the relevant information that the assistant knows so far about the keywords.
---Instructions---
- Provide a concise summary that connects the keywords with the context from the conversation history.
- Do not include irrelevant information.
---Keywords---
{keywords}
---Conversation History---
{history}
"""