code format

This commit is contained in:
yongtenglei 2025-11-19 11:32:17 +08:00
parent 6de01dc601
commit e202d484ce

View file

@ -14,24 +14,27 @@
# limitations under the License. # limitations under the License.
# #
import re
import base64 import base64
import json import json
import os
import tempfile
import logging import logging
import os
import re
import tempfile
from abc import ABC from abc import ABC
from copy import deepcopy from copy import deepcopy
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
from openai import OpenAI from openai import OpenAI
from openai.lib.azure import AzureOpenAI from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
from common.token_utils import num_tokens_from_string, total_token_count_from_response
from rag.nlp import is_english from rag.nlp import is_english
from rag.prompts.generator import vision_llm_describe_prompt from rag.prompts.generator import vision_llm_describe_prompt
from common.token_utils import num_tokens_from_string, total_token_count_from_response
class Base(ABC): class Base(ABC):
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -70,12 +73,7 @@ class Base(ABC):
pmpt = [{"type": "text", "text": text}] pmpt = [{"type": "text", "text": text}]
for img in images: for img in images:
pmpt.append({ pmpt.append({"type": "image_url", "image_url": {"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"}})
"type": "image_url",
"image_url": {
"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
}
})
return pmpt return pmpt
def chat(self, system, history, gen_conf, images=None, **kwargs): def chat(self, system, history, gen_conf, images=None, **kwargs):
@ -128,7 +126,7 @@ class Base(ABC):
try: try:
image.save(buffered, format="JPEG") image.save(buffered, format="JPEG")
except Exception: except Exception:
# reset buffer before saving PNG # reset buffer before saving PNG
buffered.seek(0) buffered.seek(0)
buffered.truncate() buffered.truncate()
image.save(buffered, format="PNG") image.save(buffered, format="PNG")
@ -158,7 +156,7 @@ class Base(ABC):
try: try:
image.save(buffered, format="JPEG") image.save(buffered, format="JPEG")
except Exception: except Exception:
# reset buffer before saving PNG # reset buffer before saving PNG
buffered.seek(0) buffered.seek(0)
buffered.truncate() buffered.truncate()
image.save(buffered, format="PNG") image.save(buffered, format="PNG")
@ -176,18 +174,13 @@ class Base(ABC):
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese" if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
b64 b64,
) ),
} }
] ]
def vision_llm_prompt(self, b64, prompt=None): def vision_llm_prompt(self, b64, prompt=None):
return [ return [{"role": "user", "content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)}]
{
"role": "user",
"content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
}
]
class GptV4(Base): class GptV4(Base):
@ -208,7 +201,7 @@ class GptV4(Base):
model=self.model_name, model=self.model_name,
messages=self.prompt(b64), messages=self.prompt(b64),
extra_body=self.extra_body, extra_body=self.extra_body,
unused = None, unused=None,
) )
return res.choices[0].message.content.strip(), total_token_count_from_response(res) return res.choices[0].message.content.strip(), total_token_count_from_response(res)
@ -219,7 +212,7 @@ class GptV4(Base):
messages=self.vision_llm_prompt(b64, prompt), messages=self.vision_llm_prompt(b64, prompt),
extra_body=self.extra_body, extra_body=self.extra_body,
) )
return res.choices[0].message.content.strip(),total_token_count_from_response(res) return res.choices[0].message.content.strip(), total_token_count_from_response(res)
class AzureGptV4(GptV4): class AzureGptV4(GptV4):
@ -324,14 +317,12 @@ class Zhipu4V(GptV4):
self.lang = lang self.lang = lang
Base.__init__(self, **kwargs) Base.__init__(self, **kwargs)
def _clean_conf(self, gen_conf): def _clean_conf(self, gen_conf):
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] del gen_conf["max_tokens"]
gen_conf = self._clean_conf_plealty(gen_conf) gen_conf = self._clean_conf_plealty(gen_conf)
return gen_conf return gen_conf
def _clean_conf_plealty(self, gen_conf): def _clean_conf_plealty(self, gen_conf):
if "presence_penalty" in gen_conf: if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"] del gen_conf["presence_penalty"]
@ -339,24 +330,17 @@ class Zhipu4V(GptV4):
del gen_conf["frequency_penalty"] del gen_conf["frequency_penalty"]
return gen_conf return gen_conf
def _request(self, msg, stream, gen_conf={}): def _request(self, msg, stream, gen_conf={}):
response = requests.post( response = requests.post(
self.base_url, self.base_url,
json={ json={"model": self.model_name, "messages": msg, "stream": stream, **gen_conf},
"model": self.model_name, headers={
"messages": msg, "Authorization": f"Bearer {self.api_key}",
"stream": stream, "Content-Type": "application/json",
**gen_conf
}, },
headers= {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
) )
return response.json() return response.json()
def chat(self, system, history, gen_conf, images=None, stream=False, **kwargs): def chat(self, system, history, gen_conf, images=None, stream=False, **kwargs):
if system and history and history[0].get("role") != "system": if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
@ -370,7 +354,6 @@ class Zhipu4V(GptV4):
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip() cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
return cleaned, total_token_count_from_response(response) return cleaned, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN
from rag.nlp import is_chinese from rag.nlp import is_chinese
@ -402,38 +385,18 @@ class Zhipu4V(GptV4):
yield tk_count yield tk_count
def describe(self, image): def describe(self, image):
return self.describe_with_prompt(image) return self.describe_with_prompt(image)
def describe_with_prompt(self, image, prompt=None): def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image) b64 = self.image2base64(image)
if prompt is None: if prompt is None:
prompt = "Describe this image." prompt = "Describe this image."
# Chat messages # Chat messages
messages = [ messages = [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": b64}}, {"type": "text", "text": prompt}]}]
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": { "url": b64 }
},
{
"type": "text",
"text": prompt
}
]
}
]
resp = self.client.chat.completions.create( resp = self.client.chat.completions.create(model=self.model_name, messages=messages, stream=False)
model=self.model_name,
messages=messages,
stream=False
)
content = resp.choices[0].message.content.strip() content = resp.choices[0].message.content.strip()
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip() cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
@ -452,6 +415,7 @@ class StepFunCV(GptV4):
self.lang = lang self.lang = lang
Base.__init__(self, **kwargs) Base.__init__(self, **kwargs)
class VolcEngineCV(GptV4): class VolcEngineCV(GptV4):
_FACTORY_NAME = "VolcEngine" _FACTORY_NAME = "VolcEngine"
@ -464,6 +428,7 @@ class VolcEngineCV(GptV4):
self.lang = lang self.lang = lang
Base.__init__(self, **kwargs) Base.__init__(self, **kwargs)
class LmStudioCV(GptV4): class LmStudioCV(GptV4):
_FACTORY_NAME = "LM-Studio" _FACTORY_NAME = "LM-Studio"
@ -502,13 +467,7 @@ class TogetherAICV(GptV4):
class YiCV(GptV4): class YiCV(GptV4):
_FACTORY_NAME = "01.AI" _FACTORY_NAME = "01.AI"
def __init__( def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1", **kwargs):
self,
key,
model_name,
lang="Chinese",
base_url="https://api.lingyiwanwu.com/v1", **kwargs
):
if not base_url: if not base_url:
base_url = "https://api.lingyiwanwu.com/v1" base_url = "https://api.lingyiwanwu.com/v1"
super().__init__(key, model_name, lang, base_url, **kwargs) super().__init__(key, model_name, lang, base_url, **kwargs)
@ -517,13 +476,7 @@ class YiCV(GptV4):
class SILICONFLOWCV(GptV4): class SILICONFLOWCV(GptV4):
_FACTORY_NAME = "SILICONFLOW" _FACTORY_NAME = "SILICONFLOW"
def __init__( def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1", **kwargs):
self,
key,
model_name,
lang="Chinese",
base_url="https://api.siliconflow.cn/v1", **kwargs
):
if not base_url: if not base_url:
base_url = "https://api.siliconflow.cn/v1" base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, lang, base_url, **kwargs) super().__init__(key, model_name, lang, base_url, **kwargs)
@ -532,13 +485,7 @@ class SILICONFLOWCV(GptV4):
class OpenRouterCV(GptV4): class OpenRouterCV(GptV4):
_FACTORY_NAME = "OpenRouter" _FACTORY_NAME = "OpenRouter"
def __init__( def __init__(self, key, model_name, lang="Chinese", base_url="https://openrouter.ai/api/v1", **kwargs):
self,
key,
model_name,
lang="Chinese",
base_url="https://openrouter.ai/api/v1", **kwargs
):
if not base_url: if not base_url:
base_url = "https://openrouter.ai/api/v1" base_url = "https://openrouter.ai/api/v1"
api_key = json.loads(key).get("api_key", "") api_key = json.loads(key).get("api_key", "")
@ -549,6 +496,7 @@ class OpenRouterCV(GptV4):
provider_order = json.loads(key).get("provider_order", "") provider_order = json.loads(key).get("provider_order", "")
self.extra_body = {} self.extra_body = {}
if provider_order: if provider_order:
def _to_order_list(x): def _to_order_list(x):
if x is None: if x is None:
return [] return []
@ -557,6 +505,7 @@ class OpenRouterCV(GptV4):
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()] return [str(s).strip() for s in x if str(s).strip()]
return [] return []
provider_cfg = {} provider_cfg = {}
provider_order = _to_order_list(provider_order) provider_order = _to_order_list(provider_order)
provider_cfg["order"] = provider_order provider_cfg["order"] = provider_order
@ -616,18 +565,18 @@ class OllamaCV(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs): def __init__(self, key, model_name, lang="Chinese", **kwargs):
from ollama import Client from ollama import Client
self.client = Client(host=kwargs["base_url"]) self.client = Client(host=kwargs["base_url"])
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1))) self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
Base.__init__(self, **kwargs) Base.__init__(self, **kwargs)
def _clean_img(self, img): def _clean_img(self, img):
if not isinstance(img, str): if not isinstance(img, str):
return img return img
#remove the header like "data/*;base64," # remove the header like "data/*;base64,"
if img.startswith("data:") and ";base64," in img: if img.startswith("data:") and ";base64," in img:
img = img.split(";base64,")[1] img = img.split(";base64,")[1]
return img return img
@ -687,12 +636,7 @@ class OllamaCV(Base):
def chat(self, system, history, gen_conf, images=None, **kwargs): def chat(self, system, history, gen_conf, images=None, **kwargs):
try: try:
response = self.client.chat( response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
model=self.model_name,
messages=self._form_history(system, history, images),
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
ans = response["message"]["content"].strip() ans = response["message"]["content"].strip()
return ans, response["eval_count"] + response.get("prompt_eval_count", 0) return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
@ -702,13 +646,7 @@ class OllamaCV(Base):
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = "" ans = ""
try: try:
response = self.client.chat( response = self.client.chat(model=self.model_name, messages=self._form_history(system, history, images), stream=True, options=self._clean_conf(gen_conf), keep_alive=self.keep_alive)
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True,
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
for resp in response: for resp in response:
if resp["done"]: if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
@ -823,7 +761,6 @@ class GeminiCV(Base):
def describe_with_prompt(self, image, prompt=None): def describe_with_prompt(self, image, prompt=None):
from google.genai import types from google.genai import types
vision_prompt = prompt if prompt else vision_llm_describe_prompt() vision_prompt = prompt if prompt else vision_llm_describe_prompt()
contents = [ contents = [
@ -842,7 +779,6 @@ class GeminiCV(Base):
) )
return res.text, total_token_count_from_response(res) return res.text, total_token_count_from_response(res)
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs): def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
if video_bytes: if video_bytes:
try: try:
@ -921,10 +857,7 @@ class GeminiCV(Base):
if video_size_mb <= 20: if video_size_mb <= 20:
response = client.models.generate_content( response = client.models.generate_content(
model="models/gemini-2.5-flash", model="models/gemini-2.5-flash",
contents=types.Content(parts=[ contents=types.Content(parts=[types.Part(inline_data=types.Blob(data=video_bytes, mime_type="video/mp4")), types.Part(text="Please summarize the video in proper sentences.")]),
types.Part(inline_data=types.Blob(data=video_bytes, mime_type="video/mp4")),
types.Part(text="Please summarize the video in proper sentences.")
])
) )
else: else:
logging.info(f"Video size {video_size_mb:.2f}MB exceeds 20MB. Using Files API...") logging.info(f"Video size {video_size_mb:.2f}MB exceeds 20MB. Using Files API...")
@ -934,10 +867,7 @@ class GeminiCV(Base):
tmp_path = Path(tmp.name) tmp_path = Path(tmp.name)
uploaded_file = client.files.upload(file=tmp_path) uploaded_file = client.files.upload(file=tmp_path)
response = client.models.generate_content( response = client.models.generate_content(model="gemini-2.5-flash", contents=[uploaded_file, "Please summarize this video in proper sentences."])
model="gemini-2.5-flash",
contents=[uploaded_file, "Please summarize this video in proper sentences."]
)
summary = response.text or "" summary = response.text or ""
logging.info(f"[GeminiCV] Video summarized: {summary[:32]}...") logging.info(f"[GeminiCV] Video summarized: {summary[:32]}...")
@ -953,13 +883,7 @@ class GeminiCV(Base):
class NvidiaCV(Base): class NvidiaCV(Base):
_FACTORY_NAME = "NVIDIA" _FACTORY_NAME = "NVIDIA"
def __init__( def __init__(self, key, model_name, lang="Chinese", base_url="https://ai.api.nvidia.com/v1/vlm", **kwargs):
self,
key,
model_name,
lang="Chinese",
base_url="https://ai.api.nvidia.com/v1/vlm", **kwargs
):
if not base_url: if not base_url:
base_url = ("https://ai.api.nvidia.com/v1/vlm",) base_url = ("https://ai.api.nvidia.com/v1/vlm",)
self.lang = lang self.lang = lang
@ -1004,9 +928,7 @@ class NvidiaCV(Base):
"content-type": "application/json", "content-type": "application/json",
"Authorization": f"Bearer {self.key}", "Authorization": f"Bearer {self.key}",
}, },
json={ json={"messages": msg, **gen_conf},
"messages": msg, **gen_conf
},
) )
return response.json() return response.json()
@ -1014,18 +936,12 @@ class NvidiaCV(Base):
b64 = self.image2base64(image) b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
response = self._request(vision_prompt) response = self._request(vision_prompt)
return ( return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
response["choices"][0]["message"]["content"].strip(),
total_token_count_from_response(response)
)
def chat(self, system, history, gen_conf, images=None, **kwargs): def chat(self, system, history, gen_conf, images=None, **kwargs):
try: try:
response = self._request(self._form_history(system, history, images), gen_conf) response = self._request(self._form_history(system, history, images), gen_conf)
return ( return (response["choices"][0]["message"]["content"].strip(), total_token_count_from_response(response))
response["choices"][0]["message"]["content"].strip(),
total_token_count_from_response(response)
)
except Exception as e: except Exception as e:
return "**ERROR**: " + str(e), 0 return "**ERROR**: " + str(e), 0
@ -1034,7 +950,7 @@ class NvidiaCV(Base):
try: try:
response = self._request(self._form_history(system, history, images), gen_conf) response = self._request(self._form_history(system, history, images), gen_conf)
cnt = response["choices"][0]["message"]["content"] cnt = response["choices"][0]["message"]["content"]
total_tokens += total_token_count_from_response(response) total_tokens += total_token_count_from_response(response)
for resp in cnt: for resp in cnt:
yield resp yield resp
except Exception as e: except Exception as e:
@ -1062,14 +978,15 @@ class AnthropicCV(Base):
return text return text
pmpt = [{"type": "text", "text": text}] pmpt = [{"type": "text", "text": text}]
for img in images: for img in images:
pmpt.append({ pmpt.append(
"type": "image", {
"source": { "type": "image",
"type": "base64", "source": {
"media_type": (img.split(":")[1].split(";")[0] if isinstance(img, str) and img[:4] == "data" else "image/png"), "type": "base64",
"data": (img.split(",")[1] if isinstance(img, str) and img[:4] == "data" else img) "media_type": (img.split(":")[1].split(";")[0] if isinstance(img, str) and img[:4] == "data" else "image/png"),
}, "data": (img.split(",")[1] if isinstance(img, str) and img[:4] == "data" else img),
} },
}
) )
return pmpt return pmpt