code format

This commit is contained in:
yongtenglei 2025-11-19 11:32:17 +08:00
parent 6de01dc601
commit e202d484ce

View file

@ -14,24 +14,27 @@
# limitations under the License.
#
import re
import base64
import json
import os
import tempfile
import logging
import os
import re
import tempfile
from abc import ABC
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from urllib.parse import urljoin
import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
from common.token_utils import num_tokens_from_string, total_token_count_from_response
from rag.nlp import is_english
from rag.prompts.generator import vision_llm_describe_prompt
from common.token_utils import num_tokens_from_string, total_token_count_from_response
class Base(ABC):
def __init__(self, **kwargs):
@ -70,12 +73,7 @@ class Base(ABC):
pmpt = [{"type": "text", "text": text}]
for img in images:
pmpt.append({
"type": "image_url",
"image_url": {
"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
}
})
pmpt.append({"type": "image_url", "image_url": {"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"}})
return pmpt
def chat(self, system, history, gen_conf, images=None, **kwargs):
@ -128,7 +126,7 @@ class Base(ABC):
try:
image.save(buffered, format="JPEG")
except Exception:
# reset buffer before saving PNG
# reset buffer before saving PNG
buffered.seek(0)
buffered.truncate()
image.save(buffered, format="PNG")
@ -158,7 +156,7 @@ class Base(ABC):
try:
image.save(buffered, format="JPEG")
except Exception:
# reset buffer before saving PNG
# reset buffer before saving PNG
buffered.seek(0)
buffered.truncate()
image.save(buffered, format="PNG")
@ -176,18 +174,13 @@ class Base(ABC):
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
b64
)
b64,
),
}
]
def vision_llm_prompt(self, b64, prompt=None):
return [
{
"role": "user",
"content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
}
]
return [{"role": "user", "content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)}]
class GptV4(Base):
@ -208,7 +201,7 @@ class GptV4(Base):
model=self.model_name,
messages=self.prompt(b64),
extra_body=self.extra_body,
unused = None,
unused=None,
)
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
@ -219,7 +212,7 @@ class GptV4(Base):
messages=self.vision_llm_prompt(b64, prompt),
extra_body=self.extra_body,
)
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
class AzureGptV4(GptV4):
@ -324,14 +317,12 @@ class Zhipu4V(GptV4):
self.lang = lang
Base.__init__(self, **kwargs)
def _clean_conf(self, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
gen_conf = self._clean_conf_plealty(gen_conf)
return gen_conf
def _clean_conf_plealty(self, gen_conf):
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
@ -339,24 +330,17 @@ class Zhipu4V(GptV4):
del gen_conf["frequency_penalty"]
return gen_conf
def _request(self, msg, stream, gen_conf={}):
response = requests.post(
self.base_url,
json={
"model": self.model_name,
"messages": msg,
"stream": stream,
**gen_conf
json={"model": self.model_name, "messages": msg, "stream": stream, **gen_conf},
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
headers= {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
)
return response.json()
def chat(self, system, history, gen_conf, images=None, stream=False, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
@ -370,7 +354,6 @@ class Zhipu4V(GptV4):
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
return cleaned, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN
from rag.nlp import is_chinese
@ -402,38 +385,18 @@ class Zhipu4V(GptV4):
yield tk_count
def describe(self, image):
return self.describe_with_prompt(image)
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
if prompt is None:
prompt = "Describe this image."
# Chat messages
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": { "url": b64 }
},
{
"type": "text",
"text": prompt
}
]
}
]
messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": b64}}, {"type": "text", "text": prompt}]}]
resp = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
stream=False
)
resp = self.client.chat.completions.create(model=self.model_name, messages=messages, stream=False)
content = resp.choices[0].message.content.strip()
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
@ -452,6 +415,7 @@ class StepFunCV(GptV4):
self.lang = lang
Base.__init__(self, **kwargs)
class VolcEngineCV(GptV4):
_FACTORY_NAME = "VolcEngine"
@ -464,6 +428,7 @@ class VolcEngineCV(GptV4):
self.lang = lang
Base.__init__(self, **kwargs)
class LmStudioCV(GptV4):
_FACTORY_NAME = "LM-Studio"
@ -502,13 +467,7 @@ class TogetherAICV(GptV4):
class YiCV(GptV4):
_FACTORY_NAME = "01.AI"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://api.lingyiwanwu.com/v1", **kwargs
):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1", **kwargs):
if not base_url:
base_url = "https://api.lingyiwanwu.com/v1"
super().__init__(key, model_name, lang, base_url, **kwargs)
@ -517,13 +476,7 @@ class YiCV(GptV4):
class SILICONFLOWCV(GptV4):
_FACTORY_NAME = "SILICONFLOW"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://api.siliconflow.cn/v1", **kwargs
):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1", **kwargs):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, lang, base_url, **kwargs)
@ -532,13 +485,7 @@ class SILICONFLOWCV(GptV4):
class OpenRouterCV(GptV4):
_FACTORY_NAME = "OpenRouter"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://openrouter.ai/api/v1", **kwargs
):
def __init__(self, key, model_name, lang="Chinese", base_url="https://openrouter.ai/api/v1", **kwargs):
if not base_url:
base_url = "https://openrouter.ai/api/v1"
api_key = json.loads(key).get("api_key", "")
@ -549,6 +496,7 @@ class OpenRouterCV(GptV4):
provider_order = json.loads(key).get("provider_order", "")
self.extra_body = {}
if provider_order:
def _to_order_list(x):
if x is None:
return []
@ -557,6 +505,7 @@ class OpenRouterCV(GptV4):
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
provider_cfg = {}
provider_order = _to_order_list(provider_order)
provider_cfg["order"] = provider_order
@ -616,18 +565,18 @@ class OllamaCV(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
from ollama import Client
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
self.lang = lang
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
Base.__init__(self, **kwargs)
def _clean_img(self, img):
if not isinstance(img, str):
return img
#remove the header like "data/*;base64,"
# remove the header like "data/*;base64,"
if img.startswith("data:") and ";base64," in img:
img = img.split(";base64,")[1]
return img
@ -687,12 +636,7 @@ class OllamaCV(Base):
def chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = self.client.chat(
model=self.model_name,
messages=self._form_history(system, history, images),
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
ans = response["message"]["content"].strip()
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
@ -702,13 +646,7 @@ class OllamaCV(Base):
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
try:
response = self.client.chat(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True,
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
@ -823,7 +761,6 @@ class GeminiCV(Base):
def describe_with_prompt(self, image, prompt=None):
from google.genai import types
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
contents = [
@ -842,7 +779,6 @@ class GeminiCV(Base):
)
return res.text, total_token_count_from_response(res)
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
if video_bytes:
try:
@ -921,10 +857,7 @@ class GeminiCV(Base):
if video_size_mb <= 20:
response = client.models.generate_content(
model="models/gemini-2.5-flash",
contents=types.Content(parts=[
types.Part(inline_data=types.Blob(data=video_bytes, mime_type="video/mp4")),
types.Part(text="Please summarize the video in proper sentences.")
])
contents=types.Content(parts=[types.Part(inline_data=types.Blob(data=video_bytes, mime_type="video/mp4")), types.Part(text="Please summarize the video in proper sentences.")]),
)
else:
logging.info(f"Video size {video_size_mb:.2f}MB exceeds 20MB. Using Files API...")
@ -934,10 +867,7 @@ class GeminiCV(Base):
tmp_path = Path(tmp.name)
uploaded_file = client.files.upload(file=tmp_path)
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=[uploaded_file, "Please summarize this video in proper sentences."]
)
response = client.models.generate_content(model="gemini-2.5-flash", contents=[uploaded_file, "Please summarize this video in proper sentences."])
summary = response.text or ""
logging.info(f"[GeminiCV] Video summarized: {summary[:32]}...")
@ -953,13 +883,7 @@ class GeminiCV(Base):
class NvidiaCV(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://ai.api.nvidia.com/v1/vlm", **kwargs
):
def __init__(self, key, model_name, lang="Chinese", base_url="https://ai.api.nvidia.com/v1/vlm", **kwargs):
if not base_url:
base_url = ("https://ai.api.nvidia.com/v1/vlm",)
self.lang = lang
@ -1004,9 +928,7 @@ class NvidiaCV(Base):
"content-type": "application/json",
"Authorization": f"Bearer {self.key}",
},
json={
"messages": msg, **gen_conf
},
json={"messages": msg, **gen_conf},
)
return response.json()
@ -1014,18 +936,12 @@ class NvidiaCV(Base):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
response = self._request(vision_prompt)
return (
response["choices"][0]["message"]["content"].strip(),
total_token_count_from_response(response)
)
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
def chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = self._request(self._form_history(system, history, images), gen_conf)
return (
response["choices"][0]["message"]["content"].strip(),
total_token_count_from_response(response)
)
return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
except Exception as e:
return "**ERROR**: " + str(e), 0
@ -1034,7 +950,7 @@ class NvidiaCV(Base):
try:
response = self._request(self._form_history(system, history, images), gen_conf)
cnt = response["choices"][0]["message"]["content"]
total_tokens += total_token_count_from_response(response)
total_tokens += total_token_count_from_response(response)
for resp in cnt:
yield resp
except Exception as e:
@ -1062,14 +978,15 @@ class AnthropicCV(Base):
return text
pmpt = [{"type": "text", "text": text}]
for img in images:
pmpt.append({
"type": "image",
"source": {
"type": "base64",
"media_type": (img.split(":")[1].split(";")[0] if isinstance(img, str) and img[:4] == "data" else "image/png"),
"data": (img.split(",")[1] if isinstance(img, str) and img[:4] == "data" else img)
},
}
pmpt.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": (img.split(":")[1].split(";")[0] if isinstance(img, str) and img[:4] == "data" else "image/png"),
"data": (img.split(",")[1] if isinstance(img, str) and img[:4] == "data" else img),
},
}
)
return pmpt