diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 43e345cd2..cd6a9a4ba 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import binascii import logging import re @@ -21,7 +22,6 @@ from copy import deepcopy from datetime import datetime from functools import partial from timeit import default_timer as timer -import trio from langfuse import Langfuse from peewee import fn from agentic_reasoning import DeepResearcher @@ -931,5 +931,5 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}): rank_feature=label_question(question, kbs), ) mindmap = MindMapExtractor(chat_mdl) - mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) + mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in ranks["chunks"]])) return mind_map.output diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 395dcad83..43adf5d8e 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import random @@ -22,7 +23,6 @@ from copy import deepcopy from datetime import datetime from io import BytesIO -import trio import xxhash from peewee import fn, Case, JOIN @@ -999,7 +999,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): from graphrag.general.mind_map_extractor import MindMapExtractor mindmap = MindMapExtractor(llm_bdl) try: - mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]) + mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])) mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) if len(mind_map) < 32: raise Exception("Few content: " + mind_map) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 8f17e1de0..53cb1ce02 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import functools import inspect import json @@ -25,7 +26,6 @@ from functools import wraps from typing import Any import requests -import trio from quart import ( Response, jsonify, @@ -681,18 +681,36 @@ async def is_strong_enough(chat_model, embedding_model): async def _is_strong_enough(): nonlocal chat_model, embedding_model if embedding_model: - with trio.fail_after(10): - _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) + await asyncio.wait_for( + asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]), + timeout=10 + ) + if chat_model: - with trio.fail_after(30): - res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {})) - if res.find("**ERROR**") >= 0: + res = await asyncio.wait_for( + asyncio.to_thread( + chat_model.chat, + "Nothing special.", + [{"role": "user", "content": "Are you strong enough!?"}], + {} + ), + timeout=30 + ) + if "**ERROR**" in res: raise Exception(res) # Pressure test for GraphRAG task - async with trio.open_nursery() as nursery: - for _ in range(count): - nursery.start_soon(_is_strong_enough) + tasks = [ + asyncio.create_task(_is_strong_enough()) + for _ in range(count) + ] + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise def get_allowed_llm_factories() -> list: diff --git a/common/connection_utils.py b/common/connection_utils.py index 5b8154f0c..86ebc371d 100644 --- a/common/connection_utils.py +++ b/common/connection_utils.py @@ -19,7 +19,6 @@ import queue import threading from typing import Any, Callable, Coroutine, Optional, Type, Union import asyncio -import trio from functools import wraps from quart import make_response, jsonify from common.constants import RetCode @@ -70,11 +69,10 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: for a in range(attempts): try: if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): - with trio.fail_after(seconds): - return await func(*args, **kwargs) + return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds) else: return await func(*args, **kwargs) - except trio.TooSlowError: + except asyncio.TimeoutError: if a < attempts - 1: continue if on_timeout is not None: diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 7ffc52538..9e99fe941 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import itertools import os @@ -21,7 +22,6 @@ from dataclasses import dataclass from typing import Any, Callable import networkx as nx -import trio from graphrag.general.extractor import Extractor from rag.nlp import is_english @@ -101,35 +101,55 @@ class EntityResolution(Extractor): remain_candidates_to_resolve = num_candidates resolution_result = set() - resolution_result_lock = trio.Lock() + resolution_result_lock = asyncio.Lock() resolution_batch_size = 100 max_concurrent_tasks = 5 - semaphore = trio.Semaphore(max_concurrent_tasks) + semaphore = asyncio.Semaphore(max_concurrent_tasks) async def limited_resolve_candidate(candidate_batch, result_set, result_lock): nonlocal remain_candidates_to_resolve, callback async with semaphore: try: enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") - with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: - await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id) - remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) - callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ") - if cancel_scope.cancelled_caught: + timeout_sec = 280 if enable_timeout_assertion else 1_000_000_000 + + try: + await asyncio.wait_for( + self._resolve_candidate(candidate_batch, result_set, result_lock, task_id), + timeout=timeout_sec + ) + remain_candidates_to_resolve -= len(candidate_batch[1]) + callback( + msg=f"Resolved {len(candidate_batch[1])} pairs, " + f"{remain_candidates_to_resolve} remain." + ) + + except asyncio.TimeoutError: logging.warning(f"Timeout resolving {candidate_batch}, skipping...") - remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) - callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ") + remain_candidates_to_resolve -= len(candidate_batch[1]) + callback( + msg=f"Failed to resolve {len(candidate_batch[1])} pairs due to timeout, skipped. " + f"{remain_candidates_to_resolve} remain." + ) + except Exception as e: logging.error(f"Error resolving candidate batch: {e}") - async with trio.open_nursery() as nursery: - for candidate_resolution_i in candidate_resolution.items(): - if not candidate_resolution_i[1]: - continue - for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size): - candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size] - nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock) + tasks = [] + for key, lst in candidate_resolution.items(): + if not lst: + continue + for i in range(0, len(lst), resolution_batch_size): + batch = (key, lst[i:i + resolution_batch_size]) + tasks.append(limited_resolve_candidate(batch, resolution_result, resolution_result_lock)) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") @@ -141,10 +161,18 @@ class EntityResolution(Extractor): async with semaphore: await self._merge_graph_nodes(graph, nodes, change, task_id) - async with trio.open_nursery() as nursery: - for sub_connect_graph in nx.connected_components(connect_graph): - merging_nodes = list(sub_connect_graph) - nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change) + tasks = [] + for sub_connect_graph in nx.connected_components(connect_graph): + merging_nodes = list(sub_connect_graph) + tasks.append(asyncio.create_task(limited_merge_nodes(graph, merging_nodes, change)) + ) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise # Update pagerank pr = nx.pagerank(graph) @@ -156,7 +184,7 @@ class EntityResolution(Extractor): change=change, ) - async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""): + async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: asyncio.Lock, task_id: str = ""): if task_id: if has_canceled(task_id): logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.") @@ -178,13 +206,22 @@ class EntityResolution(Extractor): text = perform_variable_replacements(self._resolution_prompt, variables=variables) logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") async with chat_limiter: + timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000 try: - enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") - with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: - response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id) - if cancel_scope.cancelled_caught: - logging.warning("_resolve_candidate._chat timeout, skipping...") - return + response = await asyncio.wait_for( + asyncio.to_thread( + self._chat, + text, + [{"role": "user", "content": "Output:"}], + {}, + task_id + ), + timeout=timeout_seconds, + ) + + except asyncio.TimeoutError: + logging.warning("_resolve_candidate._chat timeout, skipping...") + return except Exception as e: logging.error(f"_resolve_candidate._chat failed: {e}") return diff --git a/graphrag/general/community_reports_extractor.py b/graphrag/general/community_reports_extractor.py index 09634fb4d..734a5401d 100644 --- a/graphrag/general/community_reports_extractor.py +++ b/graphrag/general/community_reports_extractor.py @@ -5,6 +5,7 @@ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ +import asyncio import logging import json import os @@ -24,7 +25,6 @@ from graphrag.general.leiden import add_community_info2graph from rag.llm.chat_model import Base as CompletionLLM from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter from common.token_utils import num_tokens_from_string -import trio @dataclass @@ -101,14 +101,11 @@ class CommunityReportsExtractor(Extractor): text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) async with chat_limiter: try: - with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope: - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before LLM call.") - raise TaskCanceledException(f"Task {task_id} was cancelled") - response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id) - if cancel_scope.cancelled_caught: - logging.warning("extract_community_report._chat timeout, skipping...") - return + timeout = 180 if enable_timeout_assertion else 1000000000 + response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout) + except asyncio.TimeoutError: + logging.warning("extract_community_report._chat timeout, skipping...") + return except Exception as e: logging.error(f"extract_community_report._chat failed: {e}") return @@ -141,17 +138,24 @@ class CommunityReportsExtractor(Extractor): if callback: callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}") - st = trio.current_time() - async with trio.open_nursery() as nursery: - for level, comm in communities.items(): - logging.info(f"Level {level}: Community: {len(comm.keys())}") - for community in comm.items(): - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before community processing.") - raise TaskCanceledException(f"Task {task_id} was cancelled") - nursery.start_soon(extract_community_report, community) + st = asyncio.get_running_loop().time() + tasks = [] + for level, comm in communities.items(): + logging.info(f"Level {level}: Community: {len(comm.keys())}") + for community in comm.items(): + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before community processing.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + tasks.append(asyncio.create_task(extract_community_report(community))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if callback: - callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}") + callback(msg=f"Community reports done in {asyncio.get_running_loop().time() - st:.2f}s, used tokens: {token_count}") return CommunityReportsResult( structured_output=res_dict, diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 495e562ed..8549985e6 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import os import re @@ -21,7 +22,6 @@ from copy import deepcopy from typing import Callable import networkx as nx -import trio from api.db.services.task_service import has_canceled from common.connection_utils import timeout @@ -109,14 +109,14 @@ class Extractor: async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""): self.callback = callback - start_ts = trio.current_time() + start_ts = asyncio.get_running_loop().time() async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): out_results = [] error_count = 0 max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3)) - limiter = trio.Semaphore(max_concurrency) + limiter = asyncio.Semaphore(max_concurrency) async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""): nonlocal error_count @@ -137,9 +137,18 @@ class Extractor: if error_count > max_errors: raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}") - async with trio.open_nursery() as nursery: - for i, ck in enumerate(chunks): - nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id) + tasks = [ + asyncio.create_task(worker((doc_id, ck), i, len(chunks), task_id)) + for i, ck in enumerate(chunks) + ] + + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if error_count > 0: warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)" @@ -166,7 +175,7 @@ class Extractor: for k, v in m_edges.items(): maybe_edges[tuple(sorted(k))].extend(v) sum_token_count += token_count - now = trio.current_time() + now = asyncio.get_running_loop().time() if self.callback: self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.") start_ts = now @@ -176,14 +185,22 @@ class Extractor: if task_id and has_canceled(task_id): raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging") - async with trio.open_nursery() as nursery: - for en_nm, ents in maybe_nodes.items(): - nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id) + tasks = [ + asyncio.create_task(self._merge_nodes(en_nm, ents, all_entities_data, task_id)) + for en_nm, ents in maybe_nodes.items() + ] + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if task_id and has_canceled(task_id): raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging") - now = trio.current_time() + now = asyncio.get_running_loop().time() if self.callback: self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.") @@ -194,14 +211,25 @@ class Extractor: if task_id and has_canceled(task_id): raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging") - async with trio.open_nursery() as nursery: - for (src, tgt), rels in maybe_edges.items(): - nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id) + tasks = [] + for (src, tgt), rels in maybe_edges.items(): + tasks.append( + asyncio.create_task( + self._merge_edges(src, tgt, rels, all_relationships_data, task_id) + ) + ) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if task_id and has_canceled(task_id): raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging") - now = trio.current_time() + now = asyncio.get_running_loop().time() if self.callback: self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.") @@ -309,5 +337,5 @@ class Extractor: raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling") async with chat_limiter: - summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id) + summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id) return summary diff --git a/graphrag/general/graph_extractor.py b/graphrag/general/graph_extractor.py index d156fcb2e..f2bc7949f 100644 --- a/graphrag/general/graph_extractor.py +++ b/graphrag/general/graph_extractor.py @@ -5,11 +5,11 @@ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ +import asyncio import re from typing import Any from dataclasses import dataclass import tiktoken -import trio from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT @@ -107,7 +107,7 @@ class GraphExtractor(Extractor): } hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables) async with chat_limiter: - response = await trio.to_thread.run_sync(self._chat, hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id) + response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id) token_count += num_tokens_from_string(hint_prompt + response) results = response or "" @@ -117,7 +117,7 @@ class GraphExtractor(Extractor): for i in range(self._max_gleanings): history.append({"role": "user", "content": CONTINUE_PROMPT}) async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat("", history, {})) + response = await asyncio.to_thread(self._chat, "", history, {}) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) results += response or "" @@ -127,7 +127,7 @@ class GraphExtractor(Extractor): history.append({"role": "assistant", "content": response}) history.append({"role": "user", "content": LOOP_PROMPT}) async with chat_limiter: - continuation = await trio.to_thread.run_sync(lambda: self._chat("", history)) + continuation = await asyncio.to_thread(self._chat, "", history) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) if continuation != "Y": break diff --git a/graphrag/general/index.py b/graphrag/general/index.py index f307e5d91..813027e38 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import os import networkx as nx -import trio from api.db.services.document_service import DocumentService from api.db.services.task_service import has_canceled @@ -54,25 +54,35 @@ async def run_graphrag( callback, ): enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") - start = trio.current_time() + start = asyncio.get_running_loop().time() tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] chunks = [] for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True): chunks.append(d["content_with_weight"]) - with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): - subgraph = await generate_subgraph( - LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt, - tenant_id, - kb_id, - doc_id, - chunks, - language, - row["kb_parser_config"]["graphrag"].get("entity_types", []), - chat_model, - embedding_model, - callback, + timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000 + + try: + subgraph = await asyncio.wait_for( + generate_subgraph( + LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) + or row["kb_parser_config"]["graphrag"]["method"] != "general" + else GeneralKGExt, + tenant_id, + kb_id, + doc_id, + chunks, + language, + row["kb_parser_config"]["graphrag"].get("entity_types", []), + chat_model, + embedding_model, + callback, + ), + timeout=timeout_sec, ) + except asyncio.TimeoutError: + logging.error("generate_subgraph timeout") + raise if not subgraph: return @@ -125,7 +135,7 @@ async def run_graphrag( ) finally: graphrag_task_lock.release() - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.") return @@ -145,7 +155,7 @@ async def run_graphrag_for_kb( ) -> dict: tenant_id, kb_id = row["tenant_id"], row["kb_id"] enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") - start = trio.current_time() + start = asyncio.get_running_loop().time() fields_for_chunks = ["content_with_weight", "doc_id"] if not doc_ids: @@ -211,7 +221,7 @@ async def run_graphrag_for_kb( callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.") return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0} - semaphore = trio.Semaphore(max_parallel_docs) + semaphore = asyncio.Semaphore(max_parallel_docs) subgraphs: dict[str, object] = {} failed_docs: list[tuple[str, str]] = [] # (doc_id, error) @@ -234,20 +244,28 @@ async def run_graphrag_for_kb( try: msg = f"[GraphRAG] build_subgraph doc:{doc_id}" callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)") - with trio.fail_after(deadline): - sg = await generate_subgraph( - kg_extractor, - tenant_id, - kb_id, - doc_id, - chunks, - language, - kb_parser_config.get("graphrag", {}).get("entity_types", []), - chat_model, - embedding_model, - callback, - task_id=row["id"] + + try: + sg = await asyncio.wait_for( + generate_subgraph( + kg_extractor, + tenant_id, + kb_id, + doc_id, + chunks, + language, + kb_parser_config.get("graphrag", {}).get("entity_types", []), + chat_model, + embedding_model, + callback, + task_id=row["id"] + ), + timeout=deadline, ) + except asyncio.TimeoutError: + failed_docs.append((doc_id, "timeout")) + callback(msg=f"{msg} FAILED: timeout") + return if sg: subgraphs[doc_id] = sg callback(msg=f"{msg} done") @@ -264,9 +282,14 @@ async def run_graphrag_for_kb( callback(msg=f"Task {row['id']} cancelled before processing documents.") raise TaskCanceledException(f"Task {row['id']} was cancelled") - async with trio.open_nursery() as nursery: - for doc_id in doc_ids: - nursery.start_soon(build_one, doc_id) + tasks = [asyncio.create_task(build_one(doc_id)) for doc_id in doc_ids] + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if has_canceled(row["id"]): callback(msg=f"Task {row['id']} cancelled after document processing.") @@ -275,7 +298,7 @@ async def run_graphrag_for_kb( ok_docs = [d for d in doc_ids if d in subgraphs] if not ok_docs: callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.") - now = trio.current_time() + now = asyncio.get_running_loop().time() return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200) @@ -313,7 +336,7 @@ async def run_graphrag_for_kb( kb_lock.release() if not with_resolution and not with_community: - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}") return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} @@ -356,7 +379,7 @@ async def run_graphrag_for_kb( finally: kb_lock.release() - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}") return { "ok_docs": ok_docs, @@ -388,7 +411,7 @@ async def generate_subgraph( if contains: callback(msg=f"Graph already contains {doc_id}") return None - start = trio.current_time() + start = asyncio.get_running_loop().time() ext = extractor( llm_bdl, language=language, @@ -436,9 +459,9 @@ async def generate_subgraph( "removed_kwd": "N", } cid = chunk_id(chunk) - await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) - await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) - now = trio.current_time() + await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,) + await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,) + now = asyncio.get_running_loop().time() callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") return subgraph @@ -452,7 +475,7 @@ async def merge_subgraph( embedding_model, callback, ): - start = trio.current_time() + start = asyncio.get_running_loop().time() change = GraphChange() old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"]) if old_graph is not None: @@ -468,7 +491,7 @@ async def merge_subgraph( new_graph.nodes[node_name]["pagerank"] = pagerank await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback) - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.") return new_graph @@ -490,7 +513,7 @@ async def resolve_entities( callback(msg=f"Task {task_id} cancelled during entity resolution.") raise TaskCanceledException(f"Task {task_id} was cancelled") - start = trio.current_time() + start = asyncio.get_running_loop().time() er = EntityResolution( llm_bdl, ) @@ -505,7 +528,7 @@ async def resolve_entities( raise TaskCanceledException(f"Task {task_id} was cancelled") await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback) - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"Graph resolution done in {now - start:.2f}s.") @@ -524,7 +547,7 @@ async def extract_community( callback(msg=f"Task {task_id} cancelled before community extraction.") raise TaskCanceledException(f"Task {task_id} was cancelled") - start = trio.current_time() + start = asyncio.get_running_loop().time() ext = CommunityReportsExtractor( llm_bdl, ) @@ -538,7 +561,7 @@ async def extract_community( community_reports = cr.output doc_ids = graph.graph["source_id"] - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.") start = now if task_id and has_canceled(task_id): @@ -568,16 +591,10 @@ async def extract_community( chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) chunks.append(chunk) - await trio.to_thread.run_sync( - lambda: settings.docStoreConn.delete( - {"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, - search.index_name(tenant_id), - kb_id, - ) - ) + await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,) es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) + doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,) if doc_store_result: error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" raise Exception(error_message) @@ -586,6 +603,6 @@ async def extract_community( callback(msg=f"Task {task_id} cancelled after community indexing.") raise TaskCanceledException(f"Task {task_id} was cancelled") - now = trio.current_time() + now = asyncio.get_running_loop().time() callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.") return community_structure, community_reports diff --git a/graphrag/general/mind_map_extractor.py b/graphrag/general/mind_map_extractor.py index c85579d3d..f944aec98 100644 --- a/graphrag/general/mind_map_extractor.py +++ b/graphrag/general/mind_map_extractor.py @@ -14,12 +14,12 @@ # limitations under the License. # +import asyncio import logging import collections import re from typing import Any from dataclasses import dataclass -import trio from graphrag.general.extractor import Extractor from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT @@ -89,17 +89,29 @@ class MindMapExtractor(Extractor): token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512) texts = [] cnt = 0 - async with trio.open_nursery() as nursery: - for i in range(len(sections)): - section_cnt = num_tokens_from_string(sections[i]) - if cnt + section_cnt >= token_count and texts: - nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) - texts = [] - cnt = 0 - texts.append(sections[i]) - cnt += section_cnt - if texts: - nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) + tasks = [] + for i in range(len(sections)): + section_cnt = num_tokens_from_string(sections[i]) + if cnt + section_cnt >= token_count and texts: + tasks.append(asyncio.create_task( + self._process_document("".join(texts), prompt_variables, res) + )) + texts = [] + cnt = 0 + + texts.append(sections[i]) + cnt += section_cnt + if texts: + tasks.append(asyncio.create_task( + self._process_document("".join(texts), prompt_variables, res) + )) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise if not res: return MindMapResult(output={"id": "root", "children": []}) merge_json = reduce(self._merge, res) @@ -172,7 +184,7 @@ class MindMapExtractor(Extractor): } text = perform_variable_replacements(self._mind_map_prompt, variables=variables) async with chat_limiter: - response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], {})) + response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{}) response = re.sub(r"```[^\n]*", "", response) logging.debug(response) logging.debug(self._todict(markdown_to_json.dictify(response))) diff --git a/graphrag/general/smoke.py b/graphrag/general/smoke.py index 5a04d9782..ba405e193 100644 --- a/graphrag/general/smoke.py +++ b/graphrag/general/smoke.py @@ -15,10 +15,10 @@ # import argparse +import asyncio import json import logging import networkx as nx -import trio from common.constants import LLMType from api.db.services.document_service import DocumentService @@ -107,4 +107,4 @@ async def main(): if __name__ == "__main__": - trio.run(main) + asyncio.run(main) diff --git a/graphrag/light/graph_extractor.py b/graphrag/light/graph_extractor.py index e698c2b9f..f507f4617 100644 --- a/graphrag/light/graph_extractor.py +++ b/graphrag/light/graph_extractor.py @@ -5,13 +5,13 @@ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ +import asyncio import logging import re from dataclasses import dataclass from typing import Any import networkx as nx -import trio from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor from graphrag.light.graph_prompt import PROMPTS @@ -86,13 +86,12 @@ class GraphExtractor(Extractor): if self.callback: self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...") async with chat_limiter: - final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf, task_id) + final_result = await asyncio.to_thread(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id) token_count += num_tokens_from_string(hint_prompt + final_result) history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt) for now_glean_index in range(self._max_gleanings): async with chat_limiter: - # glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf)) - glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id) + glean_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id) history.extend([{"role": "assistant", "content": glean_result}]) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt) final_result += glean_result @@ -101,7 +100,7 @@ class GraphExtractor(Extractor): history.extend([{"role": "user", "content": self._if_loop_prompt}]) async with chat_limiter: - if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id) + if_loop_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": diff --git a/graphrag/light/smoke.py b/graphrag/light/smoke.py index bd4107ce6..bfa3ca256 100644 --- a/graphrag/light/smoke.py +++ b/graphrag/light/smoke.py @@ -15,10 +15,10 @@ # import argparse +import asyncio import json import networkx as nx import logging -import trio from common.constants import LLMType from api.db.services.document_service import DocumentService @@ -93,4 +93,4 @@ async def main(): if __name__ == "__main__": - trio.run(main) + asyncio.run(main) diff --git a/graphrag/search.py b/graphrag/search.py index b3a0104e1..7399ea393 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging from collections import defaultdict from copy import deepcopy import json_repair import pandas as pd -import trio from common.misc_utils import get_uuid from graphrag.query_analyze_prompt import PROMPTS @@ -44,7 +44,7 @@ class KGSearch(Dealer): return response def query_rewrite(self, llm, question, idxnms, kb_ids): - ty2ents = trio.run(lambda: get_entity_type2samples(idxnms, kb_ids)) + ty2ents = asyncio.run(get_entity_type2samples(idxnms, kb_ids)) hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {}) diff --git a/graphrag/utils.py b/graphrag/utils.py index 51a9c1abc..a39bdd2d7 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -6,6 +6,7 @@ Reference: - [LightRag](https://github.com/HKUDS/LightRAG) """ +import asyncio import dataclasses import html import json @@ -19,7 +20,6 @@ from typing import Any, Callable, Set, Tuple import networkx as nx import numpy as np -import trio import xxhash from networkx.readwrite import json_graph @@ -34,7 +34,7 @@ GRAPH_FIELD_SEP = "" ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] -chat_limiter = trio.CapacityLimiter(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) +chat_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) @dataclasses.dataclass @@ -314,8 +314,11 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): ebd = get_embed_cache(embd_mdl.llm_name, ent_name) if ebd is None: async with chat_limiter: - with trio.fail_after(3 if enable_timeout_assertion else 30000000): - ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name])) + timeout = 3 if enable_timeout_assertion else 30000000 + ebd, _ = await asyncio.wait_for( + asyncio.to_thread(embd_mdl.encode, [ent_name]), + timeout=timeout + ) ebd = ebd[0] set_embed_cache(embd_mdl.llm_name, ent_name, ebd) assert ebd is not None @@ -365,8 +368,14 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, ebd = get_embed_cache(embd_mdl.llm_name, txt) if ebd is None: async with chat_limiter: - with trio.fail_after(3 if enable_timeout_assertion else 300000000): - ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt + f": {meta['description']}"])) + timeout = 3 if enable_timeout_assertion else 300000000 + ebd, _ = await asyncio.wait_for( + asyncio.to_thread( + embd_mdl.encode, + [txt + f": {meta['description']}"] + ), + timeout=timeout + ) ebd = ebd[0] set_embed_cache(embd_mdl.llm_name, txt, ebd) assert ebd is not None @@ -381,7 +390,11 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): "knowledge_graph_kwd": ["graph"], "removed_kwd": "N", } - res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) + res = await asyncio.to_thread( + settings.docStoreConn.search, + fields, [], condition, [], OrderByExpr(), + 0, 1, search.index_name(tenant_id), [kb_id] + ) fields2 = settings.docStoreConn.get_fields(res, fields) graph_doc_ids = set() for chunk_id in fields2.keys(): @@ -391,7 +404,12 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]} - res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])) + res = await asyncio.to_thread( + settings.retriever.search, + conds, + search.index_name(tenant_id), + [kb_id] + ) doc_ids = [] if res.total == 0: return doc_ids @@ -402,7 +420,12 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: async def get_graph(tenant_id, kb_id, exclude_rebuild=None): conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]} - res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id]) + res = await asyncio.to_thread( + settings.retriever.search, + conds, + search.index_name(tenant_id), + [kb_id] + ) if not res.total == 0: for id in res.ids: try: @@ -421,26 +444,47 @@ async def get_graph(tenant_id, kb_id, exclude_rebuild=None): async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback): global chat_limiter - start = trio.current_time() + start = asyncio.get_running_loop().time() - await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) + await asyncio.to_thread( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": ["graph", "subgraph"]}, + search.index_name(tenant_id), + kb_id + ) if change.removed_nodes: - await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id) + await asyncio.to_thread( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, + search.index_name(tenant_id), + kb_id + ) if change.removed_edges: async def del_edges(from_node, to_node): async with chat_limiter: - await trio.to_thread.run_sync( - settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id + await asyncio.to_thread( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, + search.index_name(tenant_id), + kb_id ) - async with trio.open_nursery() as nursery: - for from_node, to_node in change.removed_edges: - nursery.start_soon(del_edges, from_node, to_node) + tasks = [] + for from_node, to_node in change.removed_edges: + tasks.append(asyncio.create_task(del_edges(from_node, to_node))) - now = trio.current_time() + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + now = asyncio.get_running_loop().time() if callback: callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.") start = now @@ -475,24 +519,41 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang } ) - async with trio.open_nursery() as nursery: - for ii, node in enumerate(change.added_updated_nodes): - node_attrs = graph.nodes[node] - nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks) - if ii % 100 == 9 and callback: - callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}") + tasks = [] + for ii, node in enumerate(change.added_updated_nodes): + node_attrs = graph.nodes[node] + tasks.append(asyncio.create_task( + graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks) + )) + if ii % 100 == 9 and callback: + callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}") + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise - async with trio.open_nursery() as nursery: - for ii, (from_node, to_node) in enumerate(change.added_updated_edges): - edge_attrs = graph.get_edge_data(from_node, to_node) - if not edge_attrs: - # added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging. - continue - nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) - if ii % 100 == 9 and callback: - callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}") + tasks = [] + for ii, (from_node, to_node) in enumerate(change.added_updated_edges): + edge_attrs = graph.get_edge_data(from_node, to_node) + if not edge_attrs: + continue + tasks.append(asyncio.create_task( + graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) + )) + if ii % 100 == 9 and callback: + callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}") + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise - now = trio.current_time() + now = asyncio.get_running_loop().time() if callback: callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") start = now @@ -500,14 +561,22 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): - with trio.fail_after(3 if enable_timeout_assertion else 30000000): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) + timeout = 3 if enable_timeout_assertion else 30000000 + doc_store_result = await asyncio.wait_for( + asyncio.to_thread( + settings.docStoreConn.insert, + chunks[b : b + es_bulk_size], + search.index_name(tenant_id), + kb_id + ), + timeout=timeout + ) if b % 100 == es_bulk_size and callback: callback(msg=f"Insert chunks: {b}/{len(chunks)}") if doc_store_result: error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" raise Exception(error_message) - now = trio.current_time() + now = asyncio.get_running_loop().time() if callback: callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.") @@ -555,7 +624,7 @@ def merge_tuples(list1, list2): async def get_entity_type2samples(idxnms, kb_ids: list): - es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) + es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids) res = defaultdict(list) for id in es_res.ids: @@ -588,8 +657,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None): flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"] bs = 256 for i in range(0, 1024 * bs, bs): - es_res = await trio.to_thread.run_sync( - lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) + es_res = await asyncio.to_thread( + settings.docStoreConn.search, + flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, + [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id] ) # tot = settings.docStoreConn.get_total(es_res) es_res = settings.docStoreConn.get_fields(es_res, flds) diff --git a/rag/raptor.py b/rag/raptor.py index a455d0127..201e67560 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import re import numpy as np -import trio import umap from sklearn.mixture import GaussianMixture @@ -56,37 +56,37 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: @timeout(60 * 20) async def _chat(self, system, history, gen_conf): - cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)) + cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf) if cached: return cached last_exc = None for attempt in range(3): try: - response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf)) + response = await asyncio.to_thread(self._llm_model.chat, system, history, gen_conf) response = re.sub(r"^.*", "", response, flags=re.DOTALL) if response.find("**ERROR**") >= 0: raise Exception(response) - await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)) + await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf) return response except Exception as exc: last_exc = exc logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc) if attempt < 2: - await trio.sleep(1 + attempt) + await asyncio.sleep(1 + attempt) raise last_exc if last_exc else Exception("LLM chat failed without exception") @timeout(20) async def _embedding_encode(self, txt): - response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt)) + response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt) if response is not None: return response - embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) + embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt]) if len(embds) < 1 or len(embds[0]) < 1: raise Exception("Embedding error: ") embds = embds[0] - await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds)) + await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds) return embds def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""): @@ -198,16 +198,21 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: lbls = [np.where(prob > self._threshold)[0] for prob in probs] lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] - async with trio.open_nursery() as nursery: - for c in range(n_clusters): - ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] - assert len(ck_idx) > 0 - - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.") - raise TaskCanceledException(f"Task {task_id} was cancelled") - - nursery.start_soon(summarize, ck_idx) + tasks = [] + for c in range(n_clusters): + ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] + assert len(ck_idx) > 0 + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.") + raise TaskCanceledException(f"Task {task_id} was cancelled") + tasks.append(asyncio.create_task(summarize(ck_idx))) + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) labels.extend(lbls) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index c1776c5c3..3a6d64d8c 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -920,7 +920,7 @@ async def do_handle_task(task): file_type = task.get("type", "") parser_id = task.get("parser_id", "") raptor_config = kb_parser_config.get("raptor", {}) - + if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config): skip_reason = get_skip_reason(file_type, parser_id, task_parser_config) logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")