fix jina embedding
This commit is contained in:
parent
80f6d22d2a
commit
03832b2951
2 changed files with 27 additions and 3 deletions
|
|
@ -1194,6 +1194,12 @@
|
|||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 8196,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "jina-embeddings-v4",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "embedding"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
|
|
|||
|
|
@ -357,13 +357,14 @@ class JinaEmbed(Base):
|
|||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
# Jina v2/v3/v4 never reaches here
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
batch_size = 16
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
|
||||
data = {"model": self.model_name, "input": texts[i : i + batch_size]}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
try:
|
||||
res = response.json()
|
||||
|
|
@ -403,11 +404,28 @@ class JinaMultiVecEmbed(Base):
|
|||
img_b64s = base64.b64encode(text).decode('utf8')
|
||||
input.append({"image": img_b64s}) # base64 encoded image
|
||||
for i in range(0, len(texts), batch_size):
|
||||
data = {"model": self.model_name, "task": task, "truncate": True, "return_multivector": True, "input": input[i : i + batch_size]}
|
||||
data = {"model": self.model_name, "input": input[i : i + batch_size]}
|
||||
if "v4" in self.model_name:
|
||||
data["return_multivector"] = True
|
||||
|
||||
if "v3" in self.model_name or "v4" in self.model_name:
|
||||
data['task'] = task
|
||||
data['truncate'] = True
|
||||
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
try:
|
||||
res = response.json()
|
||||
ress.extend([d["embeddings"] for d in res["data"]])
|
||||
for d in res['data']:
|
||||
if data.get("return_multivector", False): # v4
|
||||
token_embs = np.asarray(d['embeddings'], dtype=np.float32)
|
||||
chunk_emb = token_embs.mean(axis=0)
|
||||
|
||||
else:
|
||||
# v2/v3
|
||||
chunk_emb = np.asarray(d['embedding'], dtype=np.float32)
|
||||
|
||||
ress.append(chunk_emb)
|
||||
|
||||
token_count += self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
log_exception(_e, response)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue