Refa:replace trio with asyncio in graphrag and api
This commit is contained in:
parent
94dccd8270
commit
9298007478
17 changed files with 411 additions and 222 deletions
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import binascii
|
import binascii
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
@ -21,7 +22,6 @@ from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
import trio
|
|
||||||
from langfuse import Langfuse
|
from langfuse import Langfuse
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
from agentic_reasoning import DeepResearcher
|
from agentic_reasoning import DeepResearcher
|
||||||
|
|
@ -931,5 +931,5 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||||
rank_feature=label_question(question, kbs),
|
rank_feature=label_question(question, kbs),
|
||||||
)
|
)
|
||||||
mindmap = MindMapExtractor(chat_mdl)
|
mindmap = MindMapExtractor(chat_mdl)
|
||||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in ranks["chunks"]]))
|
||||||
return mind_map.output
|
return mind_map.output
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
|
@ -22,7 +23,6 @@ from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import trio
|
|
||||||
import xxhash
|
import xxhash
|
||||||
from peewee import fn, Case, JOIN
|
from peewee import fn, Case, JOIN
|
||||||
|
|
||||||
|
|
@ -999,7 +999,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||||
mindmap = MindMapExtractor(llm_bdl)
|
mindmap = MindMapExtractor(llm_bdl)
|
||||||
try:
|
try:
|
||||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
|
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
|
||||||
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
||||||
if len(mind_map) < 32:
|
if len(mind_map) < 32:
|
||||||
raise Exception("Few content: " + mind_map)
|
raise Exception("Few content: " + mind_map)
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
|
@ -25,7 +26,6 @@ from functools import wraps
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import trio
|
|
||||||
from quart import (
|
from quart import (
|
||||||
Response,
|
Response,
|
||||||
jsonify,
|
jsonify,
|
||||||
|
|
@ -681,18 +681,36 @@ async def is_strong_enough(chat_model, embedding_model):
|
||||||
async def _is_strong_enough():
|
async def _is_strong_enough():
|
||||||
nonlocal chat_model, embedding_model
|
nonlocal chat_model, embedding_model
|
||||||
if embedding_model:
|
if embedding_model:
|
||||||
with trio.fail_after(10):
|
await asyncio.wait_for(
|
||||||
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
|
asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]),
|
||||||
|
timeout=10
|
||||||
|
)
|
||||||
|
|
||||||
if chat_model:
|
if chat_model:
|
||||||
with trio.fail_after(30):
|
res = await asyncio.wait_for(
|
||||||
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
|
asyncio.to_thread(
|
||||||
if res.find("**ERROR**") >= 0:
|
chat_model.chat,
|
||||||
|
"Nothing special.",
|
||||||
|
[{"role": "user", "content": "Are you strong enough!?"}],
|
||||||
|
{}
|
||||||
|
),
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
if "**ERROR**" in res:
|
||||||
raise Exception(res)
|
raise Exception(res)
|
||||||
|
|
||||||
# Pressure test for GraphRAG task
|
# Pressure test for GraphRAG task
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = [
|
||||||
for _ in range(count):
|
asyncio.create_task(_is_strong_enough())
|
||||||
nursery.start_soon(_is_strong_enough)
|
for _ in range(count)
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_allowed_llm_factories() -> list:
|
def get_allowed_llm_factories() -> list:
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import queue
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Callable, Coroutine, Optional, Type, Union
|
from typing import Any, Callable, Coroutine, Optional, Type, Union
|
||||||
import asyncio
|
import asyncio
|
||||||
import trio
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from quart import make_response, jsonify
|
from quart import make_response, jsonify
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
|
|
@ -70,11 +69,10 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
|
||||||
for a in range(attempts):
|
for a in range(attempts):
|
||||||
try:
|
try:
|
||||||
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
||||||
with trio.fail_after(seconds):
|
return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
|
||||||
return await func(*args, **kwargs)
|
|
||||||
else:
|
else:
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
except trio.TooSlowError:
|
except asyncio.TimeoutError:
|
||||||
if a < attempts - 1:
|
if a < attempts - 1:
|
||||||
continue
|
continue
|
||||||
if on_timeout is not None:
|
if on_timeout is not None:
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
|
@ -21,7 +22,6 @@ from dataclasses import dataclass
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import trio
|
|
||||||
|
|
||||||
from graphrag.general.extractor import Extractor
|
from graphrag.general.extractor import Extractor
|
||||||
from rag.nlp import is_english
|
from rag.nlp import is_english
|
||||||
|
|
@ -101,35 +101,55 @@ class EntityResolution(Extractor):
|
||||||
remain_candidates_to_resolve = num_candidates
|
remain_candidates_to_resolve = num_candidates
|
||||||
|
|
||||||
resolution_result = set()
|
resolution_result = set()
|
||||||
resolution_result_lock = trio.Lock()
|
resolution_result_lock = asyncio.Lock()
|
||||||
resolution_batch_size = 100
|
resolution_batch_size = 100
|
||||||
max_concurrent_tasks = 5
|
max_concurrent_tasks = 5
|
||||||
semaphore = trio.Semaphore(max_concurrent_tasks)
|
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||||
|
|
||||||
async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
|
async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
|
||||||
nonlocal remain_candidates_to_resolve, callback
|
nonlocal remain_candidates_to_resolve, callback
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
try:
|
try:
|
||||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
timeout_sec = 280 if enable_timeout_assertion else 1_000_000_000
|
||||||
await self._resolve_candidate(candidate_batch, result_set, result_lock, task_id)
|
|
||||||
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
|
try:
|
||||||
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
|
await asyncio.wait_for(
|
||||||
if cancel_scope.cancelled_caught:
|
self._resolve_candidate(candidate_batch, result_set, result_lock, task_id),
|
||||||
|
timeout=timeout_sec
|
||||||
|
)
|
||||||
|
remain_candidates_to_resolve -= len(candidate_batch[1])
|
||||||
|
callback(
|
||||||
|
msg=f"Resolved {len(candidate_batch[1])} pairs, "
|
||||||
|
f"{remain_candidates_to_resolve} remain."
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
logging.warning(f"Timeout resolving {candidate_batch}, skipping...")
|
logging.warning(f"Timeout resolving {candidate_batch}, skipping...")
|
||||||
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
|
remain_candidates_to_resolve -= len(candidate_batch[1])
|
||||||
callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ")
|
callback(
|
||||||
|
msg=f"Failed to resolve {len(candidate_batch[1])} pairs due to timeout, skipped. "
|
||||||
|
f"{remain_candidates_to_resolve} remain."
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error resolving candidate batch: {e}")
|
logging.error(f"Error resolving candidate batch: {e}")
|
||||||
|
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = []
|
||||||
for candidate_resolution_i in candidate_resolution.items():
|
for key, lst in candidate_resolution.items():
|
||||||
if not candidate_resolution_i[1]:
|
if not lst:
|
||||||
continue
|
continue
|
||||||
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
|
for i in range(0, len(lst), resolution_batch_size):
|
||||||
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
|
batch = (key, lst[i:i + resolution_batch_size])
|
||||||
nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock)
|
tasks.append(limited_resolve_candidate(batch, resolution_result, resolution_result_lock))
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
|
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
|
||||||
|
|
||||||
|
|
@ -141,10 +161,18 @@ class EntityResolution(Extractor):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
await self._merge_graph_nodes(graph, nodes, change, task_id)
|
await self._merge_graph_nodes(graph, nodes, change, task_id)
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = []
|
||||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||||
merging_nodes = list(sub_connect_graph)
|
merging_nodes = list(sub_connect_graph)
|
||||||
nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change)
|
tasks.append(asyncio.create_task(limited_merge_nodes(graph, merging_nodes, change))
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
# Update pagerank
|
# Update pagerank
|
||||||
pr = nx.pagerank(graph)
|
pr = nx.pagerank(graph)
|
||||||
|
|
@ -156,7 +184,7 @@ class EntityResolution(Extractor):
|
||||||
change=change,
|
change=change,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock, task_id: str = ""):
|
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: asyncio.Lock, task_id: str = ""):
|
||||||
if task_id:
|
if task_id:
|
||||||
if has_canceled(task_id):
|
if has_canceled(task_id):
|
||||||
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
|
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
|
||||||
|
|
@ -178,13 +206,22 @@ class EntityResolution(Extractor):
|
||||||
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
||||||
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
|
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
|
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
|
||||||
try:
|
try:
|
||||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
response = await asyncio.wait_for(
|
||||||
with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
asyncio.to_thread(
|
||||||
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
|
self._chat,
|
||||||
if cancel_scope.cancelled_caught:
|
text,
|
||||||
logging.warning("_resolve_candidate._chat timeout, skipping...")
|
[{"role": "user", "content": "Output:"}],
|
||||||
return
|
{},
|
||||||
|
task_id
|
||||||
|
),
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logging.warning("_resolve_candidate._chat timeout, skipping...")
|
||||||
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"_resolve_candidate._chat failed: {e}")
|
logging.error(f"_resolve_candidate._chat failed: {e}")
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ Reference:
|
||||||
- [graphrag](https://github.com/microsoft/graphrag)
|
- [graphrag](https://github.com/microsoft/graphrag)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
@ -24,7 +25,6 @@ from graphrag.general.leiden import add_community_info2graph
|
||||||
from rag.llm.chat_model import Base as CompletionLLM
|
from rag.llm.chat_model import Base as CompletionLLM
|
||||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||||
from common.token_utils import num_tokens_from_string
|
from common.token_utils import num_tokens_from_string
|
||||||
import trio
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -101,14 +101,11 @@ class CommunityReportsExtractor(Extractor):
|
||||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
try:
|
try:
|
||||||
with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope:
|
timeout = 180 if enable_timeout_assertion else 1000000000
|
||||||
if task_id and has_canceled(task_id):
|
response = await asyncio.wait_for(asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
|
||||||
logging.info(f"Task {task_id} cancelled before LLM call.")
|
except asyncio.TimeoutError:
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||||
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {}, task_id)
|
return
|
||||||
if cancel_scope.cancelled_caught:
|
|
||||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"extract_community_report._chat failed: {e}")
|
logging.error(f"extract_community_report._chat failed: {e}")
|
||||||
return
|
return
|
||||||
|
|
@ -141,17 +138,24 @@ class CommunityReportsExtractor(Extractor):
|
||||||
if callback:
|
if callback:
|
||||||
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
|
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
|
||||||
|
|
||||||
st = trio.current_time()
|
st = asyncio.get_running_loop().time()
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = []
|
||||||
for level, comm in communities.items():
|
for level, comm in communities.items():
|
||||||
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
||||||
for community in comm.items():
|
for community in comm.items():
|
||||||
if task_id and has_canceled(task_id):
|
if task_id and has_canceled(task_id):
|
||||||
logging.info(f"Task {task_id} cancelled before community processing.")
|
logging.info(f"Task {task_id} cancelled before community processing.")
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||||
nursery.start_soon(extract_community_report, community)
|
tasks.append(asyncio.create_task(extract_community_report(community)))
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
if callback:
|
if callback:
|
||||||
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
|
callback(msg=f"Community reports done in {asyncio.get_running_loop().time() - st:.2f}s, used tokens: {token_count}")
|
||||||
|
|
||||||
return CommunityReportsResult(
|
return CommunityReportsResult(
|
||||||
structured_output=res_dict,
|
structured_output=res_dict,
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
@ -21,7 +22,6 @@ from copy import deepcopy
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import trio
|
|
||||||
|
|
||||||
from api.db.services.task_service import has_canceled
|
from api.db.services.task_service import has_canceled
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
|
|
@ -109,14 +109,14 @@ class Extractor:
|
||||||
|
|
||||||
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
|
async def __call__(self, doc_id: str, chunks: list[str], callback: Callable | None = None, task_id: str = ""):
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
start_ts = trio.current_time()
|
start_ts = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
|
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
|
||||||
out_results = []
|
out_results = []
|
||||||
error_count = 0
|
error_count = 0
|
||||||
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
|
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
|
||||||
|
|
||||||
limiter = trio.Semaphore(max_concurrency)
|
limiter = asyncio.Semaphore(max_concurrency)
|
||||||
|
|
||||||
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
|
async def worker(chunk_key_dp: tuple[str, str], idx: int, total: int, task_id=""):
|
||||||
nonlocal error_count
|
nonlocal error_count
|
||||||
|
|
@ -137,9 +137,18 @@ class Extractor:
|
||||||
if error_count > max_errors:
|
if error_count > max_errors:
|
||||||
raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}")
|
raise Exception(f"Maximum error count ({max_errors}) reached. Last errors: {str(e)}")
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = [
|
||||||
for i, ck in enumerate(chunks):
|
asyncio.create_task(worker((doc_id, ck), i, len(chunks), task_id))
|
||||||
nursery.start_soon(worker, (doc_id, ck), i, len(chunks), task_id)
|
for i, ck in enumerate(chunks)
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
if error_count > 0:
|
if error_count > 0:
|
||||||
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
|
warning_msg = f"Completed with {error_count} errors (out of {len(chunks)} chunks processed)"
|
||||||
|
|
@ -166,7 +175,7 @@ class Extractor:
|
||||||
for k, v in m_edges.items():
|
for k, v in m_edges.items():
|
||||||
maybe_edges[tuple(sorted(k))].extend(v)
|
maybe_edges[tuple(sorted(k))].extend(v)
|
||||||
sum_token_count += token_count
|
sum_token_count += token_count
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
if self.callback:
|
if self.callback:
|
||||||
self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.")
|
self.callback(msg=f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now - start_ts:.2f}s.")
|
||||||
start_ts = now
|
start_ts = now
|
||||||
|
|
@ -176,14 +185,22 @@ class Extractor:
|
||||||
if task_id and has_canceled(task_id):
|
if task_id and has_canceled(task_id):
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
|
raise TaskCanceledException(f"Task {task_id} was cancelled before nodes merging")
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = [
|
||||||
for en_nm, ents in maybe_nodes.items():
|
asyncio.create_task(self._merge_nodes(en_nm, ents, all_entities_data, task_id))
|
||||||
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data, task_id)
|
for en_nm, ents in maybe_nodes.items()
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
if task_id and has_canceled(task_id):
|
if task_id and has_canceled(task_id):
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
|
raise TaskCanceledException(f"Task {task_id} was cancelled after nodes merging")
|
||||||
|
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
if self.callback:
|
if self.callback:
|
||||||
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
|
self.callback(msg=f"Entities merging done, {now - start_ts:.2f}s.")
|
||||||
|
|
||||||
|
|
@ -194,14 +211,25 @@ class Extractor:
|
||||||
if task_id and has_canceled(task_id):
|
if task_id and has_canceled(task_id):
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
|
raise TaskCanceledException(f"Task {task_id} was cancelled before relationships merging")
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = []
|
||||||
for (src, tgt), rels in maybe_edges.items():
|
for (src, tgt), rels in maybe_edges.items():
|
||||||
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data, task_id)
|
tasks.append(
|
||||||
|
asyncio.create_task(
|
||||||
|
self._merge_edges(src, tgt, rels, all_relationships_data, task_id)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
if task_id and has_canceled(task_id):
|
if task_id and has_canceled(task_id):
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
|
raise TaskCanceledException(f"Task {task_id} was cancelled after relationships merging")
|
||||||
|
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
if self.callback:
|
if self.callback:
|
||||||
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
|
self.callback(msg=f"Relationships merging done, {now - start_ts:.2f}s.")
|
||||||
|
|
||||||
|
|
@ -309,5 +337,5 @@ class Extractor:
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
|
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
|
||||||
|
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
summary = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
summary = await asyncio.to_thread(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||||
return summary
|
return summary
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,11 @@ Reference:
|
||||||
- [graphrag](https://github.com/microsoft/graphrag)
|
- [graphrag](https://github.com/microsoft/graphrag)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import trio
|
|
||||||
|
|
||||||
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
|
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||||
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
||||||
|
|
@ -107,7 +107,7 @@ class GraphExtractor(Extractor):
|
||||||
}
|
}
|
||||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
response = await trio.to_thread.run_sync(self._chat, hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id)
|
response = await asyncio.to_thread(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
|
||||||
token_count += num_tokens_from_string(hint_prompt + response)
|
token_count += num_tokens_from_string(hint_prompt + response)
|
||||||
|
|
||||||
results = response or ""
|
results = response or ""
|
||||||
|
|
@ -117,7 +117,7 @@ class GraphExtractor(Extractor):
|
||||||
for i in range(self._max_gleanings):
|
for i in range(self._max_gleanings):
|
||||||
history.append({"role": "user", "content": CONTINUE_PROMPT})
|
history.append({"role": "user", "content": CONTINUE_PROMPT})
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
response = await trio.to_thread.run_sync(lambda: self._chat("", history, {}))
|
response = await asyncio.to_thread(self._chat, "", history, {})
|
||||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||||
results += response or ""
|
results += response or ""
|
||||||
|
|
||||||
|
|
@ -127,7 +127,7 @@ class GraphExtractor(Extractor):
|
||||||
history.append({"role": "assistant", "content": response})
|
history.append({"role": "assistant", "content": response})
|
||||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history))
|
continuation = await asyncio.to_thread(self._chat, "", history)
|
||||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||||
if continuation != "Y":
|
if continuation != "Y":
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -13,12 +13,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import trio
|
|
||||||
|
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.task_service import has_canceled
|
from api.db.services.task_service import has_canceled
|
||||||
|
|
@ -54,25 +54,35 @@ async def run_graphrag(
|
||||||
callback,
|
callback,
|
||||||
):
|
):
|
||||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
start = trio.current_time()
|
start = asyncio.get_running_loop().time()
|
||||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||||
chunks = []
|
chunks = []
|
||||||
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||||
chunks.append(d["content_with_weight"])
|
chunks.append(d["content_with_weight"])
|
||||||
|
|
||||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||||
subgraph = await generate_subgraph(
|
|
||||||
LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general" else GeneralKGExt,
|
try:
|
||||||
tenant_id,
|
subgraph = await asyncio.wait_for(
|
||||||
kb_id,
|
generate_subgraph(
|
||||||
doc_id,
|
LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {})
|
||||||
chunks,
|
or row["kb_parser_config"]["graphrag"]["method"] != "general"
|
||||||
language,
|
else GeneralKGExt,
|
||||||
row["kb_parser_config"]["graphrag"].get("entity_types", []),
|
tenant_id,
|
||||||
chat_model,
|
kb_id,
|
||||||
embedding_model,
|
doc_id,
|
||||||
callback,
|
chunks,
|
||||||
|
language,
|
||||||
|
row["kb_parser_config"]["graphrag"].get("entity_types", []),
|
||||||
|
chat_model,
|
||||||
|
embedding_model,
|
||||||
|
callback,
|
||||||
|
),
|
||||||
|
timeout=timeout_sec,
|
||||||
)
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logging.error("generate_subgraph timeout")
|
||||||
|
raise
|
||||||
|
|
||||||
if not subgraph:
|
if not subgraph:
|
||||||
return
|
return
|
||||||
|
|
@ -125,7 +135,7 @@ async def run_graphrag(
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
graphrag_task_lock.release()
|
graphrag_task_lock.release()
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
|
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -145,7 +155,7 @@ async def run_graphrag_for_kb(
|
||||||
) -> dict:
|
) -> dict:
|
||||||
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
|
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
|
||||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
start = trio.current_time()
|
start = asyncio.get_running_loop().time()
|
||||||
fields_for_chunks = ["content_with_weight", "doc_id"]
|
fields_for_chunks = ["content_with_weight", "doc_id"]
|
||||||
|
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
|
|
@ -211,7 +221,7 @@ async def run_graphrag_for_kb(
|
||||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
||||||
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
||||||
|
|
||||||
semaphore = trio.Semaphore(max_parallel_docs)
|
semaphore = asyncio.Semaphore(max_parallel_docs)
|
||||||
|
|
||||||
subgraphs: dict[str, object] = {}
|
subgraphs: dict[str, object] = {}
|
||||||
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
||||||
|
|
@ -234,20 +244,28 @@ async def run_graphrag_for_kb(
|
||||||
try:
|
try:
|
||||||
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
|
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
|
||||||
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
|
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
|
||||||
with trio.fail_after(deadline):
|
|
||||||
sg = await generate_subgraph(
|
try:
|
||||||
kg_extractor,
|
sg = await asyncio.wait_for(
|
||||||
tenant_id,
|
generate_subgraph(
|
||||||
kb_id,
|
kg_extractor,
|
||||||
doc_id,
|
tenant_id,
|
||||||
chunks,
|
kb_id,
|
||||||
language,
|
doc_id,
|
||||||
kb_parser_config.get("graphrag", {}).get("entity_types", []),
|
chunks,
|
||||||
chat_model,
|
language,
|
||||||
embedding_model,
|
kb_parser_config.get("graphrag", {}).get("entity_types", []),
|
||||||
callback,
|
chat_model,
|
||||||
task_id=row["id"]
|
embedding_model,
|
||||||
|
callback,
|
||||||
|
task_id=row["id"]
|
||||||
|
),
|
||||||
|
timeout=deadline,
|
||||||
)
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
failed_docs.append((doc_id, "timeout"))
|
||||||
|
callback(msg=f"{msg} FAILED: timeout")
|
||||||
|
return
|
||||||
if sg:
|
if sg:
|
||||||
subgraphs[doc_id] = sg
|
subgraphs[doc_id] = sg
|
||||||
callback(msg=f"{msg} done")
|
callback(msg=f"{msg} done")
|
||||||
|
|
@ -264,9 +282,14 @@ async def run_graphrag_for_kb(
|
||||||
callback(msg=f"Task {row['id']} cancelled before processing documents.")
|
callback(msg=f"Task {row['id']} cancelled before processing documents.")
|
||||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = [asyncio.create_task(build_one(doc_id)) for doc_id in doc_ids]
|
||||||
for doc_id in doc_ids:
|
try:
|
||||||
nursery.start_soon(build_one, doc_id)
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
if has_canceled(row["id"]):
|
if has_canceled(row["id"]):
|
||||||
callback(msg=f"Task {row['id']} cancelled after document processing.")
|
callback(msg=f"Task {row['id']} cancelled after document processing.")
|
||||||
|
|
@ -275,7 +298,7 @@ async def run_graphrag_for_kb(
|
||||||
ok_docs = [d for d in doc_ids if d in subgraphs]
|
ok_docs = [d for d in doc_ids if d in subgraphs]
|
||||||
if not ok_docs:
|
if not ok_docs:
|
||||||
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
|
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||||
|
|
||||||
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
|
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
|
||||||
|
|
@ -313,7 +336,7 @@ async def run_graphrag_for_kb(
|
||||||
kb_lock.release()
|
kb_lock.release()
|
||||||
|
|
||||||
if not with_resolution and not with_community:
|
if not with_resolution and not with_community:
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
||||||
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
||||||
|
|
||||||
|
|
@ -356,7 +379,7 @@ async def run_graphrag_for_kb(
|
||||||
finally:
|
finally:
|
||||||
kb_lock.release()
|
kb_lock.release()
|
||||||
|
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
|
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
|
||||||
return {
|
return {
|
||||||
"ok_docs": ok_docs,
|
"ok_docs": ok_docs,
|
||||||
|
|
@ -388,7 +411,7 @@ async def generate_subgraph(
|
||||||
if contains:
|
if contains:
|
||||||
callback(msg=f"Graph already contains {doc_id}")
|
callback(msg=f"Graph already contains {doc_id}")
|
||||||
return None
|
return None
|
||||||
start = trio.current_time()
|
start = asyncio.get_running_loop().time()
|
||||||
ext = extractor(
|
ext = extractor(
|
||||||
llm_bdl,
|
llm_bdl,
|
||||||
language=language,
|
language=language,
|
||||||
|
|
@ -436,9 +459,9 @@ async def generate_subgraph(
|
||||||
"removed_kwd": "N",
|
"removed_kwd": "N",
|
||||||
}
|
}
|
||||||
cid = chunk_id(chunk)
|
cid = chunk_id(chunk)
|
||||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
|
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
|
||||||
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
|
await asyncio.to_thread(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||||
return subgraph
|
return subgraph
|
||||||
|
|
||||||
|
|
@ -452,7 +475,7 @@ async def merge_subgraph(
|
||||||
embedding_model,
|
embedding_model,
|
||||||
callback,
|
callback,
|
||||||
):
|
):
|
||||||
start = trio.current_time()
|
start = asyncio.get_running_loop().time()
|
||||||
change = GraphChange()
|
change = GraphChange()
|
||||||
old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"])
|
old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"])
|
||||||
if old_graph is not None:
|
if old_graph is not None:
|
||||||
|
|
@ -468,7 +491,7 @@ async def merge_subgraph(
|
||||||
new_graph.nodes[node_name]["pagerank"] = pagerank
|
new_graph.nodes[node_name]["pagerank"] = pagerank
|
||||||
|
|
||||||
await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
|
await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.")
|
callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.")
|
||||||
return new_graph
|
return new_graph
|
||||||
|
|
||||||
|
|
@ -490,7 +513,7 @@ async def resolve_entities(
|
||||||
callback(msg=f"Task {task_id} cancelled during entity resolution.")
|
callback(msg=f"Task {task_id} cancelled during entity resolution.")
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||||
|
|
||||||
start = trio.current_time()
|
start = asyncio.get_running_loop().time()
|
||||||
er = EntityResolution(
|
er = EntityResolution(
|
||||||
llm_bdl,
|
llm_bdl,
|
||||||
)
|
)
|
||||||
|
|
@ -505,7 +528,7 @@ async def resolve_entities(
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||||
|
|
||||||
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
|
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -524,7 +547,7 @@ async def extract_community(
|
||||||
callback(msg=f"Task {task_id} cancelled before community extraction.")
|
callback(msg=f"Task {task_id} cancelled before community extraction.")
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||||
|
|
||||||
start = trio.current_time()
|
start = asyncio.get_running_loop().time()
|
||||||
ext = CommunityReportsExtractor(
|
ext = CommunityReportsExtractor(
|
||||||
llm_bdl,
|
llm_bdl,
|
||||||
)
|
)
|
||||||
|
|
@ -538,7 +561,7 @@ async def extract_community(
|
||||||
community_reports = cr.output
|
community_reports = cr.output
|
||||||
doc_ids = graph.graph["source_id"]
|
doc_ids = graph.graph["source_id"]
|
||||||
|
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
||||||
start = now
|
start = now
|
||||||
if task_id and has_canceled(task_id):
|
if task_id and has_canceled(task_id):
|
||||||
|
|
@ -568,16 +591,10 @@ async def extract_community(
|
||||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
await trio.to_thread.run_sync(
|
await asyncio.to_thread(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
|
||||||
lambda: settings.docStoreConn.delete(
|
|
||||||
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
|
|
||||||
search.index_name(tenant_id),
|
|
||||||
kb_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
es_bulk_size = 4
|
es_bulk_size = 4
|
||||||
for b in range(0, len(chunks), es_bulk_size):
|
for b in range(0, len(chunks), es_bulk_size):
|
||||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
|
||||||
if doc_store_result:
|
if doc_store_result:
|
||||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||||
raise Exception(error_message)
|
raise Exception(error_message)
|
||||||
|
|
@ -586,6 +603,6 @@ async def extract_community(
|
||||||
callback(msg=f"Task {task_id} cancelled after community indexing.")
|
callback(msg=f"Task {task_id} cancelled after community indexing.")
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||||
|
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
||||||
return community_structure, community_reports
|
return community_structure, community_reports
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import collections
|
import collections
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import trio
|
|
||||||
|
|
||||||
from graphrag.general.extractor import Extractor
|
from graphrag.general.extractor import Extractor
|
||||||
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
||||||
|
|
@ -89,17 +89,29 @@ class MindMapExtractor(Extractor):
|
||||||
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
||||||
texts = []
|
texts = []
|
||||||
cnt = 0
|
cnt = 0
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = []
|
||||||
for i in range(len(sections)):
|
for i in range(len(sections)):
|
||||||
section_cnt = num_tokens_from_string(sections[i])
|
section_cnt = num_tokens_from_string(sections[i])
|
||||||
if cnt + section_cnt >= token_count and texts:
|
if cnt + section_cnt >= token_count and texts:
|
||||||
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
|
tasks.append(asyncio.create_task(
|
||||||
texts = []
|
self._process_document("".join(texts), prompt_variables, res)
|
||||||
cnt = 0
|
))
|
||||||
texts.append(sections[i])
|
texts = []
|
||||||
cnt += section_cnt
|
cnt = 0
|
||||||
if texts:
|
|
||||||
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
|
texts.append(sections[i])
|
||||||
|
cnt += section_cnt
|
||||||
|
if texts:
|
||||||
|
tasks.append(asyncio.create_task(
|
||||||
|
self._process_document("".join(texts), prompt_variables, res)
|
||||||
|
))
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
if not res:
|
if not res:
|
||||||
return MindMapResult(output={"id": "root", "children": []})
|
return MindMapResult(output={"id": "root", "children": []})
|
||||||
merge_json = reduce(self._merge, res)
|
merge_json = reduce(self._merge, res)
|
||||||
|
|
@ -172,7 +184,7 @@ class MindMapExtractor(Extractor):
|
||||||
}
|
}
|
||||||
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], {}))
|
response = await asyncio.to_thread(self._chat,text,[{"role": "user", "content": "Output:"}],{})
|
||||||
response = re.sub(r"```[^\n]*", "", response)
|
response = re.sub(r"```[^\n]*", "", response)
|
||||||
logging.debug(response)
|
logging.debug(response)
|
||||||
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,10 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import trio
|
|
||||||
|
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
|
|
@ -107,4 +107,4 @@ async def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
trio.run(main)
|
asyncio.run(main)
|
||||||
|
|
|
||||||
|
|
@ -5,13 +5,13 @@ Reference:
|
||||||
- [graphrag](https://github.com/microsoft/graphrag)
|
- [graphrag](https://github.com/microsoft/graphrag)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import trio
|
|
||||||
|
|
||||||
from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor
|
from graphrag.general.extractor import ENTITY_EXTRACTION_MAX_GLEANINGS, Extractor
|
||||||
from graphrag.light.graph_prompt import PROMPTS
|
from graphrag.light.graph_prompt import PROMPTS
|
||||||
|
|
@ -86,13 +86,12 @@ class GraphExtractor(Extractor):
|
||||||
if self.callback:
|
if self.callback:
|
||||||
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
|
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
final_result = await trio.to_thread.run_sync(self._chat, "", [{"role": "user", "content": hint_prompt}], gen_conf, task_id)
|
final_result = await asyncio.to_thread(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
|
||||||
token_count += num_tokens_from_string(hint_prompt + final_result)
|
token_count += num_tokens_from_string(hint_prompt + final_result)
|
||||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
|
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
|
||||||
for now_glean_index in range(self._max_gleanings):
|
for now_glean_index in range(self._max_gleanings):
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
# glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf))
|
glean_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
|
||||||
glean_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
|
|
||||||
history.extend([{"role": "assistant", "content": glean_result}])
|
history.extend([{"role": "assistant", "content": glean_result}])
|
||||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
||||||
final_result += glean_result
|
final_result += glean_result
|
||||||
|
|
@ -101,7 +100,7 @@ class GraphExtractor(Extractor):
|
||||||
|
|
||||||
history.extend([{"role": "user", "content": self._if_loop_prompt}])
|
history.extend([{"role": "user", "content": self._if_loop_prompt}])
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
if_loop_result = await trio.to_thread.run_sync(self._chat, "", history, gen_conf, task_id)
|
if_loop_result = await asyncio.to_thread(self._chat,"",history,gen_conf,task_id)
|
||||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
||||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||||
if if_loop_result != "yes":
|
if if_loop_result != "yes":
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,10 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import logging
|
import logging
|
||||||
import trio
|
|
||||||
|
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
|
|
@ -93,4 +93,4 @@ async def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
trio.run(main)
|
asyncio.run(main)
|
||||||
|
|
|
||||||
|
|
@ -13,13 +13,13 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import json_repair
|
import json_repair
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import trio
|
|
||||||
|
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from graphrag.query_analyze_prompt import PROMPTS
|
from graphrag.query_analyze_prompt import PROMPTS
|
||||||
|
|
@ -44,7 +44,7 @@ class KGSearch(Dealer):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def query_rewrite(self, llm, question, idxnms, kb_ids):
|
def query_rewrite(self, llm, question, idxnms, kb_ids):
|
||||||
ty2ents = trio.run(lambda: get_entity_type2samples(idxnms, kb_ids))
|
ty2ents = asyncio.run(get_entity_type2samples(idxnms, kb_ids))
|
||||||
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
|
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
|
||||||
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
|
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
|
||||||
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})
|
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ Reference:
|
||||||
- [LightRag](https://github.com/HKUDS/LightRAG)
|
- [LightRag](https://github.com/HKUDS/LightRAG)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import html
|
import html
|
||||||
import json
|
import json
|
||||||
|
|
@ -19,7 +20,6 @@ from typing import Any, Callable, Set, Tuple
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import trio
|
|
||||||
import xxhash
|
import xxhash
|
||||||
from networkx.readwrite import json_graph
|
from networkx.readwrite import json_graph
|
||||||
|
|
||||||
|
|
@ -34,7 +34,7 @@ GRAPH_FIELD_SEP = "<SEP>"
|
||||||
|
|
||||||
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
||||||
|
|
||||||
chat_limiter = trio.CapacityLimiter(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
chat_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|
@ -314,8 +314,11 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
|
||||||
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
|
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
|
||||||
if ebd is None:
|
if ebd is None:
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
|
timeout = 3 if enable_timeout_assertion else 30000000
|
||||||
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
|
ebd, _ = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(embd_mdl.encode, [ent_name]),
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
ebd = ebd[0]
|
ebd = ebd[0]
|
||||||
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
|
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
|
||||||
assert ebd is not None
|
assert ebd is not None
|
||||||
|
|
@ -365,8 +368,14 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
|
||||||
ebd = get_embed_cache(embd_mdl.llm_name, txt)
|
ebd = get_embed_cache(embd_mdl.llm_name, txt)
|
||||||
if ebd is None:
|
if ebd is None:
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
with trio.fail_after(3 if enable_timeout_assertion else 300000000):
|
timeout = 3 if enable_timeout_assertion else 300000000
|
||||||
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt + f": {meta['description']}"]))
|
ebd, _ = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(
|
||||||
|
embd_mdl.encode,
|
||||||
|
[txt + f": {meta['description']}"]
|
||||||
|
),
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
ebd = ebd[0]
|
ebd = ebd[0]
|
||||||
set_embed_cache(embd_mdl.llm_name, txt, ebd)
|
set_embed_cache(embd_mdl.llm_name, txt, ebd)
|
||||||
assert ebd is not None
|
assert ebd is not None
|
||||||
|
|
@ -381,7 +390,11 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||||
"knowledge_graph_kwd": ["graph"],
|
"knowledge_graph_kwd": ["graph"],
|
||||||
"removed_kwd": "N",
|
"removed_kwd": "N",
|
||||||
}
|
}
|
||||||
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
res = await asyncio.to_thread(
|
||||||
|
settings.docStoreConn.search,
|
||||||
|
fields, [], condition, [], OrderByExpr(),
|
||||||
|
0, 1, search.index_name(tenant_id), [kb_id]
|
||||||
|
)
|
||||||
fields2 = settings.docStoreConn.get_fields(res, fields)
|
fields2 = settings.docStoreConn.get_fields(res, fields)
|
||||||
graph_doc_ids = set()
|
graph_doc_ids = set()
|
||||||
for chunk_id in fields2.keys():
|
for chunk_id in fields2.keys():
|
||||||
|
|
@ -391,7 +404,12 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||||
|
|
||||||
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||||
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
|
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
|
||||||
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
|
res = await asyncio.to_thread(
|
||||||
|
settings.retriever.search,
|
||||||
|
conds,
|
||||||
|
search.index_name(tenant_id),
|
||||||
|
[kb_id]
|
||||||
|
)
|
||||||
doc_ids = []
|
doc_ids = []
|
||||||
if res.total == 0:
|
if res.total == 0:
|
||||||
return doc_ids
|
return doc_ids
|
||||||
|
|
@ -402,7 +420,12 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||||
|
|
||||||
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
|
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||||
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
|
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
|
||||||
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id])
|
res = await asyncio.to_thread(
|
||||||
|
settings.retriever.search,
|
||||||
|
conds,
|
||||||
|
search.index_name(tenant_id),
|
||||||
|
[kb_id]
|
||||||
|
)
|
||||||
if not res.total == 0:
|
if not res.total == 0:
|
||||||
for id in res.ids:
|
for id in res.ids:
|
||||||
try:
|
try:
|
||||||
|
|
@ -421,26 +444,47 @@ async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||||
|
|
||||||
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
|
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
|
||||||
global chat_limiter
|
global chat_limiter
|
||||||
start = trio.current_time()
|
start = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
|
await asyncio.to_thread(
|
||||||
|
settings.docStoreConn.delete,
|
||||||
|
{"knowledge_graph_kwd": ["graph", "subgraph"]},
|
||||||
|
search.index_name(tenant_id),
|
||||||
|
kb_id
|
||||||
|
)
|
||||||
|
|
||||||
if change.removed_nodes:
|
if change.removed_nodes:
|
||||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
|
await asyncio.to_thread(
|
||||||
|
settings.docStoreConn.delete,
|
||||||
|
{"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)},
|
||||||
|
search.index_name(tenant_id),
|
||||||
|
kb_id
|
||||||
|
)
|
||||||
|
|
||||||
if change.removed_edges:
|
if change.removed_edges:
|
||||||
|
|
||||||
async def del_edges(from_node, to_node):
|
async def del_edges(from_node, to_node):
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
await trio.to_thread.run_sync(
|
await asyncio.to_thread(
|
||||||
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
|
settings.docStoreConn.delete,
|
||||||
|
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
|
||||||
|
search.index_name(tenant_id),
|
||||||
|
kb_id
|
||||||
)
|
)
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = []
|
||||||
for from_node, to_node in change.removed_edges:
|
for from_node, to_node in change.removed_edges:
|
||||||
nursery.start_soon(del_edges, from_node, to_node)
|
tasks.append(asyncio.create_task(del_edges(from_node, to_node)))
|
||||||
|
|
||||||
now = trio.current_time()
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
now = asyncio.get_running_loop().time()
|
||||||
if callback:
|
if callback:
|
||||||
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
|
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
|
||||||
start = now
|
start = now
|
||||||
|
|
@ -475,24 +519,41 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = []
|
||||||
for ii, node in enumerate(change.added_updated_nodes):
|
for ii, node in enumerate(change.added_updated_nodes):
|
||||||
node_attrs = graph.nodes[node]
|
node_attrs = graph.nodes[node]
|
||||||
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
|
tasks.append(asyncio.create_task(
|
||||||
if ii % 100 == 9 and callback:
|
graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks)
|
||||||
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
|
))
|
||||||
|
if ii % 100 == 9 and callback:
|
||||||
|
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = []
|
||||||
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
|
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
|
||||||
edge_attrs = graph.get_edge_data(from_node, to_node)
|
edge_attrs = graph.get_edge_data(from_node, to_node)
|
||||||
if not edge_attrs:
|
if not edge_attrs:
|
||||||
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
|
continue
|
||||||
continue
|
tasks.append(asyncio.create_task(
|
||||||
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
|
graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
|
||||||
if ii % 100 == 9 and callback:
|
))
|
||||||
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
|
if ii % 100 == 9 and callback:
|
||||||
|
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
if callback:
|
if callback:
|
||||||
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
|
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
|
||||||
start = now
|
start = now
|
||||||
|
|
@ -500,14 +561,22 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||||
es_bulk_size = 4
|
es_bulk_size = 4
|
||||||
for b in range(0, len(chunks), es_bulk_size):
|
for b in range(0, len(chunks), es_bulk_size):
|
||||||
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
|
timeout = 3 if enable_timeout_assertion else 30000000
|
||||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
doc_store_result = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(
|
||||||
|
settings.docStoreConn.insert,
|
||||||
|
chunks[b : b + es_bulk_size],
|
||||||
|
search.index_name(tenant_id),
|
||||||
|
kb_id
|
||||||
|
),
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
if b % 100 == es_bulk_size and callback:
|
if b % 100 == es_bulk_size and callback:
|
||||||
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
|
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
|
||||||
if doc_store_result:
|
if doc_store_result:
|
||||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||||
raise Exception(error_message)
|
raise Exception(error_message)
|
||||||
now = trio.current_time()
|
now = asyncio.get_running_loop().time()
|
||||||
if callback:
|
if callback:
|
||||||
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
|
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
|
||||||
|
|
||||||
|
|
@ -555,7 +624,7 @@ def merge_tuples(list1, list2):
|
||||||
|
|
||||||
|
|
||||||
async def get_entity_type2samples(idxnms, kb_ids: list):
|
async def get_entity_type2samples(idxnms, kb_ids: list):
|
||||||
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
|
es_res = await asyncio.to_thread(settings.retriever.search,{"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
|
||||||
|
|
||||||
res = defaultdict(list)
|
res = defaultdict(list)
|
||||||
for id in es_res.ids:
|
for id in es_res.ids:
|
||||||
|
|
@ -588,8 +657,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||||
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
|
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
|
||||||
bs = 256
|
bs = 256
|
||||||
for i in range(0, 1024 * bs, bs):
|
for i in range(0, 1024 * bs, bs):
|
||||||
es_res = await trio.to_thread.run_sync(
|
es_res = await asyncio.to_thread(
|
||||||
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
|
settings.docStoreConn.search,
|
||||||
|
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
|
||||||
|
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]
|
||||||
)
|
)
|
||||||
# tot = settings.docStoreConn.get_total(es_res)
|
# tot = settings.docStoreConn.get_total(es_res)
|
||||||
es_res = settings.docStoreConn.get_fields(es_res, flds)
|
es_res = settings.docStoreConn.get_fields(es_res, flds)
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import trio
|
|
||||||
import umap
|
import umap
|
||||||
from sklearn.mixture import GaussianMixture
|
from sklearn.mixture import GaussianMixture
|
||||||
|
|
||||||
|
|
@ -56,37 +56,37 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||||
|
|
||||||
@timeout(60 * 20)
|
@timeout(60 * 20)
|
||||||
async def _chat(self, system, history, gen_conf):
|
async def _chat(self, system, history, gen_conf):
|
||||||
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
|
cached = await asyncio.to_thread(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
|
||||||
if cached:
|
if cached:
|
||||||
return cached
|
return cached
|
||||||
|
|
||||||
last_exc = None
|
last_exc = None
|
||||||
for attempt in range(3):
|
for attempt in range(3):
|
||||||
try:
|
try:
|
||||||
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
|
response = await asyncio.to_thread(self._llm_model.chat, system, history, gen_conf)
|
||||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||||
if response.find("**ERROR**") >= 0:
|
if response.find("**ERROR**") >= 0:
|
||||||
raise Exception(response)
|
raise Exception(response)
|
||||||
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
|
await asyncio.to_thread(set_llm_cache,self._llm_model.llm_name,system,response,history,gen_conf)
|
||||||
return response
|
return response
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
last_exc = exc
|
last_exc = exc
|
||||||
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
|
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
|
||||||
if attempt < 2:
|
if attempt < 2:
|
||||||
await trio.sleep(1 + attempt)
|
await asyncio.sleep(1 + attempt)
|
||||||
|
|
||||||
raise last_exc if last_exc else Exception("LLM chat failed without exception")
|
raise last_exc if last_exc else Exception("LLM chat failed without exception")
|
||||||
|
|
||||||
@timeout(20)
|
@timeout(20)
|
||||||
async def _embedding_encode(self, txt):
|
async def _embedding_encode(self, txt):
|
||||||
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
|
response = await asyncio.to_thread(get_embed_cache, self._embd_model.llm_name, txt)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
return response
|
return response
|
||||||
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
|
embds, _ = await asyncio.to_thread(self._embd_model.encode, [txt])
|
||||||
if len(embds) < 1 or len(embds[0]) < 1:
|
if len(embds) < 1 or len(embds[0]) < 1:
|
||||||
raise Exception("Embedding error: ")
|
raise Exception("Embedding error: ")
|
||||||
embds = embds[0]
|
embds = embds[0]
|
||||||
await trio.to_thread.run_sync(lambda: set_embed_cache(self._embd_model.llm_name, txt, embds))
|
await asyncio.to_thread(set_embed_cache, self._embd_model.llm_name, txt, embds)
|
||||||
return embds
|
return embds
|
||||||
|
|
||||||
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
|
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
|
||||||
|
|
@ -198,16 +198,21 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||||
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
||||||
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
|
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
tasks = []
|
||||||
for c in range(n_clusters):
|
for c in range(n_clusters):
|
||||||
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
|
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
|
||||||
assert len(ck_idx) > 0
|
assert len(ck_idx) > 0
|
||||||
|
if task_id and has_canceled(task_id):
|
||||||
if task_id and has_canceled(task_id):
|
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
|
||||||
logging.info(f"Task {task_id} cancelled before RAPTOR cluster processing.")
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
tasks.append(asyncio.create_task(summarize(ck_idx)))
|
||||||
|
try:
|
||||||
nursery.start_soon(summarize, ck_idx)
|
await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
except Exception as e:
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
raise
|
||||||
|
|
||||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
||||||
labels.extend(lbls)
|
labels.extend(lbls)
|
||||||
|
|
|
||||||
|
|
@ -920,7 +920,7 @@ async def do_handle_task(task):
|
||||||
file_type = task.get("type", "")
|
file_type = task.get("type", "")
|
||||||
parser_id = task.get("parser_id", "")
|
parser_id = task.get("parser_id", "")
|
||||||
raptor_config = kb_parser_config.get("raptor", {})
|
raptor_config = kb_parser_config.get("raptor", {})
|
||||||
|
|
||||||
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
|
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
|
||||||
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
|
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
|
||||||
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")
|
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue