LightRAG/lightrag/llm/zhipu.py
clssck dd1413f3eb test(lightrag,examples): add prompt accuracy and quality tests
Add comprehensive test suites for prompt evaluation:
- test_prompt_accuracy.py: 365 lines testing prompt extraction accuracy
- test_prompt_quality_deep.py: 672 lines for deep quality analysis
- Refactor prompt.py to consolidate optimized variants (removed prompt_optimized.py)
- Apply ruff formatting and type hints across 30 files
- Update pyrightconfig.json for static type checking
- Modernize reproduce scripts and examples with improved type annotations
- Sync uv.lock dependencies
2025-12-05 16:39:52 +01:00

194 lines
6.6 KiB
Python

import json
import re
from lightrag.utils import verbose_debug
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed('zhipuai'):
pm.install('zhipuai')
import numpy as np
from openai import (
APIConnectionError,
APITimeoutError,
RateLimitError,
)
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.utils import (
logger,
wrap_embedding_func_with_attrs,
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
)
async def zhipu_complete_if_cache(
prompt: str | list[dict[str, str]],
model: str = 'glm-4-flashx', # The most cost/performance balance model in glm-4 series
api_key: str | None = None,
system_prompt: str | None = None,
history_messages: list[dict[str, str]] | None = None,
enable_cot: bool = False,
**kwargs,
) -> str:
if history_messages is None:
history_messages = []
if enable_cot:
logger.debug('enable_cot=True is not supported for ZhipuAI and will be ignored.')
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError as e:
raise ImportError('Please install zhipuai before initialize zhipuai backend.') from e
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI(api_key=api_key) if api_key else ZhipuAI()
messages = []
if not system_prompt:
system_prompt = (
'You are a helpful assistant. Note that sensitive words in the content should be replaced with ***'
)
# Add system prompt if provided
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.extend(history_messages)
messages.append({'role': 'user', 'content': prompt})
# Add debug logging
logger.debug('===== Query Input to LLM =====')
logger.debug(f'Query: {prompt}')
verbose_debug(f'System prompt: {system_prompt}')
# Remove unsupported kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']}
response = client.chat.completions.create(model=model, messages=messages, **kwargs)
return response.choices[0].message.content
async def zhipu_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
enable_cot: bool = False,
**kwargs,
):
if history_messages is None:
history_messages = []
# Remove keyword_extraction from kwargs if it was passed redundantly
kwargs.pop('keyword_extraction', None)
if keyword_extraction:
# Add a system prompt to guide the model to return JSON format
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
Please analyze the content and extract two types of keywords:
1. High-level keywords: Important concepts and main themes
2. Low-level keywords: Specific details and supporting elements
Return your response in this exact JSON format:
{
"high_level_keywords": ["keyword1", "keyword2"],
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
}
Only return the JSON, no other text."""
# Combine with existing system prompt if any
system_prompt = f'{system_prompt}\n\n{extraction_prompt}' if system_prompt else extraction_prompt
try:
response = await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
# Try to parse as JSON
try:
data = json.loads(response)
return GPTKeywordExtractionFormat(
high_level_keywords=data.get('high_level_keywords', []),
low_level_keywords=data.get('low_level_keywords', []),
)
except json.JSONDecodeError:
# If direct JSON parsing fails, try to extract JSON from text
match = re.search(r'\{[\s\S]*\}', response)
if match:
try:
data = json.loads(match.group())
return GPTKeywordExtractionFormat(
high_level_keywords=data.get('high_level_keywords', []),
low_level_keywords=data.get('low_level_keywords', []),
)
except json.JSONDecodeError:
pass
# If all parsing fails, log warning and return empty format
logger.warning(f'Failed to parse keyword extraction response: {response}')
return GPTKeywordExtractionFormat(high_level_keywords=[], low_level_keywords=[])
except Exception as e:
logger.error(f'Error during keyword extraction: {e!s}')
return GPTKeywordExtractionFormat(high_level_keywords=[], low_level_keywords=[])
else:
# For non-keyword-extraction, just return the raw response string
return await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1024)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
)
async def zhipu_embedding(
texts: list[str], model: str = 'embedding-3', api_key: str | None = None, **kwargs
) -> np.ndarray:
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError as e:
raise ImportError('Please install zhipuai before initialize zhipuai backend.') from e
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI(api_key=api_key) if api_key else ZhipuAI()
# Convert single text to list if needed
if isinstance(texts, str):
texts = [texts]
embeddings = []
for text in texts:
try:
response = client.embeddings.create(model=model, input=[text], **kwargs)
embeddings.append(response.data[0].embedding)
except Exception as e:
raise Exception(f'Error calling ChatGLM Embedding API: {e!s}') from e
return np.array(embeddings)