Remove ascii_colors dependency and fix stream handling errors
• Remove ascii_colors.trace_exception calls
• Add SafeStreamHandler for closed streams
• Patch ascii_colors console handler
• Prevent ValueError on stream close
• Improve logging error handling
(cherry picked from commit 0fb2925c6a)
This commit is contained in:
parent
fd76e0f7ce
commit
322ff19f72
3 changed files with 305 additions and 275 deletions
|
|
@ -8,7 +8,6 @@ import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import asyncio
|
import asyncio
|
||||||
from ascii_colors import trace_exception
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.utils import TiktokenTokenizer
|
from lightrag.utils import TiktokenTokenizer
|
||||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||||
|
|
@ -335,118 +334,113 @@ class OllamaAPI:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
try:
|
first_chunk_time = None
|
||||||
first_chunk_time = None
|
last_chunk_time = time.time_ns()
|
||||||
|
total_response = ""
|
||||||
|
|
||||||
|
# Ensure response is an async generator
|
||||||
|
if isinstance(response, str):
|
||||||
|
# If it's a string, send in two parts
|
||||||
|
first_chunk_time = start_time
|
||||||
last_chunk_time = time.time_ns()
|
last_chunk_time = time.time_ns()
|
||||||
total_response = ""
|
total_response = response
|
||||||
|
|
||||||
# Ensure response is an async generator
|
data = {
|
||||||
if isinstance(response, str):
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
# If it's a string, send in two parts
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
first_chunk_time = start_time
|
"response": response,
|
||||||
last_chunk_time = time.time_ns()
|
"done": False,
|
||||||
total_response = response
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
data = {
|
completion_tokens = estimate_tokens(total_response)
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
|
"response": "",
|
||||||
|
"done": True,
|
||||||
|
"done_reason": "stop",
|
||||||
|
"context": [],
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
async for chunk in response:
|
||||||
|
if chunk:
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
total_response += chunk
|
||||||
|
data = {
|
||||||
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
|
"response": chunk,
|
||||||
|
"done": False,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
except (asyncio.CancelledError, Exception) as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
if isinstance(e, asyncio.CancelledError):
|
||||||
|
error_msg = "Stream was cancelled by server"
|
||||||
|
else:
|
||||||
|
error_msg = f"Provider error: {error_msg}"
|
||||||
|
|
||||||
|
logger.error(f"Stream error: {error_msg}")
|
||||||
|
|
||||||
|
# Send error message to client
|
||||||
|
error_data = {
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
"response": response,
|
"response": f"\n\nError: {error_msg}",
|
||||||
|
"error": f"\n\nError: {error_msg}",
|
||||||
"done": False,
|
"done": False,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
completion_tokens = estimate_tokens(total_response)
|
# Send final message to close the stream
|
||||||
total_time = last_chunk_time - start_time
|
final_data = {
|
||||||
prompt_eval_time = first_chunk_time - start_time
|
|
||||||
eval_time = last_chunk_time - first_chunk_time
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
"response": "",
|
"response": "",
|
||||||
"done": True,
|
"done": True,
|
||||||
"done_reason": "stop",
|
|
||||||
"context": [],
|
|
||||||
"total_duration": total_time,
|
|
||||||
"load_duration": 0,
|
|
||||||
"prompt_eval_count": prompt_tokens,
|
|
||||||
"prompt_eval_duration": prompt_eval_time,
|
|
||||||
"eval_count": completion_tokens,
|
|
||||||
"eval_duration": eval_time,
|
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
||||||
else:
|
|
||||||
try:
|
|
||||||
async for chunk in response:
|
|
||||||
if chunk:
|
|
||||||
if first_chunk_time is None:
|
|
||||||
first_chunk_time = time.time_ns()
|
|
||||||
|
|
||||||
last_chunk_time = time.time_ns()
|
|
||||||
|
|
||||||
total_response += chunk
|
|
||||||
data = {
|
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
||||||
"response": chunk,
|
|
||||||
"done": False,
|
|
||||||
}
|
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
|
||||||
except (asyncio.CancelledError, Exception) as e:
|
|
||||||
error_msg = str(e)
|
|
||||||
if isinstance(e, asyncio.CancelledError):
|
|
||||||
error_msg = "Stream was cancelled by server"
|
|
||||||
else:
|
|
||||||
error_msg = f"Provider error: {error_msg}"
|
|
||||||
|
|
||||||
logger.error(f"Stream error: {error_msg}")
|
|
||||||
|
|
||||||
# Send error message to client
|
|
||||||
error_data = {
|
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
||||||
"response": f"\n\nError: {error_msg}",
|
|
||||||
"error": f"\n\nError: {error_msg}",
|
|
||||||
"done": False,
|
|
||||||
}
|
|
||||||
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
|
||||||
|
|
||||||
# Send final message to close the stream
|
|
||||||
final_data = {
|
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
||||||
"response": "",
|
|
||||||
"done": True,
|
|
||||||
}
|
|
||||||
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
|
||||||
return
|
|
||||||
if first_chunk_time is None:
|
|
||||||
first_chunk_time = start_time
|
|
||||||
completion_tokens = estimate_tokens(total_response)
|
|
||||||
total_time = last_chunk_time - start_time
|
|
||||||
prompt_eval_time = first_chunk_time - start_time
|
|
||||||
eval_time = last_chunk_time - first_chunk_time
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
||||||
"response": "",
|
|
||||||
"done": True,
|
|
||||||
"done_reason": "stop",
|
|
||||||
"context": [],
|
|
||||||
"total_duration": total_time,
|
|
||||||
"load_duration": 0,
|
|
||||||
"prompt_eval_count": prompt_tokens,
|
|
||||||
"prompt_eval_duration": prompt_eval_time,
|
|
||||||
"eval_count": completion_tokens,
|
|
||||||
"eval_duration": eval_time,
|
|
||||||
}
|
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
|
||||||
return
|
return
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = start_time
|
||||||
|
completion_tokens = estimate_tokens(total_response)
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
except Exception as e:
|
data = {
|
||||||
trace_exception(e)
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
raise
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
|
"response": "",
|
||||||
|
"done": True,
|
||||||
|
"done_reason": "stop",
|
||||||
|
"context": [],
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
return
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_generator(),
|
stream_generator(),
|
||||||
|
|
@ -488,7 +482,6 @@ class OllamaAPI:
|
||||||
"eval_duration": eval_time,
|
"eval_duration": eval_time,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@self.router.post(
|
@self.router.post(
|
||||||
|
|
@ -558,36 +551,98 @@ class OllamaAPI:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
try:
|
first_chunk_time = None
|
||||||
first_chunk_time = None
|
last_chunk_time = time.time_ns()
|
||||||
|
total_response = ""
|
||||||
|
|
||||||
|
# Ensure response is an async generator
|
||||||
|
if isinstance(response, str):
|
||||||
|
# If it's a string, send in two parts
|
||||||
|
first_chunk_time = start_time
|
||||||
last_chunk_time = time.time_ns()
|
last_chunk_time = time.time_ns()
|
||||||
total_response = ""
|
total_response = response
|
||||||
|
|
||||||
# Ensure response is an async generator
|
data = {
|
||||||
if isinstance(response, str):
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
# If it's a string, send in two parts
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
first_chunk_time = start_time
|
"message": {
|
||||||
last_chunk_time = time.time_ns()
|
"role": "assistant",
|
||||||
total_response = response
|
"content": response,
|
||||||
|
"images": None,
|
||||||
|
},
|
||||||
|
"done": False,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
data = {
|
completion_tokens = estimate_tokens(total_response)
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"images": None,
|
||||||
|
},
|
||||||
|
"done_reason": "stop",
|
||||||
|
"done": True,
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
async for chunk in response:
|
||||||
|
if chunk:
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
total_response += chunk
|
||||||
|
data = {
|
||||||
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": chunk,
|
||||||
|
"images": None,
|
||||||
|
},
|
||||||
|
"done": False,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
except (asyncio.CancelledError, Exception) as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
if isinstance(e, asyncio.CancelledError):
|
||||||
|
error_msg = "Stream was cancelled by server"
|
||||||
|
else:
|
||||||
|
error_msg = f"Provider error: {error_msg}"
|
||||||
|
|
||||||
|
logger.error(f"Stream error: {error_msg}")
|
||||||
|
|
||||||
|
# Send error message to client
|
||||||
|
error_data = {
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": response,
|
"content": f"\n\nError: {error_msg}",
|
||||||
"images": None,
|
"images": None,
|
||||||
},
|
},
|
||||||
|
"error": f"\n\nError: {error_msg}",
|
||||||
"done": False,
|
"done": False,
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
completion_tokens = estimate_tokens(total_response)
|
# Send final message to close the stream
|
||||||
total_time = last_chunk_time - start_time
|
final_data = {
|
||||||
prompt_eval_time = first_chunk_time - start_time
|
|
||||||
eval_time = last_chunk_time - first_chunk_time
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
"message": {
|
"message": {
|
||||||
|
|
@ -595,103 +650,36 @@ class OllamaAPI:
|
||||||
"content": "",
|
"content": "",
|
||||||
"images": None,
|
"images": None,
|
||||||
},
|
},
|
||||||
"done_reason": "stop",
|
|
||||||
"done": True,
|
"done": True,
|
||||||
"total_duration": total_time,
|
|
||||||
"load_duration": 0,
|
|
||||||
"prompt_eval_count": prompt_tokens,
|
|
||||||
"prompt_eval_duration": prompt_eval_time,
|
|
||||||
"eval_count": completion_tokens,
|
|
||||||
"eval_duration": eval_time,
|
|
||||||
}
|
}
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
||||||
else:
|
return
|
||||||
try:
|
|
||||||
async for chunk in response:
|
|
||||||
if chunk:
|
|
||||||
if first_chunk_time is None:
|
|
||||||
first_chunk_time = time.time_ns()
|
|
||||||
|
|
||||||
last_chunk_time = time.time_ns()
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = start_time
|
||||||
|
completion_tokens = estimate_tokens(total_response)
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
total_response += chunk
|
data = {
|
||||||
data = {
|
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
"message": {
|
||||||
"message": {
|
"role": "assistant",
|
||||||
"role": "assistant",
|
"content": "",
|
||||||
"content": chunk,
|
"images": None,
|
||||||
"images": None,
|
},
|
||||||
},
|
"done_reason": "stop",
|
||||||
"done": False,
|
"done": True,
|
||||||
}
|
"total_duration": total_time,
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
"load_duration": 0,
|
||||||
except (asyncio.CancelledError, Exception) as e:
|
"prompt_eval_count": prompt_tokens,
|
||||||
error_msg = str(e)
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
if isinstance(e, asyncio.CancelledError):
|
"eval_count": completion_tokens,
|
||||||
error_msg = "Stream was cancelled by server"
|
"eval_duration": eval_time,
|
||||||
else:
|
}
|
||||||
error_msg = f"Provider error: {error_msg}"
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
logger.error(f"Stream error: {error_msg}")
|
|
||||||
|
|
||||||
# Send error message to client
|
|
||||||
error_data = {
|
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": f"\n\nError: {error_msg}",
|
|
||||||
"images": None,
|
|
||||||
},
|
|
||||||
"error": f"\n\nError: {error_msg}",
|
|
||||||
"done": False,
|
|
||||||
}
|
|
||||||
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
|
||||||
|
|
||||||
# Send final message to close the stream
|
|
||||||
final_data = {
|
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "",
|
|
||||||
"images": None,
|
|
||||||
},
|
|
||||||
"done": True,
|
|
||||||
}
|
|
||||||
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
if first_chunk_time is None:
|
|
||||||
first_chunk_time = start_time
|
|
||||||
completion_tokens = estimate_tokens(total_response)
|
|
||||||
total_time = last_chunk_time - start_time
|
|
||||||
prompt_eval_time = first_chunk_time - start_time
|
|
||||||
eval_time = last_chunk_time - first_chunk_time
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
|
||||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "",
|
|
||||||
"images": None,
|
|
||||||
},
|
|
||||||
"done_reason": "stop",
|
|
||||||
"done": True,
|
|
||||||
"total_duration": total_time,
|
|
||||||
"load_duration": 0,
|
|
||||||
"prompt_eval_count": prompt_tokens,
|
|
||||||
"prompt_eval_duration": prompt_eval_time,
|
|
||||||
"eval_count": completion_tokens,
|
|
||||||
"eval_duration": eval_time,
|
|
||||||
}
|
|
||||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
trace_exception(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_generator(),
|
stream_generator(),
|
||||||
|
|
@ -753,5 +741,4 @@ class OllamaAPI:
|
||||||
"eval_duration": eval_time,
|
"eval_duration": eval_time,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,6 @@ from lightrag.models.tenant import TenantContext
|
||||||
from lightrag.tenant_rag_manager import TenantRAGManager
|
from lightrag.tenant_rag_manager import TenantRAGManager
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from ascii_colors import trace_exception
|
|
||||||
|
|
||||||
router = APIRouter(tags=["query"])
|
router = APIRouter(tags=["query"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -399,7 +397,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60, rag
|
||||||
else:
|
else:
|
||||||
return QueryResponse(response=response_content, references=None)
|
return QueryResponse(response=response_content, references=None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|
@ -650,7 +647,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60, rag
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|
@ -1061,7 +1057,6 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60, rag
|
||||||
data={},
|
data={},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import html
|
import html
|
||||||
import csv
|
import csv
|
||||||
|
|
@ -40,6 +42,35 @@ from lightrag.constants import (
|
||||||
SOURCE_IDS_LIMIT_METHOD_FIFO,
|
SOURCE_IDS_LIMIT_METHOD_FIFO,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
||||||
|
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
||||||
|
|
||||||
|
|
||||||
|
class SafeStreamHandler(logging.StreamHandler):
|
||||||
|
"""StreamHandler that gracefully handles closed streams during shutdown.
|
||||||
|
|
||||||
|
This handler prevents "ValueError: I/O operation on closed file" errors
|
||||||
|
that can occur when pytest or other test frameworks close stdout/stderr
|
||||||
|
before Python's logging cleanup runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
"""Flush the stream, ignoring errors if the stream is closed."""
|
||||||
|
try:
|
||||||
|
super().flush()
|
||||||
|
except (ValueError, OSError):
|
||||||
|
# Stream is closed or otherwise unavailable, silently ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the handler, ignoring errors if the stream is already closed."""
|
||||||
|
try:
|
||||||
|
super().close()
|
||||||
|
except (ValueError, OSError):
|
||||||
|
# Stream is closed or otherwise unavailable, silently ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Initialize logger with basic configuration
|
# Initialize logger with basic configuration
|
||||||
logger = logging.getLogger("lightrag")
|
logger = logging.getLogger("lightrag")
|
||||||
logger.propagate = False # prevent log message send to root logger
|
logger.propagate = False # prevent log message send to root logger
|
||||||
|
|
@ -47,7 +78,7 @@ logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
# Add console handler if no handlers exist
|
# Add console handler if no handlers exist
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = SafeStreamHandler()
|
||||||
console_handler.setLevel(logging.INFO)
|
console_handler.setLevel(logging.INFO)
|
||||||
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
|
|
@ -56,6 +87,33 @@ if not logger.handlers:
|
||||||
# Set httpx logging level to WARNING
|
# Set httpx logging level to WARNING
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_ascii_colors_console_handler() -> None:
|
||||||
|
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ascii_colors import ConsoleHandler
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
if getattr(ConsoleHandler, "_lightrag_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
original_handle_error = ConsoleHandler.handle_error
|
||||||
|
|
||||||
|
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
|
||||||
|
exc_type, _, _ = sys.exc_info()
|
||||||
|
if exc_type in (ValueError, OSError) and "close" in message.lower():
|
||||||
|
return
|
||||||
|
original_handle_error(self, message)
|
||||||
|
|
||||||
|
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
|
||||||
|
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
_patch_ascii_colors_console_handler()
|
||||||
|
|
||||||
|
|
||||||
# Global import for pypinyin with startup-time logging
|
# Global import for pypinyin with startup-time logging
|
||||||
try:
|
try:
|
||||||
import pypinyin
|
import pypinyin
|
||||||
|
|
@ -283,8 +341,8 @@ def setup_logger(
|
||||||
logger_instance.handlers = [] # Clear existing handlers
|
logger_instance.handlers = [] # Clear existing handlers
|
||||||
logger_instance.propagate = False
|
logger_instance.propagate = False
|
||||||
|
|
||||||
# Add console handler
|
# Add console handler with safe stream handling
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = SafeStreamHandler()
|
||||||
console_handler.setFormatter(simple_formatter)
|
console_handler.setFormatter(simple_formatter)
|
||||||
console_handler.setLevel(level)
|
console_handler.setLevel(level)
|
||||||
logger_instance.addHandler(console_handler)
|
logger_instance.addHandler(console_handler)
|
||||||
|
|
@ -350,9 +408,20 @@ class TaskState:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
|
"""Embedding function wrapper with dimension validation
|
||||||
|
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
|
||||||
|
This class should not be wrapped multiple times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_dim: Expected dimension of the embeddings
|
||||||
|
func: The actual embedding function to wrap
|
||||||
|
max_token_size: Optional token limit for the embedding model
|
||||||
|
send_dimensions: Whether to inject embedding_dim as a keyword argument
|
||||||
|
"""
|
||||||
|
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
func: callable
|
func: callable
|
||||||
max_token_size: int | None = None # deprecated keep it for compatible only
|
max_token_size: int | None = None # Token limit for the embedding model
|
||||||
send_dimensions: bool = (
|
send_dimensions: bool = (
|
||||||
False # Control whether to send embedding_dim to the function
|
False # Control whether to send embedding_dim to the function
|
||||||
)
|
)
|
||||||
|
|
@ -376,7 +445,32 @@ class EmbeddingFunc:
|
||||||
# Inject embedding_dim from decorator
|
# Inject embedding_dim from decorator
|
||||||
kwargs["embedding_dim"] = self.embedding_dim
|
kwargs["embedding_dim"] = self.embedding_dim
|
||||||
|
|
||||||
return await self.func(*args, **kwargs)
|
# Call the actual embedding function
|
||||||
|
result = await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Validate embedding dimensions using total element count
|
||||||
|
total_elements = result.size # Total number of elements in the numpy array
|
||||||
|
expected_dim = self.embedding_dim
|
||||||
|
|
||||||
|
# Check if total elements can be evenly divided by embedding_dim
|
||||||
|
if total_elements % expected_dim != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Embedding dimension mismatch detected: "
|
||||||
|
f"total elements ({total_elements}) cannot be evenly divided by "
|
||||||
|
f"expected dimension ({expected_dim}). "
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional: Verify vector count matches input text count
|
||||||
|
actual_vectors = total_elements // expected_dim
|
||||||
|
if args and isinstance(args[0], (list, tuple)):
|
||||||
|
expected_vectors = len(args[0])
|
||||||
|
if actual_vectors != expected_vectors:
|
||||||
|
raise ValueError(
|
||||||
|
f"Vector count mismatch: "
|
||||||
|
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args: Any) -> str:
|
def compute_args_hash(*args: Any) -> str:
|
||||||
|
|
@ -930,70 +1024,24 @@ def load_json(file_name):
|
||||||
def _sanitize_string_for_json(text: str) -> str:
|
def _sanitize_string_for_json(text: str) -> str:
|
||||||
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
|
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
|
||||||
|
|
||||||
This is a simpler sanitizer specifically for JSON that directly removes
|
Uses regex for optimal performance with zero-copy optimization for clean strings.
|
||||||
problematic characters without attempting to encode first.
|
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: String to sanitize
|
text: String to sanitize
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Sanitized string safe for UTF-8 encoding in JSON
|
Original string if clean (zero-copy), sanitized string if dirty
|
||||||
"""
|
"""
|
||||||
if not text:
|
if not text:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
# Directly filter out problematic characters without pre-validation
|
# Fast path: Check if sanitization is needed using C-level regex search
|
||||||
sanitized = ""
|
if not _SURROGATE_PATTERN.search(text):
|
||||||
for char in text:
|
return text # Zero-copy for clean strings - most common case
|
||||||
code_point = ord(char)
|
|
||||||
# Skip surrogate characters (U+D800 to U+DFFF) - main cause of encoding errors
|
|
||||||
if 0xD800 <= code_point <= 0xDFFF:
|
|
||||||
continue
|
|
||||||
# Skip other non-characters in Unicode
|
|
||||||
elif code_point == 0xFFFE or code_point == 0xFFFF:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
sanitized += char
|
|
||||||
|
|
||||||
return sanitized
|
# Slow path: Remove problematic characters using C-level regex substitution
|
||||||
|
return _SURROGATE_PATTERN.sub("", text)
|
||||||
|
|
||||||
def _sanitize_json_data(data: Any) -> Any:
|
|
||||||
"""Recursively sanitize all string values in data structure for safe UTF-8 encoding
|
|
||||||
|
|
||||||
DEPRECATED: This function creates a deep copy of the data which can be memory-intensive.
|
|
||||||
For new code, prefer using write_json with SanitizingJSONEncoder which sanitizes during
|
|
||||||
serialization without creating copies.
|
|
||||||
|
|
||||||
Handles all JSON-serializable types including:
|
|
||||||
- Dictionary keys and values
|
|
||||||
- Lists and tuples (preserves type)
|
|
||||||
- Nested structures
|
|
||||||
- Strings at any level
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Data to sanitize (dict, list, tuple, str, or other types)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized data with all strings cleaned of problematic characters
|
|
||||||
"""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
# Sanitize both keys and values
|
|
||||||
return {
|
|
||||||
_sanitize_string_for_json(k)
|
|
||||||
if isinstance(k, str)
|
|
||||||
else k: _sanitize_json_data(v)
|
|
||||||
for k, v in data.items()
|
|
||||||
}
|
|
||||||
elif isinstance(data, (list, tuple)):
|
|
||||||
# Handle both lists and tuples, preserve original type
|
|
||||||
sanitized = [_sanitize_json_data(item) for item in data]
|
|
||||||
return type(data)(sanitized)
|
|
||||||
elif isinstance(data, str):
|
|
||||||
return _sanitize_string_for_json(data)
|
|
||||||
else:
|
|
||||||
# Numbers, booleans, None, etc. - return as-is
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class SanitizingJSONEncoder(json.JSONEncoder):
|
class SanitizingJSONEncoder(json.JSONEncoder):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue