diff --git a/conf/llm_factories.json b/conf/llm_factories.json index d3b2dcc1c..3c84bd03d 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -1194,6 +1194,12 @@ "tags": "TEXT EMBEDDING", "max_tokens": 8196, "model_type": "embedding" + }, + { + "llm_name": "jina-embeddings-v4", + "tags": "TEXT EMBEDDING", + "max_tokens": 32768, + "model_type": "embedding" } ] }, diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 7f2f9ee7d..b8d2f158c 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -357,13 +357,14 @@ class JinaEmbed(Base): self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name + # Jina v2/v3/v4 never reaches here def encode(self, texts: list): texts = [truncate(t, 8196) for t in texts] batch_size = 16 ress = [] token_count = 0 for i in range(0, len(texts), batch_size): - data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"} + data = {"model": self.model_name, "input": texts[i : i + batch_size]} response = requests.post(self.base_url, headers=self.headers, json=data) try: res = response.json() @@ -403,11 +404,28 @@ class JinaMultiVecEmbed(Base): img_b64s = base64.b64encode(text).decode('utf8') input.append({"image": img_b64s}) # base64 encoded image for i in range(0, len(texts), batch_size): - data = {"model": self.model_name, "task": task, "truncate": True, "return_multivector": True, "input": input[i : i + batch_size]} + data = {"model": self.model_name, "input": input[i : i + batch_size]} + if "v4" in self.model_name: + data["return_multivector"] = True + + if "v3" in self.model_name or "v4" in self.model_name: + data['task'] = task + data['truncate'] = True + response = requests.post(self.base_url, headers=self.headers, json=data) try: res = response.json() - ress.extend([d["embeddings"] for d in res["data"]]) + for d in res['data']: + if data.get("return_multivector", False): # v4 + token_embs = np.asarray(d['embeddings'], dtype=np.float32) + chunk_emb = token_embs.mean(axis=0) + + else: + # v2/v3 + chunk_emb = np.asarray(d['embedding'], dtype=np.float32) + + ress.append(chunk_emb) + token_count += self.total_token_count(res) except Exception as _e: log_exception(_e, response)