add llm API
This commit is contained in:
parent
0866821a89
commit
499a355a90
6 changed files with 114 additions and 11 deletions
|
|
@ -1,10 +1,9 @@
|
||||||
[infiniflow]
|
[infiniflow]
|
||||||
es=http://es01:9200
|
es=http://es01:9200
|
||||||
pgdb_usr=root
|
postgres_user=root
|
||||||
pgdb_pwd=infiniflow_docgpt
|
postgres_password=infiniflow_docgpt
|
||||||
pgdb_host=postgres
|
postgres_host=postgres
|
||||||
pgdb_port=5432
|
postgres_port=5432
|
||||||
minio_host=minio:9000
|
minio_host=minio:9000
|
||||||
minio_usr=infiniflow
|
minio_user=infiniflow
|
||||||
minio_pwd=infiniflow_docgpt
|
minio_password=infiniflow_docgpt
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,10 @@ class QWen(Base):
|
||||||
from dashscope import Generation
|
from dashscope import Generation
|
||||||
from dashscope.api_entities.dashscope_response import Role
|
from dashscope.api_entities.dashscope_response import Role
|
||||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
response = Generation.call(
|
response = Generation.call(
|
||||||
Generation.Models.qwen_turbo,
|
Generation.Models.qwen_turbo,
|
||||||
messages=messages,
|
messages=history,
|
||||||
result_format='message'
|
result_format='message'
|
||||||
)
|
)
|
||||||
if response.status_code == HTTPStatus.OK:
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
|
|
||||||
70
python/llm/cv_model.py
Normal file
70
python/llm/cv_model.py
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
from abc import ABC
|
||||||
|
from openai import OpenAI
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
class Base(ABC):
|
||||||
|
def describe(self, image, max_tokens=300):
|
||||||
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
|
def image2base64(self, image):
|
||||||
|
if isinstance(image, BytesIO):
|
||||||
|
return base64.b64encode(image.getvalue()).decode("utf-8")
|
||||||
|
buffered = BytesIO()
|
||||||
|
try:
|
||||||
|
image.save(buffered, format="JPEG")
|
||||||
|
except Exception as e:
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def prompt(self, b64):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{b64}"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GptV4(Base):
|
||||||
|
def __init__(self):
|
||||||
|
import openapi
|
||||||
|
openapi.api_key = os.environ["OPENAPI_KEY"]
|
||||||
|
self.client = OpenAI()
|
||||||
|
|
||||||
|
def describe(self, image, max_tokens=300):
|
||||||
|
b64 = self.image2base64(image)
|
||||||
|
|
||||||
|
res = self.client.chat.completions.create(
|
||||||
|
model="gpt-4-vision-preview",
|
||||||
|
messages=self.prompt(b64),
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
return res.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class QWen(Base):
|
||||||
|
def describe(self, image, max_tokens=300):
|
||||||
|
from http import HTTPStatus
|
||||||
|
from dashscope import MultiModalConversation1`
|
||||||
|
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||||
|
response = MultiModalConversation.call(model=MultiModalConversation.Models.qwen_vl_chat_v1,
|
||||||
|
messages=self.prompt(self.image2base64(image)))
|
||||||
|
)
|
||||||
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
return response.output.choices[0]['message']['content']
|
||||||
|
return response.message
|
||||||
|
|
||||||
|
|
@ -30,3 +30,32 @@ class HuEmbedding(Base):
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
res.extend(self.model.encode(texts[i:i+batch_size]).tolist())
|
res.extend(self.model.encode(texts[i:i+batch_size]).tolist())
|
||||||
return np.array(res)
|
return np.array(res)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GptEmbed(Base):
|
||||||
|
def __init__(self):
|
||||||
|
import openapi,os
|
||||||
|
from openai import OpenAI
|
||||||
|
openapi.api_key = os.environ["OPENAPI_KEY"]
|
||||||
|
self.client = OpenAI()
|
||||||
|
|
||||||
|
def encode(self, texts: list, batch_size=32):
|
||||||
|
res = self.client.embeddings.create(input = texts,
|
||||||
|
model="text-embedding-ada-002")
|
||||||
|
return [d["embedding"] for d in res["data"]]
|
||||||
|
|
||||||
|
|
||||||
|
class QWen(base):
|
||||||
|
def encode(self, texts: list, batch_size=32, text_type="document"):
|
||||||
|
import dashscope
|
||||||
|
from http import HTTPStatus
|
||||||
|
res = []
|
||||||
|
for txt in texts:
|
||||||
|
resp = dashscope.TextEmbedding.call(
|
||||||
|
model=dashscope.TextEmbedding.Models.text_embedding_v2,
|
||||||
|
input=txt[:2048],
|
||||||
|
text_type=text_type
|
||||||
|
)
|
||||||
|
res.append(resp["output"]["embeddings"][0]["embedding"])
|
||||||
|
return res
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,11 @@ class Postgres(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.conn = psycopg2.connect(f"dbname={self.dbnm} user={self.config.get('pgdb_usr')} password={self.config.get('pgdb_pwd')} host={self.config.get('pgdb_host')} port={self.config.get('pgdb_port')}")
|
self.conn = psycopg2.connect(f"""dbname={self.dbnm}
|
||||||
|
user={self.config.get('postgres_user')}
|
||||||
|
password={self.config.get('postgres_password')}
|
||||||
|
host={self.config.get('postgres_host')}
|
||||||
|
port={self.config.get('postgres_port')}""")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Fail to connect %s "%self.config.get("pgdb_host") + str(e))
|
logging.error("Fail to connect %s "%self.config.get("pgdb_host") + str(e))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,8 @@ class HuMinio(object):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.conn = Minio(self.config.get("minio_host"),
|
self.conn = Minio(self.config.get("minio_host"),
|
||||||
access_key=self.config.get("minio_usr"),
|
access_key=self.config.get("minio_user"),
|
||||||
secret_key=self.config.get("minio_pwd"),
|
secret_key=self.config.get("minio_password"),
|
||||||
secure=False
|
secure=False
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue