Add extensive test suites for API routes and utilities: - Implement test_search_routes.py (406 lines) for search endpoint validation - Implement test_upload_routes.py (724 lines) for document upload workflows - Implement test_s3_client.py (618 lines) for S3 storage operations - Implement test_citation_utils.py (352 lines) for citation extraction - Implement test_chunking.py (216 lines) for text chunking validation Add S3 storage client implementation: - Create lightrag/storage/s3_client.py with S3 operations - Add storage module initialization with exports - Integrate S3 client with document upload handling Enhance API routes and core functionality: - Add search_routes.py with full-text and graph search endpoints - Add upload_routes.py with multipart document upload support - Update operate.py with bulk operations and health checks - Enhance postgres_impl.py with bulk upsert and parameterized queries - Update lightrag_server.py to register new API routes - Improve utils.py with citation and formatting utilities Update dependencies and configuration: - Add S3 and test dependencies to pyproject.toml - Update docker-compose.test.yml for testing environment - Sync uv.lock with new dependencies Apply code quality improvements across all modified files: - Add type hints to function signatures - Update imports and router initialization - Fix logging and error handling
194 lines
6.6 KiB
Python
194 lines
6.6 KiB
Python
import json
|
|
import re
|
|
|
|
import pipmaster as pm # Pipmaster for dynamic library install
|
|
|
|
from lightrag.utils import verbose_debug
|
|
|
|
# install specific modules
|
|
if not pm.is_installed('zhipuai'):
|
|
pm.install('zhipuai')
|
|
|
|
|
|
import numpy as np
|
|
from openai import (
|
|
APIConnectionError,
|
|
APITimeoutError,
|
|
RateLimitError,
|
|
)
|
|
from tenacity import (
|
|
retry,
|
|
retry_if_exception_type,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
)
|
|
|
|
from lightrag.types import GPTKeywordExtractionFormat
|
|
from lightrag.utils import (
|
|
logger,
|
|
wrap_embedding_func_with_attrs,
|
|
)
|
|
|
|
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
|
|
)
|
|
async def zhipu_complete_if_cache(
|
|
prompt: str | list[dict[str, str]],
|
|
model: str = 'glm-4-flashx', # The most cost/performance balance model in glm-4 series
|
|
api_key: str | None = None,
|
|
system_prompt: str | None = None,
|
|
history_messages: list[dict[str, str]] | None = None,
|
|
enable_cot: bool = False,
|
|
**kwargs,
|
|
) -> str:
|
|
if history_messages is None:
|
|
history_messages = []
|
|
if enable_cot:
|
|
logger.debug('enable_cot=True is not supported for ZhipuAI and will be ignored.')
|
|
# dynamically load ZhipuAI
|
|
try:
|
|
from zhipuai import ZhipuAI
|
|
except ImportError as e:
|
|
raise ImportError('Please install zhipuai before initialize zhipuai backend.') from e
|
|
|
|
# please set ZHIPUAI_API_KEY in your environment
|
|
# os.environ["ZHIPUAI_API_KEY"]
|
|
client = ZhipuAI(api_key=api_key) if api_key else ZhipuAI()
|
|
|
|
messages = []
|
|
|
|
if not system_prompt:
|
|
system_prompt = (
|
|
'You are a helpful assistant. Note that sensitive words in the content should be replaced with ***'
|
|
)
|
|
|
|
# Add system prompt if provided
|
|
if system_prompt:
|
|
messages.append({'role': 'system', 'content': system_prompt})
|
|
messages.extend(history_messages)
|
|
messages.append({'role': 'user', 'content': prompt})
|
|
|
|
# Add debug logging
|
|
logger.debug('===== Query Input to LLM =====')
|
|
logger.debug(f'Query: {prompt}')
|
|
verbose_debug(f'System prompt: {system_prompt}')
|
|
|
|
# Remove unsupported kwargs
|
|
kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']}
|
|
|
|
response = client.chat.completions.create(model=model, messages=messages, **kwargs)
|
|
|
|
return response.choices[0].message.content
|
|
|
|
|
|
async def zhipu_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
keyword_extraction=False,
|
|
enable_cot: bool = False,
|
|
**kwargs,
|
|
):
|
|
if history_messages is None:
|
|
history_messages = []
|
|
# Remove keyword_extraction from kwargs if it was passed redundantly
|
|
kwargs.pop('keyword_extraction', None)
|
|
|
|
if keyword_extraction:
|
|
# Add a system prompt to guide the model to return JSON format
|
|
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
|
|
Please analyze the content and extract two types of keywords:
|
|
1. High-level keywords: Important concepts and main themes
|
|
2. Low-level keywords: Specific details and supporting elements
|
|
|
|
Return your response in this exact JSON format:
|
|
{
|
|
"high_level_keywords": ["keyword1", "keyword2"],
|
|
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
|
|
}
|
|
|
|
Only return the JSON, no other text."""
|
|
|
|
# Combine with existing system prompt if any
|
|
system_prompt = f'{system_prompt}\n\n{extraction_prompt}' if system_prompt else extraction_prompt
|
|
|
|
try:
|
|
response = await zhipu_complete_if_cache(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
enable_cot=enable_cot,
|
|
**kwargs,
|
|
)
|
|
|
|
# Try to parse as JSON
|
|
try:
|
|
data = json.loads(response)
|
|
return GPTKeywordExtractionFormat(
|
|
high_level_keywords=data.get('high_level_keywords', []),
|
|
low_level_keywords=data.get('low_level_keywords', []),
|
|
)
|
|
except json.JSONDecodeError:
|
|
# If direct JSON parsing fails, try to extract JSON from text
|
|
match = re.search(r'\{[\s\S]*\}', response)
|
|
if match:
|
|
try:
|
|
data = json.loads(match.group())
|
|
return GPTKeywordExtractionFormat(
|
|
high_level_keywords=data.get('high_level_keywords', []),
|
|
low_level_keywords=data.get('low_level_keywords', []),
|
|
)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# If all parsing fails, log warning and return empty format
|
|
logger.warning(f'Failed to parse keyword extraction response: {response}')
|
|
return GPTKeywordExtractionFormat(high_level_keywords=[], low_level_keywords=[])
|
|
except Exception as e:
|
|
logger.error(f'Error during keyword extraction: {e!s}')
|
|
return GPTKeywordExtractionFormat(high_level_keywords=[], low_level_keywords=[])
|
|
else:
|
|
# For non-keyword-extraction, just return the raw response string
|
|
return await zhipu_complete_if_cache(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
enable_cot=enable_cot,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@wrap_embedding_func_with_attrs(embedding_dim=1024)
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
|
|
)
|
|
async def zhipu_embedding(
|
|
texts: list[str], model: str = 'embedding-3', api_key: str | None = None, **kwargs
|
|
) -> np.ndarray:
|
|
# dynamically load ZhipuAI
|
|
try:
|
|
from zhipuai import ZhipuAI
|
|
except ImportError as e:
|
|
raise ImportError('Please install zhipuai before initialize zhipuai backend.') from e
|
|
# please set ZHIPUAI_API_KEY in your environment
|
|
# os.environ["ZHIPUAI_API_KEY"]
|
|
client = ZhipuAI(api_key=api_key) if api_key else ZhipuAI()
|
|
|
|
# Convert single text to list if needed
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
|
|
embeddings = []
|
|
for text in texts:
|
|
try:
|
|
response = client.embeddings.create(model=model, input=[text], **kwargs)
|
|
embeddings.append(response.data[0].embedding)
|
|
except Exception as e:
|
|
raise Exception(f'Error calling ChatGLM Embedding API: {e!s}') from e
|
|
|
|
return np.array(embeddings)
|