From 035920a56264ec0ee2c509bcdf6a50e4e54d5d37 Mon Sep 17 00:00:00 2001 From: sakasegawa Date: Tue, 15 Oct 2024 02:29:44 +0900 Subject: [PATCH] Implement keyword context mode --- lightrag/lightrag.py | 32 ++++-- lightrag/operate.py | 262 ++++++++++++++++++++++++++++++++++++++----- lightrag/prompt.py | 40 ++++++- 3 files changed, 297 insertions(+), 37 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9c34a607..2c2b39cd 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -13,6 +13,7 @@ from .operate import ( global_query, hybird_query, naive_query, + keyword_context_query, ) from .storage import ( @@ -78,7 +79,7 @@ class LightRAG: # text embedding # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) - embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)# + embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)# embedding_batch_num: int = 32 embedding_func_max_async: int = 16 @@ -99,11 +100,11 @@ class LightRAG: addon_params: dict = field(default_factory=dict) convert_response_to_json_func: callable = convert_response_to_json - def __post_init__(self): + def __post_init__(self): log_file = os.path.join(self.working_dir, "lightrag.log") set_logger(log_file) logger.info(f"Logger initialized for working directory: {self.working_dir}") - + _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") @@ -155,7 +156,7 @@ class LightRAG: embedding_func=self.embedding_func, ) ) - + self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial(self.llm_model_func, hashing_kv=self.llm_response_cache) ) @@ -242,11 +243,11 @@ class LightRAG: tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) await asyncio.gather(*tasks) - def query(self, query: str, param: QueryParam = QueryParam()): + def query(self, query: str, param: QueryParam = QueryParam(), history_messages=[]): loop = always_get_an_event_loop() - return loop.run_until_complete(self.aquery(query, param)) - - async def aquery(self, query: str, param: QueryParam = QueryParam()): + return loop.run_until_complete(self.aquery(query, param, history_messages)) + + async def aquery(self, query: str, param: QueryParam = QueryParam(), history_messages=[]): if param.mode == "local": response = await local_query( query, @@ -285,11 +286,22 @@ class LightRAG: param, asdict(self), ) + elif param.mode == "keyword_context": + response = await keyword_context_query( + query, + history_messages, + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.text_chunks, + param, + asdict(self), + ) else: raise ValueError(f"Unknown mode {param.mode}") await self._query_done() return response - + async def _query_done(self): tasks = [] @@ -298,5 +310,3 @@ class LightRAG: continue tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) await asyncio.gather(*tasks) - - diff --git a/lightrag/operate.py b/lightrag/operate.py index a8213a37..cbb9fdd2 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -229,7 +229,7 @@ async def _merge_edges_then_upsert( description=description, keywords=keywords, ) - + return edge_data async def extract_entities( @@ -392,7 +392,7 @@ async def local_query( kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) - + try: keywords_data = json.loads(result) keywords = keywords_data.get("low_level_keywords", []) @@ -428,7 +428,7 @@ async def local_query( ) if len(response)>len(sys_prompt): response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip() - + return response async def _build_local_query_context( @@ -619,7 +619,7 @@ async def global_query( kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) - + try: keywords_data = json.loads(result) keywords = keywords_data.get("high_level_keywords", []) @@ -630,7 +630,7 @@ async def global_query( keywords_data = json.loads(result) keywords = keywords_data.get("high_level_keywords", []) keywords = ', '.join(keywords) - + except json.JSONDecodeError as e: # Handle parsing error print(f"JSON parsing error: {e}") @@ -644,12 +644,12 @@ async def global_query( text_chunks_db, query_param, ) - + if query_param.only_need_context: return context if context is None: return PROMPTS["fail_response"] - + sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type @@ -660,7 +660,7 @@ async def global_query( ) if len(response)>len(sys_prompt): response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip() - + return response async def _build_global_query_context( @@ -672,14 +672,14 @@ async def _build_global_query_context( query_param: QueryParam, ): results = await relationships_vdb.query(keywords, top_k=query_param.top_k) - + if not len(results): return None - + edge_datas = await asyncio.gather( *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] ) - + if not all([n is not None for n in edge_datas]): logger.warning("Some edges are missing, maybe the storage is damaged") edge_degree = await asyncio.gather( @@ -767,7 +767,7 @@ async def _find_most_related_entities_from_relationships( for e in edge_datas: entity_names.add(e["src_id"]) entity_names.add(e["tgt_id"]) - + node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names] ) @@ -794,7 +794,7 @@ async def _find_related_text_unit_from_relationships( text_chunks_db: BaseKVStorage[TextChunkSchema], knowledge_graph_inst: BaseGraphStorage, ): - + text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) for dp in edge_datas @@ -809,7 +809,7 @@ async def _find_related_text_unit_from_relationships( "data": await text_chunks_db.get_by_id(c_id), "order": index, } - + if any([v is None for v in all_text_units_lookup.values()]): logger.warning("Text chunks are missing, maybe the storage is damaged") all_text_units = [ @@ -840,7 +840,7 @@ async def hybird_query( kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) - + result = await use_model_func(kw_prompt) try: keywords_data = json.loads(result) @@ -884,7 +884,7 @@ async def hybird_query( return context if context is None: return PROMPTS["fail_response"] - + sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type @@ -904,17 +904,17 @@ def combine_contexts(high_level_context, low_level_context): entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) - + entities = entities_match.group(1) if entities_match else '' relationships = relationships_match.group(1) if relationships_match else '' sources = sources_match.group(1) if sources_match else '' - + return entities, relationships, sources - + # Extract sections from both contexts if high_level_context==None: - warnings.warn("High Level context is None. Return empty High entity/relationship/source") + warnings.warn("High Level context is None. Return empty High entity/relationship/source") hl_entities, hl_relationships, hl_sources = '','','' else: hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context) @@ -927,19 +927,19 @@ def combine_contexts(high_level_context, low_level_context): ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) - + # Combine and deduplicate the entities combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n'))) combined_entities = '\n'.join(combined_entities_set) - + # Combine and deduplicate the relationships combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n'))) combined_relationships = '\n'.join(combined_relationships_set) - + # Combine and deduplicate the sources combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n'))) combined_sources = '\n'.join(combined_sources_set) - + # Format the combined context return f""" -----Entities----- @@ -985,6 +985,218 @@ async def naive_query( if len(response)>len(sys_prompt): response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip() - + return response + +async def keyword_context_query( + query, + history_messages, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, +) -> str: + use_model_func = global_config["llm_model_func"] + embedding_func = global_config["embedding_func"] + + # プロンプトからキーワードを抽出 (これはlocal_queryと同じ) + kw_prompt_temp = PROMPTS["keywords_extraction"] + kw_prompt = kw_prompt_temp.format(query=query) + result = await use_model_func(kw_prompt) + try: + keywords_data = json.loads(result) + keywords = keywords_data.get("low_level_keywords", []) + if not keywords: + return PROMPTS["fail_response"] + except json.JSONDecodeError: + return PROMPTS["fail_response"] + + # 各キーワードに対して処理を行う + all_entity_results = [] + all_relationship_results = [] + for keyword in keywords: + # キーワードに関連するコンテキストを会話履歴から生成 + # 例) history: 「歴史面白いよな」「戦国時代の武将誰が好き?」, prompt「信長かな〜」 + # -> 「信長は戦国時代の武将です」 (情報を付与せずhistory + promptからコンテキストを生成) + keyword_context = await generate_keyword_context_from_history(keyword, history_messages, use_model_func) + if not keyword_context: + keyword_context = keyword # コンテキストが生成されない場合はキーワード自体を使用 + + # コンテキストを埋め込みに変換 + context_embedding = await embedding_func([keyword_context]) + context_embedding = context_embedding[0] + + # エンティティVDBを検索 + entity_results = await entities_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k) + all_entity_results.extend(entity_results) + + # リレーションシップVDBを検索 + relationship_results = await relationships_vdb.query_by_embedding(context_embedding, top_k=query_param.top_k) + all_relationship_results.extend(relationship_results) + + unique_entity_results = {res["id"]: res for res in all_entity_results}.values() + unique_relationship_results = { (res["metadata"]["src_id"], res["metadata"]["tgt_id"]): res for res in all_relationship_results }.values() + + # 回答用のコンテキストを作成 + use_entities = await _find_most_related_entities_from_results( + unique_entity_results, knowledge_graph_inst, query_param + ) + use_relationships = await _find_most_related_relationships_from_results( + unique_relationship_results, knowledge_graph_inst, query_param + ) + use_text_units = await _find_related_text_units( + use_entities, use_relationships, text_chunks_db, query_param + ) + + context_data = await _build_keyword_context( + use_entities, use_relationships, use_text_units, query_param + ) + + # コンテキストを使用して最終回答を生成 + if query_param.only_need_context: + return context_data + + sys_prompt_temp = PROMPTS["rag_response"] + sys_prompt = sys_prompt_temp.format( + context_data=context_data, response_type=query_param.response_type + ) + response = await use_model_func( + query, + system_prompt=sys_prompt, + ) + return response.strip() + +async def generate_keyword_context_from_history(keyword, history_messages, use_model_func): + context_prompt_temp = PROMPTS["keyword_context_from_history"] + history_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history_messages]) + context_prompt = context_prompt_temp.format( + keyword=keyword, + history=history_text + ) + context = await use_model_func(context_prompt) + return context.strip() + +async def _find_most_related_entities_from_results(results, knowledge_graph_inst, query_param): + entity_ids = [r["metadata"].get("entity_name") for r in results] + entities = await asyncio.gather( + *[knowledge_graph_inst.get_node(entity_id) for entity_id in entity_ids] + ) + node_degrees = await asyncio.gather( + *[knowledge_graph_inst.node_degree(entity_id) for entity_id in entity_ids] + ) + entities_data = [ + { + "entity_name": entity_id, + "entity_type": entity.get("entity_type", "UNKNOWN"), + "description": entity.get("description", ""), + "rank": degree, + "source_id": entity.get("source_id", ""), + } + for entity_id, entity, degree in zip(entity_ids, entities, node_degrees) + if entity is not None + ] + entities_data = truncate_list_by_token_size( + entities_data, + key=lambda x: x["description"], + max_token_size=query_param.max_token_for_local_context, + ) + return entities_data + +async def _find_most_related_relationships_from_results(results, knowledge_graph_inst, query_param): + relationship_ids = [(r["metadata"]["src_id"], r["metadata"]["tgt_id"]) for r in results] + relationships = await asyncio.gather( + *[knowledge_graph_inst.get_edge(src_id, tgt_id) for src_id, tgt_id in relationship_ids] + ) + edge_degrees = await asyncio.gather( + *[knowledge_graph_inst.edge_degree(src_id, tgt_id) for src_id, tgt_id in relationship_ids] + ) + relationships_data = [ + { + "src_id": src_id, + "tgt_id": tgt_id, + "description": edge.get("description", ""), + "keywords": edge.get("keywords", ""), + "weight": edge.get("weight", 1), + "rank": degree, + "source_id": edge.get("source_id", ""), + } + for (src_id, tgt_id), edge, degree in zip(relationship_ids, relationships, edge_degrees) + if edge is not None + ] + relationships_data = truncate_list_by_token_size( + relationships_data, + key=lambda x: x["description"], + max_token_size=query_param.max_token_for_local_context, + ) + return relationships_data + +async def _find_related_text_units(use_entities, use_relationships, text_chunks_db, query_param): + text_unit_ids = set() + for entity in use_entities: + source_ids = entity.get("source_id", "") + text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP)) + for relationship in use_relationships: + source_ids = relationship.get("source_id", "") + text_unit_ids.update(source_ids.split(GRAPH_FIELD_SEP)) + text_unit_ids = list(text_unit_ids) + text_units = await text_chunks_db.get_by_ids(text_unit_ids) + text_units = [t for t in text_units if t is not None] + text_units = truncate_list_by_token_size( + text_units, + key=lambda x: x["content"], + max_token_size=query_param.max_token_for_text_unit, + ) + return text_units + +async def _build_keyword_context(use_entities, use_relationships, use_text_units, query_param): + entities_section_list = [["id", "entity", "type", "description", "rank"]] + for i, n in enumerate(use_entities): + entities_section_list.append( + [ + i, + n["entity_name"], + n.get("entity_type", "UNKNOWN"), + n.get("description", "UNKNOWN"), + n["rank"], + ] + ) + entities_context = list_of_list_to_csv(entities_section_list) + relations_section_list = [ + ["id", "src_entity", "tgt_entity", "description", "keywords", "weight", "rank"] + ] + for i, e in enumerate(use_relationships): + relations_section_list.append( + [ + i, + e["src_id"], + e["tgt_id"], + e["description"], + e["keywords"], + e["weight"], + e["rank"], + ] + ) + relations_context = list_of_list_to_csv(relations_section_list) + + text_units_section_list = [["id", "content"]] + for i, t in enumerate(use_text_units): + text_units_section_list.append([i, t["content"]]) + text_units_context = list_of_list_to_csv(text_units_section_list) + + return f""" +-----Entities----- +```csv +{entities_context} +``` +-----Relationships----- +```csv +{relations_context} +``` +-----Sources----- +```csv +{text_units_context} +``` +""".strip() \ No newline at end of file diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 5d28e49c..4e004be6 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -32,7 +32,7 @@ Format each relationship as ("relationship"{tuple_delimiter}{tupl 3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document. Format the content-level key words as ("content_keywords"{tuple_delimiter}) - + 4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. 5. When finished, output {completion_delimiter} @@ -254,3 +254,41 @@ Do not include information where the supporting evidence for it is not provided. ---Target response length and format--- {response_type} """ +PROMPTS["keywords_extraction_with_history"] = """---Role--- + +You are a helpful assistant tasked with identifying both high-level and low-level keywords from the conversation history. + +---Goal--- + +Given the conversation history, list both high-level and low-level keywords that are relevant to the user's query. + +---Instructions--- + +- Output the keywords in JSON format. +- The JSON should have two keys: + - "high_level_keywords" for overarching concepts or themes. + - "low_level_keywords" for specific entities or details. + +---Conversation History--- +{history} +""" + +PROMPTS["context_generation_from_history"] = """---Role--- + +You are a helpful assistant tasked with generating context for the user's query based on the conversation history and the provided keywords. + +---Goal--- + +Using the conversation history and the keywords, summarize the relevant information that the assistant knows so far about the keywords. + +---Instructions--- + +- Provide a concise summary that connects the keywords with the context from the conversation history. +- Do not include irrelevant information. + +---Keywords--- +{keywords} + +---Conversation History--- +{history} +""" \ No newline at end of file