diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 9c34a607..2c2b39cd 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -13,6 +13,7 @@ from .operate import (
global_query,
hybird_query,
naive_query,
+ keyword_context_query,
)
from .storage import (
@@ -78,7 +79,7 @@ class LightRAG:
# text embedding
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
- embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
+ embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
@@ -99,11 +100,11 @@ class LightRAG:
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
- def __post_init__(self):
+ def __post_init__(self):
log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
-
+
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -155,7 +156,7 @@ class LightRAG:
embedding_func=self.embedding_func,
)
)
-
+
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
)
@@ -242,11 +243,11 @@ class LightRAG:
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
- def query(self, query: str, param: QueryParam = QueryParam()):
+ def query(self, query: str, param: QueryParam = QueryParam(), history_messages=[]):
loop = always_get_an_event_loop()
- return loop.run_until_complete(self.aquery(query, param))
-
- async def aquery(self, query: str, param: QueryParam = QueryParam()):
+ return loop.run_until_complete(self.aquery(query, param, history_messages))
+
+ async def aquery(self, query: str, param: QueryParam = QueryParam(), history_messages=[]):
if param.mode == "local":
response = await local_query(
query,
@@ -285,11 +286,22 @@ class LightRAG:
param,
asdict(self),
)
+ elif param.mode == "keyword_context":
+ response = await keyword_context_query(
+ query,
+ history_messages,
+ self.chunk_entity_relation_graph,
+ self.entities_vdb,
+ self.relationships_vdb,
+ self.text_chunks,
+ param,
+ asdict(self),
+ )
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
-
+
async def _query_done(self):
tasks = []
@@ -298,5 +310,3 @@ class LightRAG:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
-
-
diff --git a/lightrag/operate.py b/lightrag/operate.py
index a8213a37..cbb9fdd2 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -229,7 +229,7 @@ async def _merge_edges_then_upsert(
description=description,
keywords=keywords,
)
-
+
return edge_data
async def extract_entities(
@@ -392,7 +392,7 @@ async def local_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
-
+
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
@@ -428,7 +428,7 @@ async def local_query(
)
if len(response)>len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip()
-
+
return response
async def _build_local_query_context(
@@ -619,7 +619,7 @@ async def global_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
-
+
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
@@ -630,7 +630,7 @@ async def global_query(
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
keywords = ', '.join(keywords)
-
+
except json.JSONDecodeError as e:
# Handle parsing error
print(f"JSON parsing error: {e}")
@@ -644,12 +644,12 @@ async def global_query(
text_chunks_db,
query_param,
)
-
+
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
-
+
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
@@ -660,7 +660,7 @@ async def global_query(
)
if len(response)>len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip()
-
+
return response
async def _build_global_query_context(
@@ -672,14 +672,14 @@ async def _build_global_query_context(
query_param: QueryParam,
):
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
-
+
if not len(results):
return None
-
+
edge_datas = await asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
)
-
+
if not all([n is not None for n in edge_datas]):
logger.warning("Some edges are missing, maybe the storage is damaged")
edge_degree = await asyncio.gather(
@@ -767,7 +767,7 @@ async def _find_most_related_entities_from_relationships(
for e in edge_datas:
entity_names.add(e["src_id"])
entity_names.add(e["tgt_id"])
-
+
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
)
@@ -794,7 +794,7 @@ async def _find_related_text_unit_from_relationships(
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
):
-
+
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in edge_datas
@@ -809,7 +809,7 @@ async def _find_related_text_unit_from_relationships(
"data": await text_chunks_db.get_by_id(c_id),
"order": index,
}
-
+
if any([v is None for v in all_text_units_lookup.values()]):
logger.warning("Text chunks are missing, maybe the storage is damaged")
all_text_units = [
@@ -840,7 +840,7 @@ async def hybird_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
-
+
result = await use_model_func(kw_prompt)
try:
keywords_data = json.loads(result)
@@ -884,7 +884,7 @@ async def hybird_query(
return context
if context is None:
return PROMPTS["fail_response"]
-
+
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
@@ -904,17 +904,17 @@ def combine_contexts(high_level_context, low_level_context):
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
-
+
entities = entities_match.group(1) if entities_match else ''
relationships = relationships_match.group(1) if relationships_match else ''
sources = sources_match.group(1) if sources_match else ''
-
+
return entities, relationships, sources
-
+
# Extract sections from both contexts
if high_level_context==None:
- warnings.warn("High Level context is None. Return empty High entity/relationship/source")
+ warnings.warn("High Level context is None. Return empty High entity/relationship/source")
hl_entities, hl_relationships, hl_sources = '','',''
else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
@@ -927,19 +927,19 @@ def combine_contexts(high_level_context, low_level_context):
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
-
+
# Combine and deduplicate the entities
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
combined_entities = '\n'.join(combined_entities_set)
-
+
# Combine and deduplicate the relationships
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
combined_relationships = '\n'.join(combined_relationships_set)
-
+
# Combine and deduplicate the sources
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
combined_sources = '\n'.join(combined_sources_set)
-
+
# Format the combined context
return f"""
-----Entities-----
@@ -985,6 +985,218 @@ async def naive_query(
if len(response)>len(sys_prompt):
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip()
-
+
return response
+
+async def keyword_context_query(
+ query,
+ history_messages,
+ knowledge_graph_inst: BaseGraphStorage,
+ entities_vdb: BaseVectorStorage,
+ relationships_vdb: BaseVectorStorage,
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
+ query_param: QueryParam,
+ global_config: dict,
+) -> str:
+ use_model_func = global_config["llm_model_func"]
+ embedding_func = global_config["embedding_func"]
+
+ # プロンプトからキーワードを抽出 (これはlocal_queryと同じ)
+ kw_prompt_temp = PROMPTS["keywords_extraction"]
+ kw_prompt = kw_prompt_temp.format(query=query)
+ result = await use_model_func(kw_prompt)
+ try:
+ keywords_data = json.loads(result)
+ keywords = keywords_data.get("low_level_keywords", [])
+ if not keywords:
+ return PROMPTS["fail_response"]
+ except json.JSONDecodeError:
+ return PROMPTS["fail_response"]
+
+ # 各キーワードに対して処理を行う
+ all_entity_results = []
+ all_relationship_results = []
+ for keyword in keywords:
+ # キーワードに関連するコンテキストを会話履歴から生成
+ # 例) history: 「歴史面白いよな」「戦国時代の武将誰が好き?」, prompt「信長かな〜」
+ # -> 「信長は戦国時代の武将です」 (情報を付与せずhistory + promptからコンテキストを生成)
+ keyword_context = await generate_keyword_context_from_history(keyword, history_messages, use_model_func)
+ if not keyword_context:
+ keyword_context = keyword # コンテキストが生成されない場合はキーワード自体を使用
+
+ # コンテキストを埋め込みに変換
+ context_embedding = await embedding_func([keyword_context])
+ context_embedding = context_embedding[0]
+
+ # エンティティVDBを検索
+ entity_results = await entities_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k)
+ all_entity_results.extend(entity_results)
+
+ # リレーションシップVDBを検索
+ relationship_results = await relationships_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k)
+ all_relationship_results.extend(relationship_results)
+
+ unique_entity_results = {res["id"]: res for res in all_entity_results}.values()
+ unique_relationship_results = { (res["metadata"]["src_id"], res["metadata"]["tgt_id"]): res for res in all_relationship_results }.values()
+
+ # 回答用のコンテキストを作成
+ use_entities = await _find_most_related_entities_from_results(
+ unique_entity_results, knowledge_graph_inst, query_param
+ )
+ use_relationships = await _find_most_related_relationships_from_results(
+ unique_relationship_results, knowledge_graph_inst, query_param
+ )
+ use_text_units = await _find_related_text_units(
+ use_entities, use_relationships, text_chunks_db, query_param
+ )
+
+ context_data = await _build_keyword_context(
+ use_entities, use_relationships, use_text_units, query_param
+ )
+
+ # コンテキストを使用して最終回答を生成
+ if query_param.only_need_context:
+ return context_data
+
+ sys_prompt_temp = PROMPTS["rag_response"]
+ sys_prompt = sys_prompt_temp.format(
+ context_data=context_data, response_type=query_param.response_type
+ )
+ response = await use_model_func(
+ query,
+ system_prompt=sys_prompt,
+ )
+ return response.strip()
+
+async def generate_keyword_context_from_history(keyword, history_messages, use_model_func):
+ context_prompt_temp = PROMPTS["keyword_context_from_history"]
+ history_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history_messages])
+ context_prompt = context_prompt_temp.format(
+ keyword=keyword,
+ history=history_text
+ )
+ context = await use_model_func(context_prompt)
+ return context.strip()
+
+async def _find_most_related_entities_from_results(results, knowledge_graph_inst, query_param):
+ entity_ids = [r["metadata"].get("entity_name") for r in results]
+ entities = await asyncio.gather(
+ *[knowledge_graph_inst.get_node(entity_id) for entity_id in entity_ids]
+ )
+ node_degrees = await asyncio.gather(
+ *[knowledge_graph_inst.node_degree(entity_id) for entity_id in entity_ids]
+ )
+ entities_data = [
+ {
+ "entity_name": entity_id,
+ "entity_type": entity.get("entity_type", "UNKNOWN"),
+ "description": entity.get("description", ""),
+ "rank": degree,
+ "source_id": entity.get("source_id", ""),
+ }
+ for entity_id, entity, degree in zip(entity_ids, entities, node_degrees)
+ if entity is not None
+ ]
+ entities_data = truncate_list_by_token_size(
+ entities_data,
+ key=lambda x: x["description"],
+ max_token_size=query_param.max_token_for_local_context,
+ )
+ return entities_data
+
+async def _find_most_related_relationships_from_results(results, knowledge_graph_inst, query_param):
+ relationship_ids = [(r["metadata"]["src_id"], r["metadata"]["tgt_id"]) for r in results]
+ relationships = await asyncio.gather(
+ *[knowledge_graph_inst.get_edge(src_id, tgt_id) for src_id, tgt_id in relationship_ids]
+ )
+ edge_degrees = await asyncio.gather(
+ *[knowledge_graph_inst.edge_degree(src_id, tgt_id) for src_id, tgt_id in relationship_ids]
+ )
+ relationships_data = [
+ {
+ "src_id": src_id,
+ "tgt_id": tgt_id,
+ "description": edge.get("description", ""),
+ "keywords": edge.get("keywords", ""),
+ "weight": edge.get("weight", 1),
+ "rank": degree,
+ "source_id": edge.get("source_id", ""),
+ }
+ for (src_id, tgt_id), edge, degree in zip(relationship_ids, relationships, edge_degrees)
+ if edge is not None
+ ]
+ relationships_data = truncate_list_by_token_size(
+ relationships_data,
+ key=lambda x: x["description"],
+ max_token_size=query_param.max_token_for_local_context,
+ )
+ return relationships_data
+
+async def _find_related_text_units(use_entities, use_relationships, text_chunks_db, query_param):
+ text_unit_ids = set()
+ for entity in use_entities:
+ source_ids = entity.get("source_id", "")
+ text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP))
+ for relationship in use_relationships:
+ source_ids = relationship.get("source_id", "")
+ text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP))
+ text_unit_ids = list(text_unit_ids)
+ text_units = await text_chunks_db.get_by_ids(text_unit_ids)
+ text_units = [t for t in text_units if t is not None]
+ text_units = truncate_list_by_token_size(
+ text_units,
+ key=lambda x: x["content"],
+ max_token_size=query_param.max_token_for_text_unit,
+ )
+ return text_units
+
+async def _build_keyword_context(use_entities, use_relationships, use_text_units, query_param):
+ entities_section_list = [["id", "entity", "type", "description", "rank"]]
+ for i, n in enumerate(use_entities):
+ entities_section_list.append(
+ [
+ i,
+ n["entity_name"],
+ n.get("entity_type", "UNKNOWN"),
+ n.get("description", "UNKNOWN"),
+ n["rank"],
+ ]
+ )
+ entities_context = list_of_list_to_csv(entities_section_list)
+ relations_section_list = [
+ ["id", "src_entity", "tgt_entity", "description", "keywords", "weight", "rank"]
+ ]
+ for i, e in enumerate(use_relationships):
+ relations_section_list.append(
+ [
+ i,
+ e["src_id"],
+ e["tgt_id"],
+ e["description"],
+ e["keywords"],
+ e["weight"],
+ e["rank"],
+ ]
+ )
+ relations_context = list_of_list_to_csv(relations_section_list)
+
+ text_units_section_list = [["id", "content"]]
+ for i, t in enumerate(use_text_units):
+ text_units_section_list.append([i, t["content"]])
+ text_units_context = list_of_list_to_csv(text_units_section_list)
+
+ return f"""
+-----Entities-----
+```csv
+{entities_context}
+```
+-----Relationships-----
+```csv
+{relations_context}
+```
+-----Sources-----
+```csv
+{text_units_context}
+```
+""".strip()
\ No newline at end of file
diff --git a/lightrag/prompt.py b/lightrag/prompt.py
index 5d28e49c..4e004be6 100644
--- a/lightrag/prompt.py
+++ b/lightrag/prompt.py
@@ -32,7 +32,7 @@ Format each relationship as ("relationship"{tuple_delimiter}{tupl
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
Format the content-level key words as ("content_keywords"{tuple_delimiter})
-
+
4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
5. When finished, output {completion_delimiter}
@@ -254,3 +254,41 @@ Do not include information where the supporting evidence for it is not provided.
---Target response length and format---
{response_type}
"""
+PROMPTS["keywords_extraction_with_history"] = """---Role---
+
+You are a helpful assistant tasked with identifying both high-level and low-level keywords from the conversation history.
+
+---Goal---
+
+Given the conversation history, list both high-level and low-level keywords that are relevant to the user's query.
+
+---Instructions---
+
+- Output the keywords in JSON format.
+- The JSON should have two keys:
+ - "high_level_keywords" for overarching concepts or themes.
+ - "low_level_keywords" for specific entities or details.
+
+---Conversation History---
+{history}
+"""
+
+PROMPTS["context_generation_from_history"] = """---Role---
+
+You are a helpful assistant tasked with generating context for the user's query based on the conversation history and the provided keywords.
+
+---Goal---
+
+Using the conversation history and the keywords, summarize the relevant information that the assistant knows so far about the keywords.
+
+---Instructions---
+
+- Provide a concise summary that connects the keywords with the context from the conversation history.
+- Do not include irrelevant information.
+
+---Keywords---
+{keywords}
+
+---Conversation History---
+{history}
+"""
\ No newline at end of file