Merge branch 'main' into release-update-0.1.50
This commit is contained in:
commit
df9234d0ee
19 changed files with 870 additions and 338 deletions
|
|
@ -4,12 +4,11 @@ import copy
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from opensearchpy import OpenSearch, helpers
|
from opensearchpy import OpenSearch, helpers
|
||||||
from opensearchpy.exceptions import RequestError
|
from opensearchpy.exceptions import OpenSearchException, RequestError
|
||||||
|
|
||||||
from lfx.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
|
from lfx.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
|
||||||
from lfx.base.vectorstores.vector_store_connection_decorator import vector_store_connection
|
from lfx.base.vectorstores.vector_store_connection_decorator import vector_store_connection
|
||||||
|
|
@ -50,11 +49,12 @@ def get_embedding_field_name(model_name: str) -> str:
|
||||||
Returns:
|
Returns:
|
||||||
Field name in format: chunk_embedding_{normalized_model_name}
|
Field name in format: chunk_embedding_{normalized_model_name}
|
||||||
"""
|
"""
|
||||||
|
logger.info(f"chunk_embedding_{normalize_model_name(model_name)}")
|
||||||
return f"chunk_embedding_{normalize_model_name(model_name)}"
|
return f"chunk_embedding_{normalize_model_name(model_name)}"
|
||||||
|
|
||||||
|
|
||||||
@vector_store_connection
|
@vector_store_connection
|
||||||
class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
class OpenSearchVectorStoreComponentMultimodalMultiEmbedding(LCVectorStoreComponent):
|
||||||
"""OpenSearch Vector Store Component with Multi-Model Hybrid Search Capabilities.
|
"""OpenSearch Vector Store Component with Multi-Model Hybrid Search Capabilities.
|
||||||
|
|
||||||
This component provides vector storage and retrieval using OpenSearch, combining semantic
|
This component provides vector storage and retrieval using OpenSearch, combining semantic
|
||||||
|
|
@ -73,9 +73,15 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
- Parallel query embedding generation for all detected models
|
- Parallel query embedding generation for all detected models
|
||||||
- Vector storage with configurable engines (jvector, nmslib, faiss, lucene)
|
- Vector storage with configurable engines (jvector, nmslib, faiss, lucene)
|
||||||
- Flexible authentication (Basic auth, JWT tokens)
|
- Flexible authentication (Basic auth, JWT tokens)
|
||||||
|
|
||||||
|
Model Name Resolution:
|
||||||
|
- Priority: deployment > model > model_name attributes
|
||||||
|
- This ensures correct matching between embedding objects and index fields
|
||||||
|
- When multiple embeddings are provided, specify embedding_model_name to select which one to use
|
||||||
|
- During search, each detected model in the index is matched to its corresponding embedding object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
display_name: str = "OpenSearch (Multi-Model)"
|
display_name: str = "OpenSearch (Multi-Model Multi-Embedding)"
|
||||||
icon: str = "OpenSearch"
|
icon: str = "OpenSearch"
|
||||||
description: str = (
|
description: str = (
|
||||||
"Store and search documents using OpenSearch with multi-model hybrid semantic and keyword search."
|
"Store and search documents using OpenSearch with multi-model hybrid semantic and keyword search."
|
||||||
|
|
@ -130,7 +136,7 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
value=[],
|
value=[],
|
||||||
input_types=["Data"]
|
input_types=["Data"],
|
||||||
),
|
),
|
||||||
StrInput(
|
StrInput(
|
||||||
name="opensearch_url",
|
name="opensearch_url",
|
||||||
|
|
@ -203,16 +209,19 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
advanced=True,
|
advanced=True,
|
||||||
),
|
),
|
||||||
*LCVectorStoreComponent.inputs, # includes search_query, add_documents, etc.
|
*LCVectorStoreComponent.inputs, # includes search_query, add_documents, etc.
|
||||||
HandleInput(name="embedding", display_name="Embedding", input_types=["Embeddings"]),
|
HandleInput(name="embedding", display_name="Embedding", input_types=["Embeddings"], is_list=True),
|
||||||
StrInput(
|
StrInput(
|
||||||
name="embedding_model_name",
|
name="embedding_model_name",
|
||||||
display_name="Embedding Model Name",
|
display_name="Embedding Model Name",
|
||||||
value="",
|
value="",
|
||||||
info=(
|
info=(
|
||||||
"Name of the embedding model being used (e.g., 'text-embedding-3-small'). "
|
"Name of the embedding model to use for ingestion. This selects which embedding from the list "
|
||||||
"Used to create dynamic vector field names and track which model embedded each document. "
|
"will be used to embed documents. Matches on deployment, model, model_id, or model_name. "
|
||||||
"Auto-detected from embedding component if not specified."
|
"For duplicate deployments, use combined format: 'deployment:model' "
|
||||||
|
"(e.g., 'text-embedding-ada-002:text-embedding-3-large'). "
|
||||||
|
"Leave empty to use the first embedding. Error message will show all available identifiers."
|
||||||
),
|
),
|
||||||
|
advanced=False,
|
||||||
),
|
),
|
||||||
StrInput(
|
StrInput(
|
||||||
name="vector_field",
|
name="vector_field",
|
||||||
|
|
@ -265,20 +274,20 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
name="username",
|
name="username",
|
||||||
display_name="Username",
|
display_name="Username",
|
||||||
value="admin",
|
value="admin",
|
||||||
show=False,
|
show=True,
|
||||||
),
|
),
|
||||||
SecretStrInput(
|
SecretStrInput(
|
||||||
name="password",
|
name="password",
|
||||||
display_name="OpenSearch Password",
|
display_name="OpenSearch Password",
|
||||||
value="admin",
|
value="admin",
|
||||||
show=False,
|
show=True,
|
||||||
),
|
),
|
||||||
SecretStrInput(
|
SecretStrInput(
|
||||||
name="jwt_token",
|
name="jwt_token",
|
||||||
display_name="JWT Token",
|
display_name="JWT Token",
|
||||||
value="JWT",
|
value="JWT",
|
||||||
load_from_db=False,
|
load_from_db=False,
|
||||||
show=True,
|
show=False,
|
||||||
info=(
|
info=(
|
||||||
"Valid JSON Web Token for authentication. "
|
"Valid JSON Web Token for authentication. "
|
||||||
"Will be sent in the Authorization header (with optional 'Bearer ' prefix)."
|
"Will be sent in the Authorization header (with optional 'Bearer ' prefix)."
|
||||||
|
|
@ -318,9 +327,16 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _get_embedding_model_name(self) -> str:
|
def _get_embedding_model_name(self, embedding_obj=None) -> str:
|
||||||
"""Get the embedding model name from component config or embedding object.
|
"""Get the embedding model name from component config or embedding object.
|
||||||
|
|
||||||
|
Priority: deployment > model > model_id > model_name
|
||||||
|
This ensures we use the actual model being deployed, not just the configured model.
|
||||||
|
Supports multiple embedding providers (OpenAI, Watsonx, Cohere, etc.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_obj: Specific embedding object to get name from (optional)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embedding model name
|
Embedding model name
|
||||||
|
|
||||||
|
|
@ -331,17 +347,46 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
if hasattr(self, "embedding_model_name") and self.embedding_model_name:
|
if hasattr(self, "embedding_model_name") and self.embedding_model_name:
|
||||||
return self.embedding_model_name.strip()
|
return self.embedding_model_name.strip()
|
||||||
|
|
||||||
# Try to get from embedding component
|
# Try to get from provided embedding object
|
||||||
|
if embedding_obj:
|
||||||
|
# Priority: deployment > model > model_id > model_name
|
||||||
|
if hasattr(embedding_obj, "deployment") and embedding_obj.deployment:
|
||||||
|
return str(embedding_obj.deployment)
|
||||||
|
if hasattr(embedding_obj, "model") and embedding_obj.model:
|
||||||
|
return str(embedding_obj.model)
|
||||||
|
if hasattr(embedding_obj, "model_id") and embedding_obj.model_id:
|
||||||
|
return str(embedding_obj.model_id)
|
||||||
|
if hasattr(embedding_obj, "model_name") and embedding_obj.model_name:
|
||||||
|
return str(embedding_obj.model_name)
|
||||||
|
|
||||||
|
# Try to get from embedding component (legacy single embedding)
|
||||||
if hasattr(self, "embedding") and self.embedding:
|
if hasattr(self, "embedding") and self.embedding:
|
||||||
if hasattr(self.embedding, "model"):
|
# Handle list of embeddings
|
||||||
return str(self.embedding.model)
|
if isinstance(self.embedding, list) and len(self.embedding) > 0:
|
||||||
if hasattr(self.embedding, "model_name"):
|
first_emb = self.embedding[0]
|
||||||
return str(self.embedding.model_name)
|
if hasattr(first_emb, "deployment") and first_emb.deployment:
|
||||||
|
return str(first_emb.deployment)
|
||||||
|
if hasattr(first_emb, "model") and first_emb.model:
|
||||||
|
return str(first_emb.model)
|
||||||
|
if hasattr(first_emb, "model_id") and first_emb.model_id:
|
||||||
|
return str(first_emb.model_id)
|
||||||
|
if hasattr(first_emb, "model_name") and first_emb.model_name:
|
||||||
|
return str(first_emb.model_name)
|
||||||
|
# Handle single embedding
|
||||||
|
elif not isinstance(self.embedding, list):
|
||||||
|
if hasattr(self.embedding, "deployment") and self.embedding.deployment:
|
||||||
|
return str(self.embedding.deployment)
|
||||||
|
if hasattr(self.embedding, "model") and self.embedding.model:
|
||||||
|
return str(self.embedding.model)
|
||||||
|
if hasattr(self.embedding, "model_id") and self.embedding.model_id:
|
||||||
|
return str(self.embedding.model_id)
|
||||||
|
if hasattr(self.embedding, "model_name") and self.embedding.model_name:
|
||||||
|
return str(self.embedding.model_name)
|
||||||
|
|
||||||
msg = (
|
msg = (
|
||||||
"Could not determine embedding model name. "
|
"Could not determine embedding model name. "
|
||||||
"Please set the 'embedding_model_name' field or ensure the embedding component "
|
"Please set the 'embedding_model_name' field or ensure the embedding component "
|
||||||
"has a 'model' or 'model_name' attribute."
|
"has a 'deployment', 'model', 'model_id', or 'model_name' attribute."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
@ -434,12 +479,8 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
# Also ensure the embedding_model tracking field exists as keyword
|
# Also ensure the embedding_model tracking field exists as keyword
|
||||||
"embedding_model": {
|
"embedding_model": {"type": "keyword"},
|
||||||
"type": "keyword"
|
"embedding_dimensions": {"type": "integer"},
|
||||||
},
|
|
||||||
"embedding_dimensions": {
|
|
||||||
"type": "integer"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
client.indices.put_mapping(index=index_name, body=mapping)
|
client.indices.put_mapping(index=index_name, body=mapping)
|
||||||
|
|
@ -450,9 +491,9 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
|
|
||||||
properties = self._get_index_properties(client)
|
properties = self._get_index_properties(client)
|
||||||
if not self._is_knn_vector_field(properties, field_name):
|
if not self._is_knn_vector_field(properties, field_name):
|
||||||
raise ValueError(
|
msg = f"Field '{field_name}' is not mapped as knn_vector. Current mapping: {properties.get(field_name)}"
|
||||||
f"Field '{field_name}' is not mapped as knn_vector. Current mapping: {properties.get(field_name)}"
|
logger.aerror(msg)
|
||||||
)
|
raise ValueError(msg)
|
||||||
|
|
||||||
def _validate_aoss_with_engines(self, *, is_aoss: bool, engine: str) -> None:
|
def _validate_aoss_with_engines(self, *, is_aoss: bool, engine: str) -> None:
|
||||||
"""Validate engine compatibility with Amazon OpenSearch Serverless (AOSS).
|
"""Validate engine compatibility with Amazon OpenSearch Serverless (AOSS).
|
||||||
|
|
@ -600,8 +641,15 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
@check_cached_vector_store
|
@check_cached_vector_store
|
||||||
def build_vector_store(self) -> OpenSearch:
|
def build_vector_store(self) -> OpenSearch:
|
||||||
# Return raw OpenSearch client as our "vector store."
|
# Return raw OpenSearch client as our "vector store."
|
||||||
self.log(self.ingest_data)
|
|
||||||
client = self.build_client()
|
client = self.build_client()
|
||||||
|
|
||||||
|
# Check if we're in ingestion-only mode (no search query)
|
||||||
|
has_search_query = bool((self.search_query or "").strip())
|
||||||
|
if not has_search_query:
|
||||||
|
logger.debug("Ingestion-only mode activated: search operations will be skipped")
|
||||||
|
logger.debug("Starting ingestion mode...")
|
||||||
|
|
||||||
|
logger.warning(f"Embedding: {self.embedding}")
|
||||||
self._add_documents_to_vector_store(client=client)
|
self._add_documents_to_vector_store(client=client)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
@ -611,33 +659,185 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
|
|
||||||
This method handles the complete document ingestion pipeline:
|
This method handles the complete document ingestion pipeline:
|
||||||
- Prepares document data and metadata
|
- Prepares document data and metadata
|
||||||
- Generates vector embeddings
|
- Generates vector embeddings using the selected model
|
||||||
- Creates appropriate index mappings with dynamic field names
|
- Creates appropriate index mappings with dynamic field names
|
||||||
- Bulk inserts documents with vectors and model tracking
|
- Bulk inserts documents with vectors and model tracking
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client: OpenSearch client for performing operations
|
client: OpenSearch client for performing operations
|
||||||
"""
|
"""
|
||||||
|
logger.debug("[INGESTION] _add_documents_to_vector_store called")
|
||||||
# Convert DataFrame to Data if needed using parent's method
|
# Convert DataFrame to Data if needed using parent's method
|
||||||
self.ingest_data = self._prepare_ingest_data()
|
self.ingest_data = self._prepare_ingest_data()
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[INGESTION] ingest_data type: "
|
||||||
|
f"{type(self.ingest_data)}, length: {len(self.ingest_data) if self.ingest_data else 0}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"[INGESTION] ingest_data content: "
|
||||||
|
f"{self.ingest_data[:2] if self.ingest_data and len(self.ingest_data) > 0 else 'empty'}"
|
||||||
|
)
|
||||||
|
|
||||||
docs = self.ingest_data or []
|
docs = self.ingest_data or []
|
||||||
if not docs:
|
if not docs:
|
||||||
self.log("No documents to ingest.")
|
logger.debug("Ingestion complete: No documents provided")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get embedding model name
|
if not self.embedding:
|
||||||
embedding_model = self._get_embedding_model_name()
|
msg = "Embedding handle is required to embed documents."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Normalize embedding to list first
|
||||||
|
embeddings_list = self.embedding if isinstance(self.embedding, list) else [self.embedding]
|
||||||
|
|
||||||
|
# Filter out None values (fail-safe mode) - do this BEFORE checking if empty
|
||||||
|
embeddings_list = [e for e in embeddings_list if e is not None]
|
||||||
|
|
||||||
|
# NOW check if we have any valid embeddings left after filtering
|
||||||
|
if not embeddings_list:
|
||||||
|
logger.warning("All embeddings returned None (fail-safe mode enabled). Skipping document ingestion.")
|
||||||
|
self.log("Embedding returned None (fail-safe mode enabled). Skipping document ingestion.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"[INGESTION] Valid embeddings after filtering: {len(embeddings_list)}")
|
||||||
|
self.log(f"Available embedding models: {len(embeddings_list)}")
|
||||||
|
|
||||||
|
# Select the embedding to use for ingestion
|
||||||
|
selected_embedding = None
|
||||||
|
embedding_model = None
|
||||||
|
|
||||||
|
# If embedding_model_name is specified, find matching embedding
|
||||||
|
if hasattr(self, "embedding_model_name") and self.embedding_model_name and self.embedding_model_name.strip():
|
||||||
|
target_model_name = self.embedding_model_name.strip()
|
||||||
|
self.log(f"Looking for embedding model: {target_model_name}")
|
||||||
|
|
||||||
|
for emb_obj in embeddings_list:
|
||||||
|
# Check all possible model identifiers (deployment, model, model_id, model_name)
|
||||||
|
# Also check available_models list from EmbeddingsWithModels
|
||||||
|
possible_names = []
|
||||||
|
deployment = getattr(emb_obj, "deployment", None)
|
||||||
|
model = getattr(emb_obj, "model", None)
|
||||||
|
model_id = getattr(emb_obj, "model_id", None)
|
||||||
|
model_name = getattr(emb_obj, "model_name", None)
|
||||||
|
available_models_attr = getattr(emb_obj, "available_models", None)
|
||||||
|
|
||||||
|
if deployment:
|
||||||
|
possible_names.append(str(deployment))
|
||||||
|
if model:
|
||||||
|
possible_names.append(str(model))
|
||||||
|
if model_id:
|
||||||
|
possible_names.append(str(model_id))
|
||||||
|
if model_name:
|
||||||
|
possible_names.append(str(model_name))
|
||||||
|
|
||||||
|
# Also add combined identifier
|
||||||
|
if deployment and model and deployment != model:
|
||||||
|
possible_names.append(f"{deployment}:{model}")
|
||||||
|
|
||||||
|
# Add all models from available_models dict
|
||||||
|
if available_models_attr and isinstance(available_models_attr, dict):
|
||||||
|
possible_names.extend(
|
||||||
|
str(model_key).strip()
|
||||||
|
for model_key in available_models_attr
|
||||||
|
if model_key and str(model_key).strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Match if target matches any of the possible names
|
||||||
|
if target_model_name in possible_names:
|
||||||
|
# Check if target is in available_models dict - use dedicated instance
|
||||||
|
if (
|
||||||
|
available_models_attr
|
||||||
|
and isinstance(available_models_attr, dict)
|
||||||
|
and target_model_name in available_models_attr
|
||||||
|
):
|
||||||
|
# Use the dedicated embedding instance from the dict
|
||||||
|
selected_embedding = available_models_attr[target_model_name]
|
||||||
|
embedding_model = target_model_name
|
||||||
|
self.log(f"Found dedicated embedding instance for '{embedding_model}' in available_models dict")
|
||||||
|
else:
|
||||||
|
# Traditional identifier match
|
||||||
|
selected_embedding = emb_obj
|
||||||
|
embedding_model = self._get_embedding_model_name(emb_obj)
|
||||||
|
self.log(f"Found matching embedding model: {embedding_model} (matched on: {target_model_name})")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not selected_embedding:
|
||||||
|
# Build detailed list of available embeddings with all their identifiers
|
||||||
|
available_info = []
|
||||||
|
for idx, emb in enumerate(embeddings_list):
|
||||||
|
emb_type = type(emb).__name__
|
||||||
|
identifiers = []
|
||||||
|
deployment = getattr(emb, "deployment", None)
|
||||||
|
model = getattr(emb, "model", None)
|
||||||
|
model_id = getattr(emb, "model_id", None)
|
||||||
|
model_name = getattr(emb, "model_name", None)
|
||||||
|
available_models_attr = getattr(emb, "available_models", None)
|
||||||
|
|
||||||
|
if deployment:
|
||||||
|
identifiers.append(f"deployment='{deployment}'")
|
||||||
|
if model:
|
||||||
|
identifiers.append(f"model='{model}'")
|
||||||
|
if model_id:
|
||||||
|
identifiers.append(f"model_id='{model_id}'")
|
||||||
|
if model_name:
|
||||||
|
identifiers.append(f"model_name='{model_name}'")
|
||||||
|
|
||||||
|
# Add combined identifier as an option
|
||||||
|
if deployment and model and deployment != model:
|
||||||
|
identifiers.append(f"combined='{deployment}:{model}'")
|
||||||
|
|
||||||
|
# Add available_models dict if present
|
||||||
|
if available_models_attr and isinstance(available_models_attr, dict):
|
||||||
|
identifiers.append(f"available_models={list(available_models_attr.keys())}")
|
||||||
|
|
||||||
|
available_info.append(
|
||||||
|
f" [{idx}] {emb_type}: {', '.join(identifiers) if identifiers else 'No identifiers'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = (
|
||||||
|
f"Embedding model '{target_model_name}' not found in available embeddings.\n\n"
|
||||||
|
f"Available embeddings:\n" + "\n".join(available_info) + "\n\n"
|
||||||
|
"Please set 'embedding_model_name' to one of the identifier values shown above "
|
||||||
|
"(use the value after the '=' sign, without quotes).\n"
|
||||||
|
"For duplicate deployments, use the 'combined' format.\n"
|
||||||
|
"Or leave it empty to use the first embedding."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
else:
|
||||||
|
# Use first embedding if no model name specified
|
||||||
|
selected_embedding = embeddings_list[0]
|
||||||
|
embedding_model = self._get_embedding_model_name(selected_embedding)
|
||||||
|
self.log(f"No embedding_model_name specified, using first embedding: {embedding_model}")
|
||||||
|
|
||||||
dynamic_field_name = get_embedding_field_name(embedding_model)
|
dynamic_field_name = get_embedding_field_name(embedding_model)
|
||||||
|
|
||||||
self.log(f"Using embedding model: {embedding_model}")
|
logger.info(f"Selected embedding model for ingestion: '{embedding_model}'")
|
||||||
|
self.log(f"Using embedding model for ingestion: {embedding_model}")
|
||||||
self.log(f"Dynamic vector field: {dynamic_field_name}")
|
self.log(f"Dynamic vector field: {dynamic_field_name}")
|
||||||
|
|
||||||
|
# Log embedding details for debugging
|
||||||
|
if hasattr(selected_embedding, "deployment"):
|
||||||
|
logger.info(f"Embedding deployment: {selected_embedding.deployment}")
|
||||||
|
if hasattr(selected_embedding, "model"):
|
||||||
|
logger.info(f"Embedding model: {selected_embedding.model}")
|
||||||
|
if hasattr(selected_embedding, "model_id"):
|
||||||
|
logger.info(f"Embedding model_id: {selected_embedding.model_id}")
|
||||||
|
if hasattr(selected_embedding, "dimensions"):
|
||||||
|
logger.info(f"Embedding dimensions: {selected_embedding.dimensions}")
|
||||||
|
if hasattr(selected_embedding, "available_models"):
|
||||||
|
logger.info(f"Embedding available_models: {selected_embedding.available_models}")
|
||||||
|
|
||||||
|
# No model switching needed - each model in available_models has its own dedicated instance
|
||||||
|
# The selected_embedding is already configured correctly for the target model
|
||||||
|
logger.info(f"Using embedding instance for '{embedding_model}' - pre-configured and ready to use")
|
||||||
|
|
||||||
# Extract texts and metadata from documents
|
# Extract texts and metadata from documents
|
||||||
texts = []
|
texts = []
|
||||||
metadatas = []
|
metadatas = []
|
||||||
# Process docs_metadata table input into a dict
|
# Process docs_metadata table input into a dict
|
||||||
additional_metadata = {}
|
additional_metadata = {}
|
||||||
|
logger.debug(f"[LF] Docs metadata {self.docs_metadata}")
|
||||||
if hasattr(self, "docs_metadata") and self.docs_metadata:
|
if hasattr(self, "docs_metadata") and self.docs_metadata:
|
||||||
logger.info(f"[LF] Docs metadata {self.docs_metadata}")
|
logger.info(f"[LF] Docs metadata {self.docs_metadata}")
|
||||||
if isinstance(self.docs_metadata[-1], Data):
|
if isinstance(self.docs_metadata[-1], Data):
|
||||||
|
|
@ -664,23 +864,27 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
|
|
||||||
metadatas.append(data_copy)
|
metadatas.append(data_copy)
|
||||||
self.log(metadatas)
|
self.log(metadatas)
|
||||||
if not self.embedding:
|
|
||||||
msg = "Embedding handle is required to embed documents."
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
# Generate embeddings (threaded for concurrency) with retries
|
# Generate embeddings (threaded for concurrency) with retries
|
||||||
def embed_chunk(chunk_text: str) -> list[float]:
|
def embed_chunk(chunk_text: str) -> list[float]:
|
||||||
return self.embedding.embed_documents([chunk_text])[0]
|
return selected_embedding.embed_documents([chunk_text])[0]
|
||||||
|
|
||||||
vectors: Optional[List[List[float]]] = None
|
vectors: list[list[float]] | None = None
|
||||||
last_exception: Optional[Exception] = None
|
last_exception: Exception | None = None
|
||||||
delay = 1.0
|
delay = 1.0
|
||||||
attempts = 0
|
attempts = 0
|
||||||
|
max_attempts = 3
|
||||||
|
|
||||||
while attempts < 3:
|
while attempts < max_attempts:
|
||||||
attempts += 1
|
attempts += 1
|
||||||
try:
|
try:
|
||||||
max_workers = min(max(len(texts), 1), 8)
|
# Restrict concurrency for IBM/Watsonx models to avoid rate limits
|
||||||
|
is_ibm = (embedding_model and "ibm" in str(embedding_model).lower()) or (
|
||||||
|
selected_embedding and "watsonx" in type(selected_embedding).__name__.lower()
|
||||||
|
)
|
||||||
|
logger.debug(f"Is IBM: {is_ibm}")
|
||||||
|
max_workers = 1 if is_ibm else min(max(len(texts), 1), 8)
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
futures = {executor.submit(embed_chunk, chunk): idx for idx, chunk in enumerate(texts)}
|
futures = {executor.submit(embed_chunk, chunk): idx for idx, chunk in enumerate(texts)}
|
||||||
vectors = [None] * len(texts)
|
vectors = [None] * len(texts)
|
||||||
|
|
@ -690,16 +894,17 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
break
|
break
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
last_exception = exc
|
last_exception = exc
|
||||||
if attempts >= 3:
|
if attempts >= max_attempts:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Embedding generation failed after retries",
|
f"Embedding generation failed for model {embedding_model} after retries",
|
||||||
error=str(exc),
|
error=str(exc),
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Threaded embedding generation failed (attempt %s/%s), retrying in %.1fs",
|
"Threaded embedding generation failed for model %s (attempt %s/%s), retrying in %.1fs",
|
||||||
|
embedding_model,
|
||||||
attempts,
|
attempts,
|
||||||
3,
|
max_attempts,
|
||||||
delay,
|
delay,
|
||||||
)
|
)
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
|
|
@ -707,11 +912,13 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
|
|
||||||
if vectors is None:
|
if vectors is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Embedding generation failed: {last_exception}" if last_exception else "Embedding generation failed"
|
f"Embedding generation failed for {embedding_model}: {last_exception}"
|
||||||
|
if last_exception
|
||||||
|
else f"Embedding generation failed for {embedding_model}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not vectors:
|
if not vectors:
|
||||||
self.log("No vectors generated from documents.")
|
self.log(f"No vectors generated from documents for model {embedding_model}.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get vector dimension for mapping
|
# Get vector dimension for mapping
|
||||||
|
|
@ -746,9 +953,7 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
client.indices.create(index=self.index_name, body=mapping)
|
client.indices.create(index=self.index_name, body=mapping)
|
||||||
except RequestError as creation_error:
|
except RequestError as creation_error:
|
||||||
if creation_error.error != "resource_already_exists_exception":
|
if creation_error.error != "resource_already_exists_exception":
|
||||||
logger.warning(
|
logger.warning(f"Failed to create index '{self.index_name}': {creation_error}")
|
||||||
f"Failed to create index '{self.index_name}': {creation_error}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure the dynamic field exists in the index
|
# Ensure the dynamic field exists in the index
|
||||||
self._ensure_embedding_field_mapping(
|
self._ensure_embedding_field_mapping(
|
||||||
|
|
@ -763,6 +968,8 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log(f"Indexing {len(texts)} documents into '{self.index_name}' with model '{embedding_model}'...")
|
self.log(f"Indexing {len(texts)} documents into '{self.index_name}' with model '{embedding_model}'...")
|
||||||
|
logger.info(f"Will store embeddings in field: {dynamic_field_name}")
|
||||||
|
logger.info(f"Will tag documents with embedding_model: {embedding_model}")
|
||||||
|
|
||||||
# Use the bulk ingestion with model tracking
|
# Use the bulk ingestion with model tracking
|
||||||
return_ids = self._bulk_ingest_embeddings(
|
return_ids = self._bulk_ingest_embeddings(
|
||||||
|
|
@ -779,6 +986,9 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
)
|
)
|
||||||
self.log(metadatas)
|
self.log(metadatas)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Ingestion complete: Successfully indexed {len(return_ids)} documents with model '{embedding_model}'"
|
||||||
|
)
|
||||||
self.log(f"Successfully indexed {len(return_ids)} documents with model {embedding_model}.")
|
self.log(f"Successfully indexed {len(return_ids)} documents with model {embedding_model}.")
|
||||||
|
|
||||||
# ---------- helpers for filters ----------
|
# ---------- helpers for filters ----------
|
||||||
|
|
@ -853,7 +1063,7 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
context_clauses.append({"terms": {field: values}})
|
context_clauses.append({"terms": {field: values}})
|
||||||
return context_clauses
|
return context_clauses
|
||||||
|
|
||||||
def _detect_available_models(self, client: OpenSearch, filter_clauses: list[dict] = None) -> list[str]:
|
def _detect_available_models(self, client: OpenSearch, filter_clauses: list[dict] | None = None) -> list[str]:
|
||||||
"""Detect which embedding models have documents in the index.
|
"""Detect which embedding models have documents in the index.
|
||||||
|
|
||||||
Uses aggregation to find all unique embedding_model values, optionally
|
Uses aggregation to find all unique embedding_model values, optionally
|
||||||
|
|
@ -867,26 +1077,13 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
List of embedding model names found in the index
|
List of embedding model names found in the index
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
agg_query = {
|
agg_query = {"size": 0, "aggs": {"embedding_models": {"terms": {"field": "embedding_model", "size": 10}}}}
|
||||||
"size": 0,
|
|
||||||
"aggs": {
|
|
||||||
"embedding_models": {
|
|
||||||
"terms": {
|
|
||||||
"field": "embedding_model",
|
|
||||||
"size": 10
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Apply filters to model detection if any exist
|
# Apply filters to model detection if any exist
|
||||||
if filter_clauses:
|
if filter_clauses:
|
||||||
agg_query["query"] = {
|
agg_query["query"] = {"bool": {"filter": filter_clauses}}
|
||||||
"bool": {
|
|
||||||
"filter": filter_clauses
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
logger.debug(f"Model detection query: {agg_query}")
|
||||||
result = client.search(
|
result = client.search(
|
||||||
index=self.index_name,
|
index=self.index_name,
|
||||||
body=agg_query,
|
body=agg_query,
|
||||||
|
|
@ -895,21 +1092,33 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
buckets = result.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
|
buckets = result.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
|
||||||
models = [b["key"] for b in buckets if b["key"]]
|
models = [b["key"] for b in buckets if b["key"]]
|
||||||
|
|
||||||
|
# Log detailed bucket info for debugging
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Detected embedding models in corpus: {models}"
|
f"Detected embedding models in corpus: {models}"
|
||||||
+ (f" (with {len(filter_clauses)} filters)" if filter_clauses else "")
|
+ (f" (with {len(filter_clauses)} filters)" if filter_clauses else "")
|
||||||
)
|
)
|
||||||
return models
|
if not models:
|
||||||
except Exception as e:
|
total_hits = result.get("hits", {}).get("total", {})
|
||||||
|
total_count = total_hits.get("value", 0) if isinstance(total_hits, dict) else total_hits
|
||||||
|
logger.warning(
|
||||||
|
f"No embedding_model values found in index '{self.index_name}'. "
|
||||||
|
f"Total docs in index: {total_count}. "
|
||||||
|
f"This may indicate documents were indexed without the embedding_model field."
|
||||||
|
)
|
||||||
|
except (OpenSearchException, KeyError, ValueError) as e:
|
||||||
logger.warning(f"Failed to detect embedding models: {e}")
|
logger.warning(f"Failed to detect embedding models: {e}")
|
||||||
# Fallback to current model
|
# Fallback to current model
|
||||||
return [self._get_embedding_model_name()]
|
fallback_model = self._get_embedding_model_name()
|
||||||
|
logger.info(f"Using fallback model: {fallback_model}")
|
||||||
|
return [fallback_model]
|
||||||
|
else:
|
||||||
|
return models
|
||||||
|
|
||||||
def _get_index_properties(self, client: OpenSearch) -> dict[str, Any] | None:
|
def _get_index_properties(self, client: OpenSearch) -> dict[str, Any] | None:
|
||||||
"""Retrieve flattened mapping properties for the current index."""
|
"""Retrieve flattened mapping properties for the current index."""
|
||||||
try:
|
try:
|
||||||
mapping = client.indices.get_mapping(index=self.index_name)
|
mapping = client.indices.get_mapping(index=self.index_name)
|
||||||
except Exception as e:
|
except OpenSearchException as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to fetch mapping for index '{self.index_name}': {e}. Proceeding without mapping metadata."
|
f"Failed to fetch mapping for index '{self.index_name}': {e}. Proceeding without mapping metadata."
|
||||||
)
|
)
|
||||||
|
|
@ -927,9 +1136,7 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
if not field_name:
|
if not field_name:
|
||||||
return False
|
return False
|
||||||
if properties is None:
|
if properties is None:
|
||||||
logger.warning(
|
logger.warning(f"Mapping metadata unavailable; assuming field '{field_name}' is usable.")
|
||||||
f"Mapping metadata unavailable; assuming field '{field_name}' is usable."
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
field_def = properties.get(field_name)
|
field_def = properties.get(field_name)
|
||||||
if not isinstance(field_def, dict):
|
if not isinstance(field_def, dict):
|
||||||
|
|
@ -938,10 +1145,35 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
nested_props = field_def.get("properties")
|
nested_props = field_def.get("properties")
|
||||||
if isinstance(nested_props, dict) and nested_props.get("type") == "knn_vector":
|
return bool(isinstance(nested_props, dict) and nested_props.get("type") == "knn_vector")
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
def _get_field_dimension(self, properties: dict[str, Any] | None, field_name: str) -> int | None:
|
||||||
|
"""Get the dimension of a knn_vector field from the index mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
properties: Index properties from mapping
|
||||||
|
field_name: Name of the vector field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dimension of the field, or None if not found
|
||||||
|
"""
|
||||||
|
if not field_name or properties is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
field_def = properties.get(field_name)
|
||||||
|
if not isinstance(field_def, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check direct knn_vector field
|
||||||
|
if field_def.get("type") == "knn_vector":
|
||||||
|
return field_def.get("dimension")
|
||||||
|
|
||||||
|
# Check nested properties
|
||||||
|
nested_props = field_def.get("properties")
|
||||||
|
if isinstance(nested_props, dict) and nested_props.get("type") == "knn_vector":
|
||||||
|
return nested_props.get("dimension")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
# ---------- search (multi-model hybrid) ----------
|
# ---------- search (multi-model hybrid) ----------
|
||||||
def search(self, query: str | None = None) -> list[dict[str, Any]]:
|
def search(self, query: str | None = None) -> list[dict[str, Any]]:
|
||||||
|
|
@ -985,6 +1217,11 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
msg = "Embedding is required to run hybrid search (KNN + keyword)."
|
msg = "Embedding is required to run hybrid search (KNN + keyword)."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Check if embedding is None (fail-safe mode)
|
||||||
|
if self.embedding is None or (isinstance(self.embedding, list) and all(e is None for e in self.embedding)):
|
||||||
|
logger.error("Embedding returned None (fail-safe mode enabled). Cannot perform search.")
|
||||||
|
return []
|
||||||
|
|
||||||
# Build filter clauses first so we can use them in model detection
|
# Build filter clauses first so we can use them in model detection
|
||||||
filter_clauses = self._coerce_filter_clauses(filter_obj)
|
filter_clauses = self._coerce_filter_clauses(filter_obj)
|
||||||
|
|
||||||
|
|
@ -995,42 +1232,166 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
logger.warning("No embedding models found in index, using current model")
|
logger.warning("No embedding models found in index, using current model")
|
||||||
available_models = [self._get_embedding_model_name()]
|
available_models = [self._get_embedding_model_name()]
|
||||||
|
|
||||||
# Generate embeddings for ALL detected models in parallel
|
# Generate embeddings for ALL detected models
|
||||||
query_embeddings = {}
|
query_embeddings = {}
|
||||||
|
|
||||||
# Note: Langflow is synchronous, so we can't use true async here
|
# Normalize embedding to list
|
||||||
# But we log the intent for parallel processing
|
embeddings_list = self.embedding if isinstance(self.embedding, list) else [self.embedding]
|
||||||
logger.info(f"Generating embeddings for {len(available_models)} models")
|
# Filter out None values (fail-safe mode)
|
||||||
|
embeddings_list = [e for e in embeddings_list if e is not None]
|
||||||
|
|
||||||
original_model_attr = getattr(self.embedding, "model", None)
|
if not embeddings_list:
|
||||||
original_deployment_attr = getattr(self.embedding, "deployment", None)
|
logger.error(
|
||||||
original_dimensions_attr = getattr(self.embedding, "dimensions", None)
|
"No valid embeddings available after filtering None values (fail-safe mode). Cannot perform search."
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create a comprehensive map of model names to embedding objects
|
||||||
|
# Check all possible identifiers (deployment, model, model_id, model_name)
|
||||||
|
# Also leverage available_models list from EmbeddingsWithModels
|
||||||
|
# Handle duplicate identifiers by creating combined keys
|
||||||
|
embedding_by_model = {}
|
||||||
|
identifier_conflicts = {} # Track which identifiers have conflicts
|
||||||
|
|
||||||
|
for idx, emb_obj in enumerate(embeddings_list):
|
||||||
|
# Get all possible identifiers for this embedding
|
||||||
|
identifiers = []
|
||||||
|
deployment = getattr(emb_obj, "deployment", None)
|
||||||
|
model = getattr(emb_obj, "model", None)
|
||||||
|
model_id = getattr(emb_obj, "model_id", None)
|
||||||
|
model_name = getattr(emb_obj, "model_name", None)
|
||||||
|
dimensions = getattr(emb_obj, "dimensions", None)
|
||||||
|
available_models_attr = getattr(emb_obj, "available_models", None)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Embedding object {idx}: deployment={deployment}, model={model}, "
|
||||||
|
f"model_id={model_id}, model_name={model_name}, dimensions={dimensions}, "
|
||||||
|
f"available_models={available_models_attr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If this embedding has available_models dict, map all models to their dedicated instances
|
||||||
|
if available_models_attr and isinstance(available_models_attr, dict):
|
||||||
|
logger.info(
|
||||||
|
f"Embedding object {idx} provides {len(available_models_attr)} models via available_models dict"
|
||||||
|
)
|
||||||
|
for model_name_key, dedicated_embedding in available_models_attr.items():
|
||||||
|
if model_name_key and str(model_name_key).strip():
|
||||||
|
model_str = str(model_name_key).strip()
|
||||||
|
if model_str not in embedding_by_model:
|
||||||
|
# Use the dedicated embedding instance from the dict
|
||||||
|
embedding_by_model[model_str] = dedicated_embedding
|
||||||
|
logger.info(f"Mapped available model '{model_str}' to dedicated embedding instance")
|
||||||
|
else:
|
||||||
|
# Conflict detected - track it
|
||||||
|
if model_str not in identifier_conflicts:
|
||||||
|
identifier_conflicts[model_str] = [embedding_by_model[model_str]]
|
||||||
|
identifier_conflicts[model_str].append(dedicated_embedding)
|
||||||
|
logger.warning(f"Available model '{model_str}' has conflict - used by multiple embeddings")
|
||||||
|
|
||||||
|
# Also map traditional identifiers (for backward compatibility)
|
||||||
|
if deployment:
|
||||||
|
identifiers.append(str(deployment))
|
||||||
|
if model:
|
||||||
|
identifiers.append(str(model))
|
||||||
|
if model_id:
|
||||||
|
identifiers.append(str(model_id))
|
||||||
|
if model_name:
|
||||||
|
identifiers.append(str(model_name))
|
||||||
|
|
||||||
|
# Map all identifiers to this embedding object
|
||||||
|
for identifier in identifiers:
|
||||||
|
if identifier not in embedding_by_model:
|
||||||
|
embedding_by_model[identifier] = emb_obj
|
||||||
|
logger.info(f"Mapped identifier '{identifier}' to embedding object {idx}")
|
||||||
|
else:
|
||||||
|
# Conflict detected - track it
|
||||||
|
if identifier not in identifier_conflicts:
|
||||||
|
identifier_conflicts[identifier] = [embedding_by_model[identifier]]
|
||||||
|
identifier_conflicts[identifier].append(emb_obj)
|
||||||
|
logger.warning(f"Identifier '{identifier}' has conflict - used by multiple embeddings")
|
||||||
|
|
||||||
|
# For embeddings with model+deployment, create combined identifier
|
||||||
|
# This helps when deployment is the same but model differs
|
||||||
|
if deployment and model and deployment != model:
|
||||||
|
combined_id = f"{deployment}:{model}"
|
||||||
|
if combined_id not in embedding_by_model:
|
||||||
|
embedding_by_model[combined_id] = emb_obj
|
||||||
|
logger.info(f"Created combined identifier '{combined_id}' for embedding object {idx}")
|
||||||
|
|
||||||
|
# Log conflicts
|
||||||
|
if identifier_conflicts:
|
||||||
|
logger.warning(
|
||||||
|
f"Found {len(identifier_conflicts)} conflicting identifiers. "
|
||||||
|
f"Consider using combined format 'deployment:model' or specifying unique model names."
|
||||||
|
)
|
||||||
|
for conflict_id, emb_list in identifier_conflicts.items():
|
||||||
|
logger.warning(f" Conflict on '{conflict_id}': {len(emb_list)} embeddings use this identifier")
|
||||||
|
|
||||||
|
logger.info(f"Generating embeddings for {len(available_models)} models in index")
|
||||||
|
logger.info(f"Available embedding identifiers: {list(embedding_by_model.keys())}")
|
||||||
|
self.log(f"[SEARCH] Models detected in index: {available_models}")
|
||||||
|
self.log(f"[SEARCH] Available embedding identifiers: {list(embedding_by_model.keys())}")
|
||||||
|
|
||||||
|
# Track matching status for debugging
|
||||||
|
matched_models = []
|
||||||
|
unmatched_models = []
|
||||||
|
|
||||||
for model_name in available_models:
|
for model_name in available_models:
|
||||||
try:
|
try:
|
||||||
# In a real async environment, these would run in parallel
|
# Check if we have an embedding object for this model
|
||||||
# For now, they run sequentially
|
if model_name in embedding_by_model:
|
||||||
if hasattr(self.embedding, "model"):
|
# Use the matching embedding object directly
|
||||||
setattr(self.embedding, "model", model_name)
|
emb_obj = embedding_by_model[model_name]
|
||||||
if hasattr(self.embedding, "deployment"):
|
emb_deployment = getattr(emb_obj, "deployment", None)
|
||||||
setattr(self.embedding, "deployment", model_name)
|
emb_model = getattr(emb_obj, "model", None)
|
||||||
if hasattr(self.embedding, "dimensions"):
|
emb_model_id = getattr(emb_obj, "model_id", None)
|
||||||
setattr(self.embedding, "dimensions", None)
|
emb_dimensions = getattr(emb_obj, "dimensions", None)
|
||||||
vec = self.embedding.embed_query(q)
|
emb_available_models = getattr(emb_obj, "available_models", None)
|
||||||
query_embeddings[model_name] = vec
|
|
||||||
logger.info(f"Generated embedding for model: {model_name}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to generate embedding for {model_name}: {e}")
|
|
||||||
|
|
||||||
if hasattr(self.embedding, "model"):
|
logger.info(
|
||||||
setattr(self.embedding, "model", original_model_attr)
|
f"Using embedding object for model '{model_name}': "
|
||||||
if hasattr(self.embedding, "deployment"):
|
f"deployment={emb_deployment}, model={emb_model}, model_id={emb_model_id}, "
|
||||||
setattr(self.embedding, "deployment", original_deployment_attr)
|
f"dimensions={emb_dimensions}"
|
||||||
if hasattr(self.embedding, "dimensions"):
|
)
|
||||||
setattr(self.embedding, "dimensions", original_dimensions_attr)
|
|
||||||
|
# Check if this is a dedicated instance from available_models dict
|
||||||
|
if emb_available_models and isinstance(emb_available_models, dict):
|
||||||
|
logger.info(
|
||||||
|
f"Model '{model_name}' using dedicated instance from available_models dict "
|
||||||
|
f"(pre-configured with correct model and dimensions)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the embedding instance directly - no model switching needed!
|
||||||
|
vec = emb_obj.embed_query(q)
|
||||||
|
query_embeddings[model_name] = vec
|
||||||
|
matched_models.append(model_name)
|
||||||
|
logger.info(f"Generated embedding for model: {model_name} (actual dimensions: {len(vec)})")
|
||||||
|
self.log(f"[MATCH] Model '{model_name}' - generated {len(vec)}-dim embedding")
|
||||||
|
else:
|
||||||
|
# No matching embedding found for this model
|
||||||
|
unmatched_models.append(model_name)
|
||||||
|
logger.warning(
|
||||||
|
f"No matching embedding found for model '{model_name}'. "
|
||||||
|
f"This model will be skipped. Available identifiers: {list(embedding_by_model.keys())}"
|
||||||
|
)
|
||||||
|
self.log(f"[NO MATCH] Model '{model_name}' - available: {list(embedding_by_model.keys())}")
|
||||||
|
except (RuntimeError, ValueError, ConnectionError, TimeoutError, AttributeError, KeyError) as e:
|
||||||
|
logger.warning(f"Failed to generate embedding for {model_name}: {e}")
|
||||||
|
self.log(f"[ERROR] Embedding generation failed for '{model_name}': {e}")
|
||||||
|
|
||||||
|
# Log summary of model matching
|
||||||
|
logger.info(f"Model matching summary: {len(matched_models)} matched, {len(unmatched_models)} unmatched")
|
||||||
|
self.log(f"[SUMMARY] Model matching: {len(matched_models)} matched, {len(unmatched_models)} unmatched")
|
||||||
|
if unmatched_models:
|
||||||
|
self.log(f"[WARN] Unmatched models in index: {unmatched_models}")
|
||||||
|
|
||||||
if not query_embeddings:
|
if not query_embeddings:
|
||||||
msg = "Failed to generate embeddings for any model"
|
msg = (
|
||||||
|
f"Failed to generate embeddings for any model. "
|
||||||
|
f"Index has models: {available_models}, but no matching embedding objects found. "
|
||||||
|
f"Available embedding identifiers: {list(embedding_by_model.keys())}"
|
||||||
|
)
|
||||||
|
self.log(f"[FAIL] Search failed: {msg}")
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
index_properties = self._get_index_properties(client)
|
index_properties = self._get_index_properties(client)
|
||||||
|
|
@ -1051,6 +1412,7 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
for model_name, embedding_vector in query_embeddings.items():
|
for model_name, embedding_vector in query_embeddings.items():
|
||||||
field_name = get_embedding_field_name(model_name)
|
field_name = get_embedding_field_name(model_name)
|
||||||
selected_field = field_name
|
selected_field = field_name
|
||||||
|
vector_dim = len(embedding_vector)
|
||||||
|
|
||||||
# Only use the expected dynamic field - no legacy fallback
|
# Only use the expected dynamic field - no legacy fallback
|
||||||
# This prevents dimension mismatches between models
|
# This prevents dimension mismatches between models
|
||||||
|
|
@ -1059,8 +1421,24 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
f"Skipping model {model_name}: field '{field_name}' is not mapped as knn_vector. "
|
f"Skipping model {model_name}: field '{field_name}' is not mapped as knn_vector. "
|
||||||
f"Documents must be indexed with this embedding model before querying."
|
f"Documents must be indexed with this embedding model before querying."
|
||||||
)
|
)
|
||||||
|
self.log(f"[SKIP] Field '{selected_field}' not a knn_vector - skipping model '{model_name}'")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Validate vector dimensions match the field dimensions
|
||||||
|
field_dim = self._get_field_dimension(index_properties, selected_field)
|
||||||
|
if field_dim is not None and field_dim != vector_dim:
|
||||||
|
logger.error(
|
||||||
|
f"Dimension mismatch for model '{model_name}': "
|
||||||
|
f"Query vector has {vector_dim} dimensions but field '{selected_field}' expects {field_dim}. "
|
||||||
|
f"Skipping this model to prevent search errors."
|
||||||
|
)
|
||||||
|
self.log(f"[DIM MISMATCH] Model '{model_name}': query={vector_dim} vs field={field_dim} - skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Adding KNN query for model '{model_name}': field='{selected_field}', "
|
||||||
|
f"query_dims={vector_dim}, field_dims={field_dim or 'unknown'}"
|
||||||
|
)
|
||||||
embedding_fields.append(selected_field)
|
embedding_fields.append(selected_field)
|
||||||
|
|
||||||
base_query = {
|
base_query = {
|
||||||
|
|
@ -1091,14 +1469,16 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
"This may indicate an empty index or missing field mappings. "
|
"This may indicate an empty index or missing field mappings. "
|
||||||
"Returning empty search results."
|
"Returning empty search results."
|
||||||
)
|
)
|
||||||
|
self.log(
|
||||||
|
f"[WARN] No valid KNN queries could be built. "
|
||||||
|
f"Query embeddings generated: {list(query_embeddings.keys())}, "
|
||||||
|
f"but no matching knn_vector fields found in index."
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Build exists filter - document must have at least one embedding field
|
# Build exists filter - document must have at least one embedding field
|
||||||
exists_any_embedding = {
|
exists_any_embedding = {
|
||||||
"bool": {
|
"bool": {"should": [{"exists": {"field": f}} for f in set(embedding_fields)], "minimum_should_match": 1}
|
||||||
"should": [{"exists": {"field": f}} for f in set(embedding_fields)],
|
|
||||||
"minimum_should_match": 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Combine user filters with exists filter
|
# Combine user filters with exists filter
|
||||||
|
|
@ -1117,7 +1497,7 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
"dis_max": {
|
"dis_max": {
|
||||||
"tie_breaker": 0.0, # Take only the best match, no blending
|
"tie_breaker": 0.0, # Take only the best match, no blending
|
||||||
"boost": 0.7, # 70% weight for semantic search
|
"boost": 0.7, # 70% weight for semantic search
|
||||||
"queries": knn_queries_with_candidates
|
"queries": knn_queries_with_candidates,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -1158,13 +1538,15 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
body["min_score"] = score_threshold
|
body["min_score"] = score_threshold
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Executing multi-model hybrid search with {len(knn_queries_with_candidates)} embedding models"
|
f"Executing multi-model hybrid search with {len(knn_queries_with_candidates)} embedding models: "
|
||||||
|
f"{list(query_embeddings.keys())}"
|
||||||
)
|
)
|
||||||
|
self.log(f"[EXEC] Executing search with {len(knn_queries_with_candidates)} KNN queries, limit={limit}")
|
||||||
|
self.log(f"[EXEC] Embedding models used: {list(query_embeddings.keys())}")
|
||||||
|
self.log(f"[EXEC] KNN fields being queried: {embedding_fields}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = client.search(
|
resp = client.search(index=self.index_name, body=body, params={"terminate_after": 0})
|
||||||
index=self.index_name, body=body, params={"terminate_after": 0}
|
|
||||||
)
|
|
||||||
except RequestError as e:
|
except RequestError as e:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
lowered = error_message.lower()
|
lowered = error_message.lower()
|
||||||
|
|
@ -1215,6 +1597,16 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
hits = resp.get("hits", {}).get("hits", [])
|
hits = resp.get("hits", {}).get("hits", [])
|
||||||
|
|
||||||
logger.info(f"Found {len(hits)} results")
|
logger.info(f"Found {len(hits)} results")
|
||||||
|
self.log(f"[RESULT] Search complete: {len(hits)} results found")
|
||||||
|
|
||||||
|
if len(hits) == 0:
|
||||||
|
self.log(
|
||||||
|
f"[EMPTY] Debug info: "
|
||||||
|
f"models_in_index={available_models}, "
|
||||||
|
f"matched_models={matched_models}, "
|
||||||
|
f"knn_fields={embedding_fields}, "
|
||||||
|
f"filters={len(filter_clauses)} clauses"
|
||||||
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
|
|
@ -1231,6 +1623,9 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
This is the main interface method that performs the multi-model search using the
|
This is the main interface method that performs the multi-model search using the
|
||||||
configured search_query and returns results in Langflow's Data format.
|
configured search_query and returns results in Langflow's Data format.
|
||||||
|
|
||||||
|
Always builds the vector store (triggering ingestion if needed), then performs
|
||||||
|
search only if a query is provided.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of Data objects containing search results with text and metadata
|
List of Data objects containing search results with text and metadata
|
||||||
|
|
||||||
|
|
@ -1238,9 +1633,20 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
Exception: If search operation fails
|
Exception: If search operation fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
raw = self.search(self.search_query or "")
|
# Always build/cache the vector store to ensure ingestion happens
|
||||||
|
logger.info(f"Search query: {self.search_query}")
|
||||||
|
if self._cached_vector_store is None:
|
||||||
|
self.build_vector_store()
|
||||||
|
|
||||||
|
# Only perform search if query is provided
|
||||||
|
search_query = (self.search_query or "").strip()
|
||||||
|
if not search_query:
|
||||||
|
self.log("No search query provided - ingestion completed, returning empty results")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Perform search with the provided query
|
||||||
|
raw = self.search(search_query)
|
||||||
return [Data(text=hit["page_content"], **hit["metadata"]) for hit in raw]
|
return [Data(text=hit["page_content"], **hit["metadata"]) for hit in raw]
|
||||||
self.log(self.ingest_data)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log(f"search_documents error: {e}")
|
self.log(f"search_documents error: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
@ -1280,9 +1686,6 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
build_config["jwt_header"]["required"] = is_jwt
|
build_config["jwt_header"]["required"] = is_jwt
|
||||||
build_config["bearer_prefix"]["required"] = False
|
build_config["bearer_prefix"]["required"] = False
|
||||||
|
|
||||||
if is_basic:
|
|
||||||
build_config["jwt_token"]["value"] = ""
|
|
||||||
|
|
||||||
return build_config
|
return build_config
|
||||||
|
|
||||||
except (KeyError, ValueError) as e:
|
except (KeyError, ValueError) as e:
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -3,6 +3,7 @@ import {
|
||||||
useQuery,
|
useQuery,
|
||||||
useQueryClient,
|
useQueryClient,
|
||||||
} from "@tanstack/react-query";
|
} from "@tanstack/react-query";
|
||||||
|
import { useChat } from "@/contexts/chat-context";
|
||||||
import { useGetSettingsQuery } from "./useGetSettingsQuery";
|
import { useGetSettingsQuery } from "./useGetSettingsQuery";
|
||||||
|
|
||||||
export interface ProviderHealthDetails {
|
export interface ProviderHealthDetails {
|
||||||
|
|
@ -24,6 +25,7 @@ export interface ProviderHealthResponse {
|
||||||
|
|
||||||
export interface ProviderHealthParams {
|
export interface ProviderHealthParams {
|
||||||
provider?: "openai" | "ollama" | "watsonx";
|
provider?: "openai" | "ollama" | "watsonx";
|
||||||
|
test_completion?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track consecutive failures for exponential backoff
|
// Track consecutive failures for exponential backoff
|
||||||
|
|
@ -38,6 +40,9 @@ export const useProviderHealthQuery = (
|
||||||
) => {
|
) => {
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
|
|
||||||
|
// Get chat error state from context (ChatProvider wraps the entire app in layout.tsx)
|
||||||
|
const { hasChatError, setChatError } = useChat();
|
||||||
|
|
||||||
const { data: settings = {} } = useGetSettingsQuery();
|
const { data: settings = {} } = useGetSettingsQuery();
|
||||||
|
|
||||||
async function checkProviderHealth(): Promise<ProviderHealthResponse> {
|
async function checkProviderHealth(): Promise<ProviderHealthResponse> {
|
||||||
|
|
@ -49,6 +54,12 @@ export const useProviderHealthQuery = (
|
||||||
url.searchParams.set("provider", params.provider);
|
url.searchParams.set("provider", params.provider);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add test_completion query param if specified or if chat error exists
|
||||||
|
const testCompletion = params?.test_completion ?? hasChatError;
|
||||||
|
if (testCompletion) {
|
||||||
|
url.searchParams.set("test_completion", "true");
|
||||||
|
}
|
||||||
|
|
||||||
const response = await fetch(url.toString());
|
const response = await fetch(url.toString());
|
||||||
|
|
||||||
if (response.ok) {
|
if (response.ok) {
|
||||||
|
|
@ -90,7 +101,7 @@ export const useProviderHealthQuery = (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const queryKey = ["provider", "health"];
|
const queryKey = ["provider", "health", params?.test_completion];
|
||||||
const failureCountKey = queryKey.join("-");
|
const failureCountKey = queryKey.join("-");
|
||||||
|
|
||||||
const queryResult = useQuery(
|
const queryResult = useQuery(
|
||||||
|
|
@ -101,26 +112,32 @@ export const useProviderHealthQuery = (
|
||||||
refetchInterval: (query) => {
|
refetchInterval: (query) => {
|
||||||
const data = query.state.data;
|
const data = query.state.data;
|
||||||
const status = data?.status;
|
const status = data?.status;
|
||||||
|
|
||||||
// If healthy, reset failure count and check every 30 seconds
|
// If healthy, reset failure count and check every 30 seconds
|
||||||
|
// Also reset chat error flag if we're using test_completion=true and it succeeded
|
||||||
if (status === "healthy") {
|
if (status === "healthy") {
|
||||||
failureCountMap.set(failureCountKey, 0);
|
failureCountMap.set(failureCountKey, 0);
|
||||||
|
// If we were checking with test_completion=true due to chat errors, reset the flag
|
||||||
|
if (hasChatError && setChatError) {
|
||||||
|
setChatError(false);
|
||||||
|
}
|
||||||
return 30000;
|
return 30000;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If backend unavailable, use moderate polling
|
// If backend unavailable, use moderate polling
|
||||||
if (status === "backend-unavailable") {
|
if (status === "backend-unavailable") {
|
||||||
return 15000;
|
return 15000;
|
||||||
}
|
}
|
||||||
|
|
||||||
// For unhealthy/error status, use exponential backoff
|
// For unhealthy/error status, use exponential backoff
|
||||||
const currentFailures = failureCountMap.get(failureCountKey) || 0;
|
const currentFailures = failureCountMap.get(failureCountKey) || 0;
|
||||||
failureCountMap.set(failureCountKey, currentFailures + 1);
|
failureCountMap.set(failureCountKey, currentFailures + 1);
|
||||||
|
|
||||||
// Exponential backoff: 5s, 10s, 20s, then 30s
|
// Exponential backoff: 5s, 10s, 20s, then 30s
|
||||||
const backoffDelays = [5000, 10000, 20000, 30000];
|
const backoffDelays = [5000, 10000, 20000, 30000];
|
||||||
const delay = backoffDelays[Math.min(currentFailures, backoffDelays.length - 1)];
|
const delay =
|
||||||
|
backoffDelays[Math.min(currentFailures, backoffDelays.length - 1)];
|
||||||
|
|
||||||
return delay;
|
return delay;
|
||||||
},
|
},
|
||||||
refetchOnWindowFocus: false, // Disabled to reduce unnecessary calls on tab switches
|
refetchOnWindowFocus: false, // Disabled to reduce unnecessary calls on tab switches
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ function ChatPage() {
|
||||||
]);
|
]);
|
||||||
const [input, setInput] = useState("");
|
const [input, setInput] = useState("");
|
||||||
const { loading, setLoading } = useLoadingStore();
|
const { loading, setLoading } = useLoadingStore();
|
||||||
|
const { setChatError } = useChat();
|
||||||
const [asyncMode, setAsyncMode] = useState(true);
|
const [asyncMode, setAsyncMode] = useState(true);
|
||||||
const [expandedFunctionCalls, setExpandedFunctionCalls] = useState<
|
const [expandedFunctionCalls, setExpandedFunctionCalls] = useState<
|
||||||
Set<string>
|
Set<string>
|
||||||
|
|
@ -123,6 +124,8 @@ function ChatPage() {
|
||||||
console.error("Streaming error:", error);
|
console.error("Streaming error:", error);
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
setWaitingTooLong(false);
|
setWaitingTooLong(false);
|
||||||
|
// Set chat error flag to trigger test_completion=true on health checks
|
||||||
|
setChatError(true);
|
||||||
const errorMessage: Message = {
|
const errorMessage: Message = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content:
|
content:
|
||||||
|
|
@ -197,6 +200,11 @@ function ChatPage() {
|
||||||
const result = await response.json();
|
const result = await response.json();
|
||||||
console.log("Upload result:", result);
|
console.log("Upload result:", result);
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
// Set chat error flag if upload fails
|
||||||
|
setChatError(true);
|
||||||
|
}
|
||||||
|
|
||||||
if (response.status === 201) {
|
if (response.status === 201) {
|
||||||
// New flow: Got task ID, start tracking with centralized system
|
// New flow: Got task ID, start tracking with centralized system
|
||||||
const taskId = result.task_id || result.id;
|
const taskId = result.task_id || result.id;
|
||||||
|
|
@ -255,6 +263,8 @@ function ChatPage() {
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Upload failed:", error);
|
console.error("Upload failed:", error);
|
||||||
|
// Set chat error flag to trigger test_completion=true on health checks
|
||||||
|
setChatError(true);
|
||||||
const errorMessage: Message = {
|
const errorMessage: Message = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: `❌ Failed to process document. Please try again.`,
|
content: `❌ Failed to process document. Please try again.`,
|
||||||
|
|
@ -858,6 +868,8 @@ function ChatPage() {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
console.error("Chat failed:", result.error);
|
console.error("Chat failed:", result.error);
|
||||||
|
// Set chat error flag to trigger test_completion=true on health checks
|
||||||
|
setChatError(true);
|
||||||
const errorMessage: Message = {
|
const errorMessage: Message = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: "Sorry, I encountered an error. Please try again.",
|
content: "Sorry, I encountered an error. Please try again.",
|
||||||
|
|
@ -867,6 +879,8 @@ function ChatPage() {
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Chat error:", error);
|
console.error("Chat error:", error);
|
||||||
|
// Set chat error flag to trigger test_completion=true on health checks
|
||||||
|
setChatError(true);
|
||||||
const errorMessage: Message = {
|
const errorMessage: Message = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content:
|
content:
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ export function OnboardingContent({
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<StickToBottom
|
<StickToBottom
|
||||||
className="flex h-full flex-1 flex-col"
|
className="flex h-full flex-1 flex-col [&>div]:scrollbar-hide"
|
||||||
resize="smooth"
|
resize="smooth"
|
||||||
initial="instant"
|
initial="instant"
|
||||||
mass={1}
|
mass={1}
|
||||||
|
|
|
||||||
|
|
@ -158,6 +158,16 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => {
|
||||||
const errorMessage = error instanceof Error ? error.message : "Upload failed";
|
const errorMessage = error instanceof Error ? error.message : "Upload failed";
|
||||||
console.error("Upload failed", errorMessage);
|
console.error("Upload failed", errorMessage);
|
||||||
|
|
||||||
|
// Dispatch event that chat context can listen to
|
||||||
|
// This avoids circular dependency issues
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
window.dispatchEvent(
|
||||||
|
new CustomEvent("ingestionFailed", {
|
||||||
|
detail: { source: "onboarding" },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Show error toast notification
|
// Show error toast notification
|
||||||
toast.error("Document upload failed", {
|
toast.error("Document upload failed", {
|
||||||
description: errorMessage,
|
description: errorMessage,
|
||||||
|
|
|
||||||
|
|
@ -1,81 +0,0 @@
|
||||||
"use client";
|
|
||||||
|
|
||||||
import { Suspense, useEffect } from "react";
|
|
||||||
import { useRouter } from "next/navigation";
|
|
||||||
import { DoclingHealthBanner } from "@/components/docling-health-banner";
|
|
||||||
import { ProtectedRoute } from "@/components/protected-route";
|
|
||||||
import { DotPattern } from "@/components/ui/dot-pattern";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { useGetSettingsQuery } from "@/app/api/queries/useGetSettingsQuery";
|
|
||||||
import OnboardingCard from "./_components/onboarding-card";
|
|
||||||
|
|
||||||
function LegacyOnboardingPage() {
|
|
||||||
const router = useRouter();
|
|
||||||
const { data: settingsDb, isLoading: isSettingsLoading } =
|
|
||||||
useGetSettingsQuery();
|
|
||||||
|
|
||||||
// Redirect if already completed onboarding
|
|
||||||
useEffect(() => {
|
|
||||||
if (!isSettingsLoading && settingsDb && settingsDb.edited) {
|
|
||||||
router.push("/");
|
|
||||||
}
|
|
||||||
}, [isSettingsLoading, settingsDb, router]);
|
|
||||||
|
|
||||||
const handleComplete = () => {
|
|
||||||
router.push("/");
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="min-h-dvh w-full flex gap-5 flex-col items-center justify-center bg-background relative p-4">
|
|
||||||
<DotPattern
|
|
||||||
width={24}
|
|
||||||
height={24}
|
|
||||||
cx={1}
|
|
||||||
cy={1}
|
|
||||||
cr={1}
|
|
||||||
className={cn(
|
|
||||||
"[mask-image:linear-gradient(to_bottom,white,transparent,transparent)]",
|
|
||||||
"text-input/70",
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<DoclingHealthBanner className="absolute top-0 left-0 right-0 w-full z-20" />
|
|
||||||
|
|
||||||
<div className="flex flex-col items-center gap-5 min-h-[550px] w-full z-10">
|
|
||||||
<div className="flex flex-col items-center justify-center gap-4">
|
|
||||||
<h1 className="text-2xl font-medium font-chivo">
|
|
||||||
Connect a model provider
|
|
||||||
</h1>
|
|
||||||
</div>
|
|
||||||
<OnboardingCard onComplete={handleComplete} />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
function OnboardingRouter() {
|
|
||||||
const updatedOnboarding = process.env.UPDATED_ONBOARDING === "true";
|
|
||||||
const router = useRouter();
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (updatedOnboarding) {
|
|
||||||
router.push("/new-onboarding");
|
|
||||||
}
|
|
||||||
}, [updatedOnboarding, router]);
|
|
||||||
|
|
||||||
if (updatedOnboarding) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return <LegacyOnboardingPage />;
|
|
||||||
}
|
|
||||||
|
|
||||||
export default function ProtectedOnboardingPage() {
|
|
||||||
return (
|
|
||||||
<ProtectedRoute>
|
|
||||||
<Suspense fallback={<div>Loading onboarding...</div>}>
|
|
||||||
<OnboardingRouter />
|
|
||||||
</Suspense>
|
|
||||||
</ProtectedRoute>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
@ -238,6 +238,15 @@ export function KnowledgeDropdown() {
|
||||||
await uploadFileUtil(file, replace);
|
await uploadFileUtil(file, replace);
|
||||||
refetchTasks();
|
refetchTasks();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
// Dispatch event that chat context can listen to
|
||||||
|
// This avoids circular dependency issues
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
window.dispatchEvent(
|
||||||
|
new CustomEvent("ingestionFailed", {
|
||||||
|
detail: { source: "knowledge-dropdown" },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
toast.error("Upload failed", {
|
toast.error("Upload failed", {
|
||||||
description: error instanceof Error ? error.message : "Unknown error",
|
description: error instanceof Error ? error.message : "Unknown error",
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import { useProviderHealthQuery } from "@/app/api/queries/useProviderHealthQuery
|
||||||
import type { ModelProvider } from "@/app/settings/_helpers/model-helpers";
|
import type { ModelProvider } from "@/app/settings/_helpers/model-helpers";
|
||||||
import { Banner, BannerIcon, BannerTitle } from "@/components/ui/banner";
|
import { Banner, BannerIcon, BannerTitle } from "@/components/ui/banner";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
import { useChat } from "@/contexts/chat-context";
|
||||||
import { Button } from "./ui/button";
|
import { Button } from "./ui/button";
|
||||||
|
|
||||||
interface ProviderHealthBannerProps {
|
interface ProviderHealthBannerProps {
|
||||||
|
|
@ -14,13 +15,16 @@ interface ProviderHealthBannerProps {
|
||||||
|
|
||||||
// Custom hook to check provider health status
|
// Custom hook to check provider health status
|
||||||
export function useProviderHealth() {
|
export function useProviderHealth() {
|
||||||
|
const { hasChatError } = useChat();
|
||||||
const {
|
const {
|
||||||
data: health,
|
data: health,
|
||||||
isLoading,
|
isLoading,
|
||||||
isFetching,
|
isFetching,
|
||||||
error,
|
error,
|
||||||
isError,
|
isError,
|
||||||
} = useProviderHealthQuery();
|
} = useProviderHealthQuery({
|
||||||
|
test_completion: hasChatError, // Use test_completion=true when chat errors occur
|
||||||
|
});
|
||||||
|
|
||||||
const isHealthy = health?.status === "healthy" && !isError;
|
const isHealthy = health?.status === "healthy" && !isError;
|
||||||
// Only consider unhealthy if backend is up but provider validation failed
|
// Only consider unhealthy if backend is up but provider validation failed
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,8 @@ interface ChatContextType {
|
||||||
conversationFilter: KnowledgeFilter | null;
|
conversationFilter: KnowledgeFilter | null;
|
||||||
// responseId: undefined = use currentConversationId, null = don't save to localStorage
|
// responseId: undefined = use currentConversationId, null = don't save to localStorage
|
||||||
setConversationFilter: (filter: KnowledgeFilter | null, responseId?: string | null) => void;
|
setConversationFilter: (filter: KnowledgeFilter | null, responseId?: string | null) => void;
|
||||||
|
hasChatError: boolean;
|
||||||
|
setChatError: (hasError: boolean) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
const ChatContext = createContext<ChatContextType | undefined>(undefined);
|
const ChatContext = createContext<ChatContextType | undefined>(undefined);
|
||||||
|
|
@ -108,6 +110,19 @@ export function ChatProvider({ children }: ChatProviderProps) {
|
||||||
const [conversationLoaded, setConversationLoaded] = useState(false);
|
const [conversationLoaded, setConversationLoaded] = useState(false);
|
||||||
const [conversationFilter, setConversationFilterState] =
|
const [conversationFilter, setConversationFilterState] =
|
||||||
useState<KnowledgeFilter | null>(null);
|
useState<KnowledgeFilter | null>(null);
|
||||||
|
const [hasChatError, setChatError] = useState(false);
|
||||||
|
|
||||||
|
// Listen for ingestion failures and set chat error flag
|
||||||
|
useEffect(() => {
|
||||||
|
const handleIngestionFailed = () => {
|
||||||
|
setChatError(true);
|
||||||
|
};
|
||||||
|
|
||||||
|
window.addEventListener("ingestionFailed", handleIngestionFailed);
|
||||||
|
return () => {
|
||||||
|
window.removeEventListener("ingestionFailed", handleIngestionFailed);
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
// Debounce refresh requests to prevent excessive reloads
|
// Debounce refresh requests to prevent excessive reloads
|
||||||
const refreshTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
const refreshTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
|
|
@ -358,6 +373,8 @@ export function ChatProvider({ children }: ChatProviderProps) {
|
||||||
setConversationLoaded,
|
setConversationLoaded,
|
||||||
conversationFilter,
|
conversationFilter,
|
||||||
setConversationFilter,
|
setConversationFilter,
|
||||||
|
hasChatError,
|
||||||
|
setChatError,
|
||||||
}),
|
}),
|
||||||
[
|
[
|
||||||
endpoint,
|
endpoint,
|
||||||
|
|
@ -378,6 +395,7 @@ export function ChatProvider({ children }: ChatProviderProps) {
|
||||||
conversationLoaded,
|
conversationLoaded,
|
||||||
conversationFilter,
|
conversationFilter,
|
||||||
setConversationFilter,
|
setConversationFilter,
|
||||||
|
hasChatError,
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -323,6 +323,20 @@ export function TaskProvider({ children }: { children: React.ReactNode }) {
|
||||||
currentTask.error || "Unknown error"
|
currentTask.error || "Unknown error"
|
||||||
}`,
|
}`,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Set chat error flag to trigger test_completion=true on health checks
|
||||||
|
// Only for ingestion-related tasks (tasks with files are ingestion tasks)
|
||||||
|
if (currentTask.files && Object.keys(currentTask.files).length > 0) {
|
||||||
|
// Dispatch event that chat context can listen to
|
||||||
|
// This avoids circular dependency issues
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
window.dispatchEvent(
|
||||||
|
new CustomEvent("ingestionFailed", {
|
||||||
|
detail: { taskId: currentTask.task_id },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,6 @@ const nextConfig: NextConfig = {
|
||||||
eslint: {
|
eslint: {
|
||||||
ignoreDuringBuilds: true,
|
ignoreDuringBuilds: true,
|
||||||
},
|
},
|
||||||
env: {
|
|
||||||
UPDATED_ONBOARDING: process.env.UPDATED_ONBOARDING,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export default nextConfig;
|
export default nextConfig;
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
"""Provider health check endpoint."""
|
"""Provider health check endpoint."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import httpx
|
import httpx
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from utils.logging_config import get_logger
|
from utils.logging_config import get_logger
|
||||||
from config.settings import get_openrag_config
|
from config.settings import get_openrag_config
|
||||||
from api.provider_validation import validate_provider_setup, _test_ollama_lightweight_health
|
from api.provider_validation import validate_provider_setup
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -16,6 +17,8 @@ async def check_provider_health(request):
|
||||||
Query parameters:
|
Query parameters:
|
||||||
provider (optional): Provider to check ('openai', 'ollama', 'watsonx', 'anthropic').
|
provider (optional): Provider to check ('openai', 'ollama', 'watsonx', 'anthropic').
|
||||||
If not provided, checks the currently configured provider.
|
If not provided, checks the currently configured provider.
|
||||||
|
test_completion (optional): If 'true', performs full validation with completion/embedding tests (consumes credits).
|
||||||
|
If 'false' or not provided, performs lightweight validation (no/minimal credits consumed).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
200: Provider is healthy and validated
|
200: Provider is healthy and validated
|
||||||
|
|
@ -26,6 +29,7 @@ async def check_provider_health(request):
|
||||||
# Get optional provider from query params
|
# Get optional provider from query params
|
||||||
query_params = dict(request.query_params)
|
query_params = dict(request.query_params)
|
||||||
check_provider = query_params.get("provider")
|
check_provider = query_params.get("provider")
|
||||||
|
test_completion = query_params.get("test_completion", "false").lower() == "true"
|
||||||
|
|
||||||
# Get current config
|
# Get current config
|
||||||
current_config = get_openrag_config()
|
current_config = get_openrag_config()
|
||||||
|
|
@ -100,6 +104,7 @@ async def check_provider_health(request):
|
||||||
llm_model=llm_model,
|
llm_model=llm_model,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
|
test_completion=test_completion,
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|
@ -124,23 +129,14 @@ async def check_provider_health(request):
|
||||||
|
|
||||||
# Validate LLM provider
|
# Validate LLM provider
|
||||||
try:
|
try:
|
||||||
# For Ollama, use lightweight health check that doesn't block on active requests
|
await validate_provider_setup(
|
||||||
if provider == "ollama":
|
provider=provider,
|
||||||
try:
|
api_key=api_key,
|
||||||
await _test_ollama_lightweight_health(endpoint)
|
llm_model=llm_model,
|
||||||
except Exception as lightweight_error:
|
endpoint=endpoint,
|
||||||
# If lightweight check fails, Ollama is down or misconfigured
|
project_id=project_id,
|
||||||
llm_error = str(lightweight_error)
|
test_completion=test_completion,
|
||||||
logger.error(f"LLM provider ({provider}) lightweight check failed: {llm_error}")
|
)
|
||||||
raise
|
|
||||||
else:
|
|
||||||
await validate_provider_setup(
|
|
||||||
provider=provider,
|
|
||||||
api_key=api_key,
|
|
||||||
llm_model=llm_model,
|
|
||||||
endpoint=endpoint,
|
|
||||||
project_id=project_id,
|
|
||||||
)
|
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException as e:
|
||||||
# Timeout means provider is busy, not misconfigured
|
# Timeout means provider is busy, not misconfigured
|
||||||
if provider == "ollama":
|
if provider == "ollama":
|
||||||
|
|
@ -154,24 +150,25 @@ async def check_provider_health(request):
|
||||||
logger.error(f"LLM provider ({provider}) validation failed: {llm_error}")
|
logger.error(f"LLM provider ({provider}) validation failed: {llm_error}")
|
||||||
|
|
||||||
# Validate embedding provider
|
# Validate embedding provider
|
||||||
|
# For WatsonX with test_completion=True, wait 2 seconds between completion and embedding tests
|
||||||
|
if (
|
||||||
|
test_completion
|
||||||
|
and provider == "watsonx"
|
||||||
|
and embedding_provider == "watsonx"
|
||||||
|
and llm_error is None
|
||||||
|
):
|
||||||
|
logger.info("Waiting 2 seconds before WatsonX embedding test (after completion test)")
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# For Ollama, use lightweight health check first
|
await validate_provider_setup(
|
||||||
if embedding_provider == "ollama":
|
provider=embedding_provider,
|
||||||
try:
|
api_key=embedding_api_key,
|
||||||
await _test_ollama_lightweight_health(embedding_endpoint)
|
embedding_model=embedding_model,
|
||||||
except Exception as lightweight_error:
|
endpoint=embedding_endpoint,
|
||||||
# If lightweight check fails, Ollama is down or misconfigured
|
project_id=embedding_project_id,
|
||||||
embedding_error = str(lightweight_error)
|
test_completion=test_completion,
|
||||||
logger.error(f"Embedding provider ({embedding_provider}) lightweight check failed: {embedding_error}")
|
)
|
||||||
raise
|
|
||||||
else:
|
|
||||||
await validate_provider_setup(
|
|
||||||
provider=embedding_provider,
|
|
||||||
api_key=embedding_api_key,
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
endpoint=embedding_endpoint,
|
|
||||||
project_id=embedding_project_id,
|
|
||||||
)
|
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException as e:
|
||||||
# Timeout means provider is busy, not misconfigured
|
# Timeout means provider is busy, not misconfigured
|
||||||
if embedding_provider == "ollama":
|
if embedding_provider == "ollama":
|
||||||
|
|
|
||||||
|
|
@ -14,17 +14,20 @@ async def validate_provider_setup(
|
||||||
llm_model: str = None,
|
llm_model: str = None,
|
||||||
endpoint: str = None,
|
endpoint: str = None,
|
||||||
project_id: str = None,
|
project_id: str = None,
|
||||||
|
test_completion: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Validate provider setup by testing completion with tool calling and embedding.
|
Validate provider setup by testing completion with tool calling and embedding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: Provider name ('openai', 'watsonx', 'ollama')
|
provider: Provider name ('openai', 'watsonx', 'ollama', 'anthropic')
|
||||||
api_key: API key for the provider (optional for ollama)
|
api_key: API key for the provider (optional for ollama)
|
||||||
embedding_model: Embedding model to test
|
embedding_model: Embedding model to test
|
||||||
llm_model: LLM model to test
|
llm_model: LLM model to test
|
||||||
endpoint: Provider endpoint (required for ollama and watsonx)
|
endpoint: Provider endpoint (required for ollama and watsonx)
|
||||||
project_id: Project ID (required for watsonx)
|
project_id: Project ID (required for watsonx)
|
||||||
|
test_completion: If True, performs full validation with completion/embedding tests (consumes credits).
|
||||||
|
If False, performs lightweight validation (no credits consumed). Default: False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If validation fails with message "Setup failed, please try again or select a different provider."
|
Exception: If validation fails with message "Setup failed, please try again or select a different provider."
|
||||||
|
|
@ -32,29 +35,37 @@ async def validate_provider_setup(
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting validation for provider: {provider_lower}")
|
logger.info(f"Starting validation for provider: {provider_lower} (test_completion={test_completion})")
|
||||||
|
|
||||||
if embedding_model:
|
if test_completion:
|
||||||
# Test embedding
|
# Full validation with completion/embedding tests (consumes credits)
|
||||||
await test_embedding(
|
if embedding_model:
|
||||||
|
# Test embedding
|
||||||
|
await test_embedding(
|
||||||
|
provider=provider_lower,
|
||||||
|
api_key=api_key,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
endpoint=endpoint,
|
||||||
|
project_id=project_id,
|
||||||
|
)
|
||||||
|
elif llm_model:
|
||||||
|
# Test completion with tool calling
|
||||||
|
await test_completion_with_tools(
|
||||||
|
provider=provider_lower,
|
||||||
|
api_key=api_key,
|
||||||
|
llm_model=llm_model,
|
||||||
|
endpoint=endpoint,
|
||||||
|
project_id=project_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Lightweight validation (no credits consumed)
|
||||||
|
await test_lightweight_health(
|
||||||
provider=provider_lower,
|
provider=provider_lower,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
embedding_model=embedding_model,
|
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif llm_model:
|
|
||||||
# Test completion with tool calling
|
|
||||||
await test_completion_with_tools(
|
|
||||||
provider=provider_lower,
|
|
||||||
api_key=api_key,
|
|
||||||
llm_model=llm_model,
|
|
||||||
endpoint=endpoint,
|
|
||||||
project_id=project_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"Validation successful for provider: {provider_lower}")
|
logger.info(f"Validation successful for provider: {provider_lower}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -62,6 +73,26 @@ async def validate_provider_setup(
|
||||||
raise Exception("Setup failed, please try again or select a different provider.")
|
raise Exception("Setup failed, please try again or select a different provider.")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_lightweight_health(
|
||||||
|
provider: str,
|
||||||
|
api_key: str = None,
|
||||||
|
endpoint: str = None,
|
||||||
|
project_id: str = None,
|
||||||
|
) -> None:
|
||||||
|
"""Test provider health with lightweight check (no credits consumed)."""
|
||||||
|
|
||||||
|
if provider == "openai":
|
||||||
|
await _test_openai_lightweight_health(api_key)
|
||||||
|
elif provider == "watsonx":
|
||||||
|
await _test_watsonx_lightweight_health(api_key, endpoint, project_id)
|
||||||
|
elif provider == "ollama":
|
||||||
|
await _test_ollama_lightweight_health(endpoint)
|
||||||
|
elif provider == "anthropic":
|
||||||
|
await _test_anthropic_lightweight_health(api_key)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
|
||||||
async def test_completion_with_tools(
|
async def test_completion_with_tools(
|
||||||
provider: str,
|
provider: str,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
|
|
@ -103,6 +134,40 @@ async def test_embedding(
|
||||||
|
|
||||||
|
|
||||||
# OpenAI validation functions
|
# OpenAI validation functions
|
||||||
|
async def _test_openai_lightweight_health(api_key: str) -> None:
|
||||||
|
"""Test OpenAI API key validity with lightweight check.
|
||||||
|
|
||||||
|
Only checks if the API key is valid without consuming credits.
|
||||||
|
Uses the /v1/models endpoint which doesn't consume credits.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# Use /v1/models endpoint which validates the key without consuming credits
|
||||||
|
response = await client.get(
|
||||||
|
"https://api.openai.com/v1/models",
|
||||||
|
headers=headers,
|
||||||
|
timeout=10.0, # Short timeout for lightweight check
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"OpenAI lightweight health check failed: {response.status_code}")
|
||||||
|
raise Exception(f"OpenAI API key validation failed: {response.status_code}")
|
||||||
|
|
||||||
|
logger.info("OpenAI lightweight health check passed")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("OpenAI lightweight health check timed out")
|
||||||
|
raise Exception("OpenAI API request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI lightweight health check failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> None:
|
async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> None:
|
||||||
"""Test OpenAI completion with tool calling."""
|
"""Test OpenAI completion with tool calling."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -213,6 +278,45 @@ async def _test_openai_embedding(api_key: str, embedding_model: str) -> None:
|
||||||
|
|
||||||
|
|
||||||
# IBM Watson validation functions
|
# IBM Watson validation functions
|
||||||
|
async def _test_watsonx_lightweight_health(
|
||||||
|
api_key: str, endpoint: str, project_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Test WatsonX API key validity with lightweight check.
|
||||||
|
|
||||||
|
Only checks if the API key is valid by getting a bearer token.
|
||||||
|
Does not consume credits by avoiding model inference requests.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get bearer token from IBM IAM - this validates the API key without consuming credits
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
token_response = await client.post(
|
||||||
|
"https://iam.cloud.ibm.com/identity/token",
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
data={
|
||||||
|
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||||
|
"apikey": api_key,
|
||||||
|
},
|
||||||
|
timeout=10.0, # Short timeout for lightweight check
|
||||||
|
)
|
||||||
|
|
||||||
|
if token_response.status_code != 200:
|
||||||
|
logger.error(f"IBM IAM token request failed: {token_response.status_code}")
|
||||||
|
raise Exception("Failed to authenticate with IBM Watson - invalid API key")
|
||||||
|
|
||||||
|
bearer_token = token_response.json().get("access_token")
|
||||||
|
if not bearer_token:
|
||||||
|
raise Exception("No access token received from IBM")
|
||||||
|
|
||||||
|
logger.info("WatsonX lightweight health check passed - API key is valid")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("WatsonX lightweight health check timed out")
|
||||||
|
raise Exception("WatsonX API request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WatsonX lightweight health check failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _test_watsonx_completion_with_tools(
|
async def _test_watsonx_completion_with_tools(
|
||||||
api_key: str, llm_model: str, endpoint: str, project_id: str
|
api_key: str, llm_model: str, endpoint: str, project_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -483,6 +587,48 @@ async def _test_ollama_embedding(embedding_model: str, endpoint: str) -> None:
|
||||||
|
|
||||||
|
|
||||||
# Anthropic validation functions
|
# Anthropic validation functions
|
||||||
|
async def _test_anthropic_lightweight_health(api_key: str) -> None:
|
||||||
|
"""Test Anthropic API key validity with lightweight check.
|
||||||
|
|
||||||
|
Only checks if the API key is valid without consuming credits.
|
||||||
|
Uses a minimal messages request with max_tokens=1 to validate the key.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"x-api-key": api_key,
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Minimal validation request - uses cheapest model with minimal tokens
|
||||||
|
payload = {
|
||||||
|
"model": "claude-3-5-haiku-latest", # Cheapest model
|
||||||
|
"max_tokens": 1, # Minimum tokens to validate key
|
||||||
|
"messages": [{"role": "user", "content": "test"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
"https://api.anthropic.com/v1/messages",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=10.0, # Short timeout for lightweight check
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"Anthropic lightweight health check failed: {response.status_code}")
|
||||||
|
raise Exception(f"Anthropic API key validation failed: {response.status_code}")
|
||||||
|
|
||||||
|
logger.info("Anthropic lightweight health check passed")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("Anthropic lightweight health check timed out")
|
||||||
|
raise Exception("Anthropic API request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic lightweight health check failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _test_anthropic_completion_with_tools(api_key: str, llm_model: str) -> None:
|
async def _test_anthropic_completion_with_tools(api_key: str, llm_model: str) -> None:
|
||||||
"""Test Anthropic completion with tool calling."""
|
"""Test Anthropic completion with tool calling."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -897,6 +897,7 @@ async def onboarding(request, flows_service, session_manager=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate provider setup before initializing OpenSearch index
|
# Validate provider setup before initializing OpenSearch index
|
||||||
|
# Use lightweight validation (test_completion=False) to avoid consuming credits during onboarding
|
||||||
try:
|
try:
|
||||||
from api.provider_validation import validate_provider_setup
|
from api.provider_validation import validate_provider_setup
|
||||||
|
|
||||||
|
|
@ -905,13 +906,14 @@ async def onboarding(request, flows_service, session_manager=None):
|
||||||
llm_provider = current_config.agent.llm_provider.lower()
|
llm_provider = current_config.agent.llm_provider.lower()
|
||||||
llm_provider_config = current_config.get_llm_provider_config()
|
llm_provider_config = current_config.get_llm_provider_config()
|
||||||
|
|
||||||
logger.info(f"Validating LLM provider setup for {llm_provider}")
|
logger.info(f"Validating LLM provider setup for {llm_provider} (lightweight)")
|
||||||
await validate_provider_setup(
|
await validate_provider_setup(
|
||||||
provider=llm_provider,
|
provider=llm_provider,
|
||||||
api_key=getattr(llm_provider_config, "api_key", None),
|
api_key=getattr(llm_provider_config, "api_key", None),
|
||||||
llm_model=current_config.agent.llm_model,
|
llm_model=current_config.agent.llm_model,
|
||||||
endpoint=getattr(llm_provider_config, "endpoint", None),
|
endpoint=getattr(llm_provider_config, "endpoint", None),
|
||||||
project_id=getattr(llm_provider_config, "project_id", None),
|
project_id=getattr(llm_provider_config, "project_id", None),
|
||||||
|
test_completion=False, # Lightweight validation - no credits consumed
|
||||||
)
|
)
|
||||||
logger.info(f"LLM provider setup validation completed successfully for {llm_provider}")
|
logger.info(f"LLM provider setup validation completed successfully for {llm_provider}")
|
||||||
|
|
||||||
|
|
@ -920,13 +922,14 @@ async def onboarding(request, flows_service, session_manager=None):
|
||||||
embedding_provider = current_config.knowledge.embedding_provider.lower()
|
embedding_provider = current_config.knowledge.embedding_provider.lower()
|
||||||
embedding_provider_config = current_config.get_embedding_provider_config()
|
embedding_provider_config = current_config.get_embedding_provider_config()
|
||||||
|
|
||||||
logger.info(f"Validating embedding provider setup for {embedding_provider}")
|
logger.info(f"Validating embedding provider setup for {embedding_provider} (lightweight)")
|
||||||
await validate_provider_setup(
|
await validate_provider_setup(
|
||||||
provider=embedding_provider,
|
provider=embedding_provider,
|
||||||
api_key=getattr(embedding_provider_config, "api_key", None),
|
api_key=getattr(embedding_provider_config, "api_key", None),
|
||||||
embedding_model=current_config.knowledge.embedding_model,
|
embedding_model=current_config.knowledge.embedding_model,
|
||||||
endpoint=getattr(embedding_provider_config, "endpoint", None),
|
endpoint=getattr(embedding_provider_config, "endpoint", None),
|
||||||
project_id=getattr(embedding_provider_config, "project_id", None),
|
project_id=getattr(embedding_provider_config, "project_id", None),
|
||||||
|
test_completion=False, # Lightweight validation - no credits consumed
|
||||||
)
|
)
|
||||||
logger.info(f"Embedding provider setup validation completed successfully for {embedding_provider}")
|
logger.info(f"Embedding provider setup validation completed successfully for {embedding_provider}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ class ModelsService:
|
||||||
self.session_manager = None
|
self.session_manager = None
|
||||||
|
|
||||||
async def get_openai_models(self, api_key: str) -> Dict[str, List[Dict[str, str]]]:
|
async def get_openai_models(self, api_key: str) -> Dict[str, List[Dict[str, str]]]:
|
||||||
"""Fetch available models from OpenAI API"""
|
"""Fetch available models from OpenAI API with lightweight validation"""
|
||||||
try:
|
try:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {api_key}",
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
|
@ -58,6 +58,8 @@ class ModelsService:
|
||||||
}
|
}
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
# Lightweight validation: just check if API key is valid
|
||||||
|
# This doesn't consume credits, only validates the key
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
"https://api.openai.com/v1/models", headers=headers, timeout=10.0
|
"https://api.openai.com/v1/models", headers=headers, timeout=10.0
|
||||||
)
|
)
|
||||||
|
|
@ -101,6 +103,7 @@ class ModelsService:
|
||||||
key=lambda x: (not x.get("default", False), x["value"])
|
key=lambda x: (not x.get("default", False), x["value"])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("OpenAI API key validated successfully without consuming credits")
|
||||||
return {
|
return {
|
||||||
"language_models": language_models,
|
"language_models": language_models,
|
||||||
"embedding_models": embedding_models,
|
"embedding_models": embedding_models,
|
||||||
|
|
@ -389,38 +392,12 @@ class ModelsService:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate credentials with the first available LLM model
|
# Lightweight validation: API key is already validated by successfully getting bearer token
|
||||||
if language_models:
|
# No need to make a generation request that consumes credits
|
||||||
first_llm_model = language_models[0]["value"]
|
if bearer_token:
|
||||||
|
logger.info("IBM Watson API key validated successfully without consuming credits")
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
validation_url = f"{watson_endpoint}/ml/v1/text/generation"
|
|
||||||
validation_params = {"version": "2024-09-16"}
|
|
||||||
validation_payload = {
|
|
||||||
"input": "test",
|
|
||||||
"model_id": first_llm_model,
|
|
||||||
"project_id": project_id,
|
|
||||||
"parameters": {
|
|
||||||
"max_new_tokens": 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
validation_response = await client.post(
|
|
||||||
validation_url,
|
|
||||||
headers=headers,
|
|
||||||
params=validation_params,
|
|
||||||
json=validation_payload,
|
|
||||||
timeout=10.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if validation_response.status_code != 200:
|
|
||||||
raise Exception(
|
|
||||||
f"Invalid credentials or endpoint: {validation_response.status_code} - {validation_response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"IBM Watson credentials validated successfully using model: {first_llm_model}")
|
|
||||||
else:
|
else:
|
||||||
logger.warning("No language models available to validate credentials")
|
logger.warning("No bearer token available - API key validation may have failed")
|
||||||
|
|
||||||
if not language_models and not embedding_models:
|
if not language_models and not embedding_models:
|
||||||
raise Exception("No IBM models retrieved from API")
|
raise Exception("No IBM models retrieved from API")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue