Refa:replace trio with asyncio in graphrag and api

This commit is contained in:
buua436 2025-12-09 17:07:11 +08:00
parent 94dccd8270
commit 9298007478
17 changed files with 411 additions and 222 deletions

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import binascii import binascii
import logging import logging
import re import re
@ -21,7 +22,6 @@ from copy import deepcopy
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from timeit import default_timer as timer from timeit import default_timer as timer
import trio
from langfuse import Langfuse from langfuse import Langfuse
from peewee import fn from peewee import fn
from agentic_reasoning import DeepResearcher from agentic_reasoning import DeepResearcher
@ -931,5 +931,5 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
rank_feature=label_question(question, kbs), rank_feature=label_question(question, kbs),
) )
mindmap = MindMapExtractor(chat_mdl) mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in ranks["chunks"]]))
return mind_map.output return mind_map.output

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import random import random
@ -22,7 +23,6 @@ from copy import deepcopy
from datetime import datetime from datetime import datetime
from io import BytesIO from io import BytesIO
import trio
import xxhash import xxhash
from peewee import fn, Case, JOIN from peewee import fn, Case, JOIN
@ -999,7 +999,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
from graphrag.general.mind_map_extractor import MindMapExtractor from graphrag.general.mind_map_extractor import MindMapExtractor
mindmap = MindMapExtractor(llm_bdl) mindmap = MindMapExtractor(llm_bdl)
try: try:
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]) mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32: if len(mind_map) < 32:
raise Exception("Few content: " + mind_map) raise Exception("Few content: " + mind_map)

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import functools import functools
import inspect import inspect
import json import json
@ -25,7 +26,6 @@ from functools import wraps
from typing import Any from typing import Any
import requests import requests
import trio
from quart import ( from quart import (
Response, Response,
jsonify, jsonify,
@ -681,18 +681,36 @@ async def is_strong_enough(chat_model, embedding_model):
async def _is_strong_enough(): async def _is_strong_enough():
nonlocal chat_model, embedding_model nonlocal chat_model, embedding_model
if embedding_model: if embedding_model:
with trio.fail_after(10): await asyncio.wait_for(
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
timeout=10
)
if chat_model: if chat_model:
with trio.fail_after(30): res = await asyncio.wait_for(
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {})) asyncio.to_thread(
if res.find("**ERROR**") >= 0: chat_model.chat,
"Nothing special.",
[{"role": "user", "content": "Are you strong enough!?"}],
{}
),
timeout=30
)
if "**ERROR**" in res:
raise Exception(res) raise Exception(res)
# Pressure test for GraphRAG task # Pressure test for GraphRAG task
async with trio.open_nursery() as nursery: tasks = [
for _ in range(count): asyncio.create_task(_is_strong_enough())
nursery.start_soon(_is_strong_enough) for _ in range(count)
]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
def get_allowed_llm_factories() -> list: def get_allowed_llm_factories() -> list:

View file

@ -19,7 +19,6 @@ import queue
import threading import threading
from typing import Any, Callable, Coroutine, Optional, Type, Union from typing import Any, Callable, Coroutine, Optional, Type, Union
import asyncio import asyncio
import trio
from functools import wraps from functools import wraps
from quart import make_response, jsonify from quart import make_response, jsonify
from common.constants import RetCode from common.constants import RetCode
@ -70,11 +69,10 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
for a in range(attempts): for a in range(attempts):
try: try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
with trio.fail_after(seconds): return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
return await func(*args, **kwargs)
else: else:
return await func(*args, **kwargs) return await func(*args, **kwargs)
except trio.TooSlowError: except asyncio.TimeoutError:
if a < attempts - 1: if a < attempts - 1:
continue continue
if on_timeout is not None: if on_timeout is not None:

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import itertools import itertools
import os import os
@ -21,7 +22,6 @@ from dataclasses import dataclass
from typing import Any, Callable from typing import Any, Callable
import networkx as nx import networkx as nx
import trio
from graphrag.general.extractor import Extractor from graphrag.general.extractor import Extractor
from rag.nlp import is_english from rag.nlp import is_english
@ -101,35 +101,55 @@ class EntityResolution(Extractor):
remain_candidates_to_resolve = num_candidates remain_candidates_to_resolve = num_candidates
resolution_result = set() resolution_result = set()
resolution_result_lock = trio.Lock() resolution_result_lock = asyncio.Lock()
resolution_batch_size = 100 resolution_batch_size = 100
max_concurrent_tasks = 5 max_concurrent_tasks = 5
semaphore = trio.Semaphore(max_concurrent_tasks) semaphore = asyncio.Semaphore(max_concurrent_tasks)
async def limited_resolve_candidate(candidate_batch, result_set, result_lock): async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
nonlocal remain_candidates_to_resolve, callback nonlocal remain_candidates_to_resolve, callback
async with semaphore: async with semaphore:
try: try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: timeout_sec = 280 if enable_timeout_assertion else 1_000_000_000
await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id)
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) try:
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ") await asyncio.wait_for(
if cancel_scope.cancelled_caught: self._resolve_candidate(candidate_batch, result_set, result_lock, task_id),
timeout=timeout_sec
)
remain_candidates_to_resolve -= len(candidate_batch[1])
callback(
msg=f"Resolved {len(candidate_batch[1])} pairs, "
f"{remain_candidates_to_resolve} remain."
)
except asyncio.TimeoutError:
logging.warning(f"Timeout resolving {candidate_batch}, skipping...") logging.warning(f"Timeout resolving {candidate_batch}, skipping...")
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) remain_candidates_to_resolve -= len(candidate_batch[1])
callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ") callback(
msg=f"Failed to resolve {len(candidate_batch[1])} pairs due to timeout, skipped. "
f"{remain_candidates_to_resolve} remain."
)
except Exception as e: except Exception as e:
logging.error(f"Error resolving candidate batch: {e}") logging.error(f"Error resolving candidate batch: {e}")
async with trio.open_nursery() as nursery: tasks = []
for candidate_resolution_i in candidate_resolution.items(): for key, lst in candidate_resolution.items():
if not candidate_resolution_i[1]: if not lst:
continue continue
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size): for i in range(0, len(lst), resolution_batch_size):
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size] batch = (key, lst[i:i + resolution_batch_size])
nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock) tasks.append(limited_resolve_candidate(batch, resolution_result, resolution_result_lock))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
@ -141,10 +161,18 @@ class EntityResolution(Extractor):
async with semaphore: async with semaphore:
await self._merge_graph_nodes(graph, nodes, change, task_id) await self._merge_graph_nodes(graph, nodes, change, task_id)
async with trio.open_nursery() as nursery: tasks = []
for sub_connect_graph in nx.connected_components(connect_graph): for sub_connect_graph in nx.connected_components(connect_graph):
merging_nodes = list(sub_connect_graph) merging_nodes = list(sub_connect_graph)
nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change) tasks.append(asyncio.create_task(limited_merge_nodes(graph, merging_nodes, change))
)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
# Update pagerank # Update pagerank
pr = nx.pagerank(graph) pr = nx.pagerank(graph)
@ -156,7 +184,7 @@ class EntityResolution(Extractor):
change=change, change=change,
) )
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""): async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: asyncio.Lock, task_id: str = ""):
if task_id: if task_id:
if has_canceled(task_id): if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.") logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
@ -178,13 +206,22 @@ class EntityResolution(Extractor):
text = perform_variable_replacements(self._resolution_prompt, variables=variables) text = perform_variable_replacements(self._resolution_prompt, variables=variables)
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
async with chat_limiter: async with chat_limiter:
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
try: try:
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") response = await asyncio.wait_for(
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope: asyncio.to_thread(
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id) self._chat,
if cancel_scope.cancelled_caught: text,
logging.warning("_resolve_candidate._chat timeout, skipping...") [{"role": "user", "content": "Output:"}],
return {},
task_id
),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
logging.warning("_resolve_candidate._chat timeout, skipping...")
return
except Exception as e: except Exception as e:
logging.error(f"_resolve_candidate._chat failed: {e}") logging.error(f"_resolve_candidate._chat failed: {e}")
return return

View file

@ -5,6 +5,7 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag) - [graphrag](https://github.com/microsoft/graphrag)
""" """
import asyncio
import logging import logging
import json import json
import os import os
@ -24,7 +25,6 @@ from graphrag.general.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from common.token_utils import num_tokens_from_string from common.token_utils import num_tokens_from_string
import trio
@dataclass @dataclass
@ -101,14 +101,11 @@ class CommunityReportsExtractor(Extractor):
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
async with chat_limiter: async with chat_limiter:
try: try:
with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope: timeout = 180 if enable_timeout_assertion else 1000000000
if task_id and has_canceled(task_id): response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
logging.info(f"Task {task_id} cancelled before LLM call.") except asyncio.TimeoutError:
raise TaskCanceledException(f"Task {task_id} was cancelled") logging.warning("extract_community_report._chat timeout, skipping...")
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id) return
if cancel_scope.cancelled_caught:
logging.warning("extract_community_report._chat timeout, skipping...")
return
except Exception as e: except Exception as e:
logging.error(f"extract_community_report._chat failed: {e}") logging.error(f"extract_community_report._chat failed: {e}")
return return
@ -141,17 +138,24 @@ class CommunityReportsExtractor(Extractor):
if callback: if callback:
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}") callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
st = trio.current_time() st = asyncio.get_running_loop().time()
async with trio.open_nursery() as nursery: tasks = []
for level, comm in communities.items(): for level, comm in communities.items():
logging.info(f"Level {level}: Community: {len(comm.keys())}") logging.info(f"Level {level}: Community: {len(comm.keys())}")
for community in comm.items(): for community in comm.items():
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before community processing.") logging.info(f"Task {task_id} cancelled before community processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
nursery.start_soon(extract_community_report, community) tasks.append(asyncio.create_task(extract_community_report(community)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if callback: if callback:
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}") callback(msg=f"Community reports done in {asyncio.get_running_loop().time() - st:.2f}s, used tokens: {token_count}")
return CommunityReportsResult( return CommunityReportsResult(
structured_output=res_dict, structured_output=res_dict,

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import os import os
import re import re
@ -21,7 +22,6 @@ from copy import deepcopy
from typing import Callable from typing import Callable
import networkx as nx import networkx as nx
import trio
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled
from common.connection_utils import timeout from common.connection_utils import timeout
@ -109,14 +109,14 @@ class Extractor:
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""): async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
self.callback = callback self.callback = callback
start_ts = trio.current_time() start_ts = asyncio.get_running_loop().time()
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = [] out_results = []
error_count = 0 error_count = 0
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3)) max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency) limiter = asyncio.Semaphore(max_concurrency)
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""): async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
nonlocal error_count nonlocal error_count
@ -137,9 +137,18 @@ class Extractor:
if error_count > max_errors: if error_count > max_errors:
raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}") raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}")
async with trio.open_nursery() as nursery: tasks = [
for i, ck in enumerate(chunks): asyncio.create_task(worker((doc_id, ck), i, len(chunks), task_id))
nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id) for i, ck in enumerate(chunks)
]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if error_count > 0: if error_count > 0:
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)" warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
@ -166,7 +175,7 @@ class Extractor:
for k, v in m_edges.items(): for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v) maybe_edges[tuple(sorted(k))].extend(v)
sum_token_count += token_count sum_token_count += token_count
now = trio.current_time() now = asyncio.get_running_loop().time()
if self.callback: if self.callback:
self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.") self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.")
start_ts = now start_ts = now
@ -176,14 +185,22 @@ class Extractor:
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging") raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
async with trio.open_nursery() as nursery: tasks = [
for en_nm, ents in maybe_nodes.items(): asyncio.create_task(self._merge_nodes(en_nm, ents, all_entities_data, task_id))
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id) for en_nm, ents in maybe_nodes.items()
]
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging") raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
now = trio.current_time() now = asyncio.get_running_loop().time()
if self.callback: if self.callback:
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.") self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
@ -194,14 +211,25 @@ class Extractor:
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging") raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
async with trio.open_nursery() as nursery: tasks = []
for (src, tgt), rels in maybe_edges.items(): for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id) tasks.append(
asyncio.create_task(
self._merge_edges(src, tgt, rels, all_relationships_data, task_id)
)
)
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging") raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
now = trio.current_time() now = asyncio.get_running_loop().time()
if self.callback: if self.callback:
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.") self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
@ -309,5 +337,5 @@ class Extractor:
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling") raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter: async with chat_limiter:
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id) summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary return summary

View file

@ -5,11 +5,11 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag) - [graphrag](https://github.com/microsoft/graphrag)
""" """
import asyncio
import re import re
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
import tiktoken import tiktoken
import trio
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
@ -107,7 +107,7 @@ class GraphExtractor(Extractor):
} }
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables) hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter: async with chat_limiter:
response = await trio.to_thread.run_sync(self._chat, hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id) response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
token_count += num_tokens_from_string(hint_prompt + response) token_count += num_tokens_from_string(hint_prompt + response)
results = response or "" results = response or ""
@ -117,7 +117,7 @@ class GraphExtractor(Extractor):
for i in range(self._max_gleanings): for i in range(self._max_gleanings):
history.append({"role": "user", "content": CONTINUE_PROMPT}) history.append({"role": "user", "content": CONTINUE_PROMPT})
async with chat_limiter: async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat("", history, {})) response = await asyncio.to_thread(self._chat, "", history, {})
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or "" results += response or ""
@ -127,7 +127,7 @@ class GraphExtractor(Extractor):
history.append({"role": "assistant", "content": response}) history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT}) history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter: async with chat_limiter:
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history)) continuation = await asyncio.to_thread(self._chat, "", history)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "Y": if continuation != "Y":
break break

View file

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import os import os
import networkx as nx import networkx as nx
import trio
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled from api.db.services.task_service import has_canceled
@ -54,25 +54,35 @@ async def run_graphrag(
callback, callback,
): ):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time() start = asyncio.get_running_loop().time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = [] chunks = []
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True): for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"]) chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
subgraph = await generate_subgraph(
LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt, try:
tenant_id, subgraph = await asyncio.wait_for(
kb_id, generate_subgraph(
doc_id, LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {})
chunks, or row["kb_parser_config"]["graphrag"]["method"] != "general"
language, else GeneralKGExt,
row["kb_parser_config"]["graphrag"].get("entity_types", []), tenant_id,
chat_model, kb_id,
embedding_model, doc_id,
callback, chunks,
language,
row["kb_parser_config"]["graphrag"].get("entity_types", []),
chat_model,
embedding_model,
callback,
),
timeout=timeout_sec,
) )
except asyncio.TimeoutError:
logging.error("generate_subgraph timeout")
raise
if not subgraph: if not subgraph:
return return
@ -125,7 +135,7 @@ async def run_graphrag(
) )
finally: finally:
graphrag_task_lock.release() graphrag_task_lock.release()
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.") callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
return return
@ -145,7 +155,7 @@ async def run_graphrag_for_kb(
) -> dict: ) -> dict:
tenant_id, kb_id = row["tenant_id"], row["kb_id"] tenant_id, kb_id = row["tenant_id"], row["kb_id"]
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time() start = asyncio.get_running_loop().time()
fields_for_chunks = ["content_with_weight", "doc_id"] fields_for_chunks = ["content_with_weight", "doc_id"]
if not doc_ids: if not doc_ids:
@ -211,7 +221,7 @@ async def run_graphrag_for_kb(
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.") callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0} return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
semaphore = trio.Semaphore(max_parallel_docs) semaphore = asyncio.Semaphore(max_parallel_docs)
subgraphs: dict[str, object] = {} subgraphs: dict[str, object] = {}
failed_docs: list[tuple[str, str]] = [] # (doc_id, error) failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
@ -234,20 +244,28 @@ async def run_graphrag_for_kb(
try: try:
msg = f"[GraphRAG] build_subgraph doc:{doc_id}" msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)") callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
with trio.fail_after(deadline):
sg = await generate_subgraph( try:
kg_extractor, sg = await asyncio.wait_for(
tenant_id, generate_subgraph(
kb_id, kg_extractor,
doc_id, tenant_id,
chunks, kb_id,
language, doc_id,
kb_parser_config.get("graphrag", {}).get("entity_types", []), chunks,
chat_model, language,
embedding_model, kb_parser_config.get("graphrag", {}).get("entity_types", []),
callback, chat_model,
task_id=row["id"] embedding_model,
callback,
task_id=row["id"]
),
timeout=deadline,
) )
except asyncio.TimeoutError:
failed_docs.append((doc_id, "timeout"))
callback(msg=f"{msg} FAILED: timeout")
return
if sg: if sg:
subgraphs[doc_id] = sg subgraphs[doc_id] = sg
callback(msg=f"{msg} done") callback(msg=f"{msg} done")
@ -264,9 +282,14 @@ async def run_graphrag_for_kb(
callback(msg=f"Task {row['id']} cancelled before processing documents.") callback(msg=f"Task {row['id']} cancelled before processing documents.")
raise TaskCanceledException(f"Task {row['id']} was cancelled") raise TaskCanceledException(f"Task {row['id']} was cancelled")
async with trio.open_nursery() as nursery: tasks = [asyncio.create_task(build_one(doc_id)) for doc_id in doc_ids]
for doc_id in doc_ids: try:
nursery.start_soon(build_one, doc_id) await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if has_canceled(row["id"]): if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled after document processing.") callback(msg=f"Task {row['id']} cancelled after document processing.")
@ -275,7 +298,7 @@ async def run_graphrag_for_kb(
ok_docs = [d for d in doc_ids if d in subgraphs] ok_docs = [d for d in doc_ids if d in subgraphs]
if not ok_docs: if not ok_docs:
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.") callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
now = trio.current_time() now = asyncio.get_running_loop().time()
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200) kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
@ -313,7 +336,7 @@ async def run_graphrag_for_kb(
kb_lock.release() kb_lock.release()
if not with_resolution and not with_community: if not with_resolution and not with_community:
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}") callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
@ -356,7 +379,7 @@ async def run_graphrag_for_kb(
finally: finally:
kb_lock.release() kb_lock.release()
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}") callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
return { return {
"ok_docs": ok_docs, "ok_docs": ok_docs,
@ -388,7 +411,7 @@ async def generate_subgraph(
if contains: if contains:
callback(msg=f"Graph already contains {doc_id}") callback(msg=f"Graph already contains {doc_id}")
return None return None
start = trio.current_time() start = asyncio.get_running_loop().time()
ext = extractor( ext = extractor(
llm_bdl, llm_bdl,
language=language, language=language,
@ -436,9 +459,9 @@ async def generate_subgraph(
"removed_kwd": "N", "removed_kwd": "N",
} }
cid = chunk_id(chunk) cid = chunk_id(chunk)
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph return subgraph
@ -452,7 +475,7 @@ async def merge_subgraph(
embedding_model, embedding_model,
callback, callback,
): ):
start = trio.current_time() start = asyncio.get_running_loop().time()
change = GraphChange() change = GraphChange()
old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"]) old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"])
if old_graph is not None: if old_graph is not None:
@ -468,7 +491,7 @@ async def merge_subgraph(
new_graph.nodes[node_name]["pagerank"] = pagerank new_graph.nodes[node_name]["pagerank"] = pagerank
await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback) await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.") callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.")
return new_graph return new_graph
@ -490,7 +513,7 @@ async def resolve_entities(
callback(msg=f"Task {task_id} cancelled during entity resolution.") callback(msg=f"Task {task_id} cancelled during entity resolution.")
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time() start = asyncio.get_running_loop().time()
er = EntityResolution( er = EntityResolution(
llm_bdl, llm_bdl,
) )
@ -505,7 +528,7 @@ async def resolve_entities(
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback) await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"Graph resolution done in {now - start:.2f}s.") callback(msg=f"Graph resolution done in {now - start:.2f}s.")
@ -524,7 +547,7 @@ async def extract_community(
callback(msg=f"Task {task_id} cancelled before community extraction.") callback(msg=f"Task {task_id} cancelled before community extraction.")
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
start = trio.current_time() start = asyncio.get_running_loop().time()
ext = CommunityReportsExtractor( ext = CommunityReportsExtractor(
llm_bdl, llm_bdl,
) )
@ -538,7 +561,7 @@ async def extract_community(
community_reports = cr.output community_reports = cr.output
doc_ids = graph.graph["source_id"] doc_ids = graph.graph["source_id"]
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.") callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
start = now start = now
if task_id and has_canceled(task_id): if task_id and has_canceled(task_id):
@ -568,16 +591,10 @@ async def extract_community(
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
chunks.append(chunk) chunks.append(chunk)
await trio.to_thread.run_sync( await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
lambda: settings.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(tenant_id),
kb_id,
)
)
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
if doc_store_result: if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message) raise Exception(error_message)
@ -586,6 +603,6 @@ async def extract_community(
callback(msg=f"Task {task_id} cancelled after community indexing.") callback(msg=f"Task {task_id} cancelled after community indexing.")
raise TaskCanceledException(f"Task {task_id} was cancelled") raise TaskCanceledException(f"Task {task_id} was cancelled")
now = trio.current_time() now = asyncio.get_running_loop().time()
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.") callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
return community_structure, community_reports return community_structure, community_reports

View file

@ -14,12 +14,12 @@
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import collections import collections
import re import re
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
import trio
from graphrag.general.extractor import Extractor from graphrag.general.extractor import Extractor
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
@ -89,17 +89,29 @@ class MindMapExtractor(Extractor):
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512) token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
texts = [] texts = []
cnt = 0 cnt = 0
async with trio.open_nursery() as nursery: tasks = []
for i in range(len(sections)): for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i]) section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts: if cnt + section_cnt >= token_count and texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) tasks.append(asyncio.create_task(
texts = [] self._process_document("".join(texts), prompt_variables, res)
cnt = 0 ))
texts.append(sections[i]) texts = []
cnt += section_cnt cnt = 0
if texts:
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res) texts.append(sections[i])
cnt += section_cnt
if texts:
tasks.append(asyncio.create_task(
self._process_document("".join(texts), prompt_variables, res)
))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
if not res: if not res:
return MindMapResult(output={"id": "root", "children": []}) return MindMapResult(output={"id": "root", "children": []})
merge_json = reduce(self._merge, res) merge_json = reduce(self._merge, res)
@ -172,7 +184,7 @@ class MindMapExtractor(Extractor):
} }
text = perform_variable_replacements(self._mind_map_prompt, variables=variables) text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
async with chat_limiter: async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], {})) response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{})
response = re.sub(r"```[^\n]*", "", response) response = re.sub(r"```[^\n]*", "", response)
logging.debug(response) logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response))) logging.debug(self._todict(markdown_to_json.dictify(response)))

View file

@ -15,10 +15,10 @@
# #
import argparse import argparse
import asyncio
import json import json
import logging import logging
import networkx as nx import networkx as nx
import trio
from common.constants import LLMType from common.constants import LLMType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
@ -107,4 +107,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
trio.run(main) asyncio.run(main)

View file

@ -5,13 +5,13 @@ Reference:
- [graphrag](https://github.com/microsoft/graphrag) - [graphrag](https://github.com/microsoft/graphrag)
""" """
import asyncio
import logging import logging
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
import networkx as nx import networkx as nx
import trio
from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor
from graphrag.light.graph_prompt import PROMPTS from graphrag.light.graph_prompt import PROMPTS
@ -86,13 +86,12 @@ class GraphExtractor(Extractor):
if self.callback: if self.callback:
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...") self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
async with chat_limiter: async with chat_limiter:
final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf, task_id) final_result = await asyncio.to_thread(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
token_count += num_tokens_from_string(hint_prompt + final_result) token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt) history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings): for now_glean_index in range(self._max_gleanings):
async with chat_limiter: async with chat_limiter:
# glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf)) glean_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
history.extend([{"role": "assistant", "content": glean_result}]) history.extend([{"role": "assistant", "content": glean_result}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result final_result += glean_result
@ -101,7 +100,7 @@ class GraphExtractor(Extractor):
history.extend([{"role": "user", "content": self._if_loop_prompt}]) history.extend([{"role": "user", "content": self._if_loop_prompt}])
async with chat_limiter: async with chat_limiter:
if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id) if_loop_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes": if if_loop_result != "yes":

View file

@ -15,10 +15,10 @@
# #
import argparse import argparse
import asyncio
import json import json
import networkx as nx import networkx as nx
import logging import logging
import trio
from common.constants import LLMType from common.constants import LLMType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
@ -93,4 +93,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
trio.run(main) asyncio.run(main)

View file

@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from copy import deepcopy
import json_repair import json_repair
import pandas as pd import pandas as pd
import trio
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from graphrag.query_analyze_prompt import PROMPTS from graphrag.query_analyze_prompt import PROMPTS
@ -44,7 +44,7 @@ class KGSearch(Dealer):
return response return response
def query_rewrite(self, llm, question, idxnms, kb_ids): def query_rewrite(self, llm, question, idxnms, kb_ids):
ty2ents = trio.run(lambda: get_entity_type2samples(idxnms, kb_ids)) ty2ents = asyncio.run(get_entity_type2samples(idxnms, kb_ids))
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {}) result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})

View file

@ -6,6 +6,7 @@ Reference:
- [LightRag](https://github.com/HKUDS/LightRAG) - [LightRag](https://github.com/HKUDS/LightRAG)
""" """
import asyncio
import dataclasses import dataclasses
import html import html
import json import json
@ -19,7 +20,6 @@ from typing import Any, Callable, Set, Tuple
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import trio
import xxhash import xxhash
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
@ -34,7 +34,7 @@ GRAPH_FIELD_SEP = "<SEP>"
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
chat_limiter = trio.CapacityLimiter(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) chat_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
@dataclasses.dataclass @dataclasses.dataclass
@ -314,8 +314,11 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
ebd = get_embed_cache(embd_mdl.llm_name, ent_name) ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None: if ebd is None:
async with chat_limiter: async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 30000000): timeout = 3 if enable_timeout_assertion else 30000000
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name])) ebd, _ = await asyncio.wait_for(
asyncio.to_thread(embd_mdl.encode, [ent_name]),
timeout=timeout
)
ebd = ebd[0] ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd) set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
assert ebd is not None assert ebd is not None
@ -365,8 +368,14 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
ebd = get_embed_cache(embd_mdl.llm_name, txt) ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None: if ebd is None:
async with chat_limiter: async with chat_limiter:
with trio.fail_after(3 if enable_timeout_assertion else 300000000): timeout = 3 if enable_timeout_assertion else 300000000
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt + f": {meta['description']}"])) ebd, _ = await asyncio.wait_for(
asyncio.to_thread(
embd_mdl.encode,
[txt + f": {meta['description']}"]
),
timeout=timeout
)
ebd = ebd[0] ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd) set_embed_cache(embd_mdl.llm_name, txt, ebd)
assert ebd is not None assert ebd is not None
@ -381,7 +390,11 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"knowledge_graph_kwd": ["graph"], "knowledge_graph_kwd": ["graph"],
"removed_kwd": "N", "removed_kwd": "N",
} }
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) res = await asyncio.to_thread(
settings.docStoreConn.search,
fields, [], condition, [], OrderByExpr(),
0, 1, search.index_name(tenant_id), [kb_id]
)
fields2 = settings.docStoreConn.get_fields(res, fields) fields2 = settings.docStoreConn.get_fields(res, fields)
graph_doc_ids = set() graph_doc_ids = set()
for chunk_id in fields2.keys(): for chunk_id in fields2.keys():
@ -391,7 +404,12 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]} conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])) res = await asyncio.to_thread(
settings.retriever.search,
conds,
search.index_name(tenant_id),
[kb_id]
)
doc_ids = [] doc_ids = []
if res.total == 0: if res.total == 0:
return doc_ids return doc_ids
@ -402,7 +420,12 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
async def get_graph(tenant_id, kb_id, exclude_rebuild=None): async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]} conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id]) res = await asyncio.to_thread(
settings.retriever.search,
conds,
search.index_name(tenant_id),
[kb_id]
)
if not res.total == 0: if not res.total == 0:
for id in res.ids: for id in res.ids:
try: try:
@ -421,26 +444,47 @@ async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback): async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
global chat_limiter global chat_limiter
start = trio.current_time() start = asyncio.get_running_loop().time()
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) await asyncio.to_thread(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["graph", "subgraph"]},
search.index_name(tenant_id),
kb_id
)
if change.removed_nodes: if change.removed_nodes:
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id) await asyncio.to_thread(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)},
search.index_name(tenant_id),
kb_id
)
if change.removed_edges: if change.removed_edges:
async def del_edges(from_node, to_node): async def del_edges(from_node, to_node):
async with chat_limiter: async with chat_limiter:
await trio.to_thread.run_sync( await asyncio.to_thread(
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
search.index_name(tenant_id),
kb_id
) )
async with trio.open_nursery() as nursery: tasks = []
for from_node, to_node in change.removed_edges: for from_node, to_node in change.removed_edges:
nursery.start_soon(del_edges, from_node, to_node) tasks.append(asyncio.create_task(del_edges(from_node, to_node)))
now = trio.current_time() try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
now = asyncio.get_running_loop().time()
if callback: if callback:
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.") callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
start = now start = now
@ -475,24 +519,41 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
} }
) )
async with trio.open_nursery() as nursery: tasks = []
for ii, node in enumerate(change.added_updated_nodes): for ii, node in enumerate(change.added_updated_nodes):
node_attrs = graph.nodes[node] node_attrs = graph.nodes[node]
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks) tasks.append(asyncio.create_task(
if ii % 100 == 9 and callback: graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks)
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}") ))
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
async with trio.open_nursery() as nursery: tasks = []
for ii, (from_node, to_node) in enumerate(change.added_updated_edges): for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
edge_attrs = graph.get_edge_data(from_node, to_node) edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs: if not edge_attrs:
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging. continue
continue tasks.append(asyncio.create_task(
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
if ii % 100 == 9 and callback: ))
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}") if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
now = trio.current_time() now = asyncio.get_running_loop().time()
if callback: if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
start = now start = now
@ -500,14 +561,22 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
with trio.fail_after(3 if enable_timeout_assertion else 30000000): timeout = 3 if enable_timeout_assertion else 30000000
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) doc_store_result = await asyncio.wait_for(
asyncio.to_thread(
settings.docStoreConn.insert,
chunks[b : b + es_bulk_size],
search.index_name(tenant_id),
kb_id
),
timeout=timeout
)
if b % 100 == es_bulk_size and callback: if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}") callback(msg=f"Insert chunks: {b}/{len(chunks)}")
if doc_store_result: if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message) raise Exception(error_message)
now = trio.current_time() now = asyncio.get_running_loop().time()
if callback: if callback:
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.") callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
@ -555,7 +624,7 @@ def merge_tuples(list1, list2):
async def get_entity_type2samples(idxnms, kb_ids: list): async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
res = defaultdict(list) res = defaultdict(list)
for id in es_res.ids: for id in es_res.ids:
@ -588,8 +657,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"] flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
bs = 256 bs = 256
for i in range(0, 1024 * bs, bs): for i in range(0, 1024 * bs, bs):
es_res = await trio.to_thread.run_sync( es_res = await asyncio.to_thread(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) settings.docStoreConn.search,
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]
) )
# tot = settings.docStoreConn.get_total(es_res) # tot = settings.docStoreConn.get_total(es_res)
es_res = settings.docStoreConn.get_fields(es_res, flds) es_res = settings.docStoreConn.get_fields(es_res, flds)

View file

@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import re import re
import numpy as np import numpy as np
import trio
import umap import umap
from sklearn.mixture import GaussianMixture from sklearn.mixture import GaussianMixture
@ -56,37 +56,37 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(60 * 20) @timeout(60 * 20)
async def _chat(self, system, history, gen_conf): async def _chat(self, system, history, gen_conf):
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)) cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
if cached: if cached:
return cached return cached
last_exc = None last_exc = None
for attempt in range(3): for attempt in range(3):
try: try:
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf)) response = await asyncio.to_thread(self._llm_model.chat, system, history, gen_conf)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL) response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0: if response.find("**ERROR**") >= 0:
raise Exception(response) raise Exception(response)
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)) await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
return response return response
except Exception as exc: except Exception as exc:
last_exc = exc last_exc = exc
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc) logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2: if attempt < 2:
await trio.sleep(1 + attempt) await asyncio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception") raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20) @timeout(20)
async def _embedding_encode(self, txt): async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt)) response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt)
if response is not None: if response is not None:
return response return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt])
if len(embds) < 1 or len(embds[0]) < 1: if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ") raise Exception("Embedding error: ")
embds = embds[0] embds = embds[0]
await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds)) await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds)
return embds return embds
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""): def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
@ -198,16 +198,21 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
lbls = [np.where(prob > self._threshold)[0] for prob in probs] lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
async with trio.open_nursery() as nursery: tasks = []
for c in range(n_clusters): for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
assert len(ck_idx) > 0 assert len(ck_idx) > 0
if task_id and has_canceled(task_id):
if task_id and has_canceled(task_id): logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.") raise TaskCanceledException(f"Task {task_id} was cancelled")
raise TaskCanceledException(f"Task {task_id} was cancelled") tasks.append(asyncio.create_task(summarize(ck_idx)))
try:
nursery.start_soon(summarize, ck_idx) await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls) labels.extend(lbls)

View file

@ -920,7 +920,7 @@ async def do_handle_task(task):
file_type = task.get("type", "") file_type = task.get("type", "")
parser_id = task.get("parser_id", "") parser_id = task.get("parser_id", "")
raptor_config = kb_parser_config.get("raptor", {}) raptor_config = kb_parser_config.get("raptor", {})
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config): if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config) skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}") logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")