fis: keyword type field name for search
This commit is contained in:
parent
a424bb422a
commit
a7c5a9f8f3
8 changed files with 94 additions and 44 deletions
|
|
@ -2,8 +2,9 @@ from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
|
@ -445,6 +446,13 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
logger.info(f"Added/updated embedding field mapping: {field_name}")
|
logger.info(f"Added/updated embedding field mapping: {field_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not add embedding field mapping for {field_name}: {e}")
|
logger.warning(f"Could not add embedding field mapping for {field_name}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
properties = self._get_index_properties(client)
|
||||||
|
if not self._is_knn_vector_field(properties, field_name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Field '{field_name}' is not mapped as knn_vector. Current mapping: {properties.get(field_name)}"
|
||||||
|
)
|
||||||
|
|
||||||
def _validate_aoss_with_engines(self, *, is_aoss: bool, engine: str) -> None:
|
def _validate_aoss_with_engines(self, *, is_aoss: bool, engine: str) -> None:
|
||||||
"""Validate engine compatibility with Amazon OpenSearch Serverless (AOSS).
|
"""Validate engine compatibility with Amazon OpenSearch Serverless (AOSS).
|
||||||
|
|
@ -664,8 +672,8 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
def embed_chunk(chunk_text: str) -> list[float]:
|
def embed_chunk(chunk_text: str) -> list[float]:
|
||||||
return self.embedding.embed_documents([chunk_text])[0]
|
return self.embedding.embed_documents([chunk_text])[0]
|
||||||
|
|
||||||
vectors: list[list[float]] | None = None
|
vectors: Optional[List[List[float]]] = None
|
||||||
last_exception: Exception | None = None
|
last_exception: Optional[Exception] = None
|
||||||
delay = 1.0
|
delay = 1.0
|
||||||
attempts = 0
|
attempts = 0
|
||||||
|
|
||||||
|
|
@ -864,7 +872,7 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
"aggs": {
|
"aggs": {
|
||||||
"embedding_models": {
|
"embedding_models": {
|
||||||
"terms": {
|
"terms": {
|
||||||
"field": "embedding_model.keyword",
|
"field": "embedding_model",
|
||||||
"size": 10
|
"size": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -879,7 +887,11 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result = client.search(index=self.index_name, body=agg_query)
|
result = client.search(
|
||||||
|
index=self.index_name,
|
||||||
|
body=agg_query,
|
||||||
|
params={"terminate_after": 0},
|
||||||
|
)
|
||||||
buckets = result.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
|
buckets = result.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
|
||||||
models = [b["key"] for b in buckets if b["key"]]
|
models = [b["key"] for b in buckets if b["key"]]
|
||||||
|
|
||||||
|
|
@ -1109,7 +1121,7 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
"data_sources": {"terms": {"field": "filename", "size": 20}},
|
"data_sources": {"terms": {"field": "filename", "size": 20}},
|
||||||
"document_types": {"terms": {"field": "mimetype", "size": 10}},
|
"document_types": {"terms": {"field": "mimetype", "size": 10}},
|
||||||
"owners": {"terms": {"field": "owner", "size": 10}},
|
"owners": {"terms": {"field": "owner", "size": 10}},
|
||||||
"embedding_models": {"terms": {"field": "embedding_model.keyword", "size": 10}},
|
"embedding_models": {"terms": {"field": "embedding_model", "size": 10}},
|
||||||
},
|
},
|
||||||
"_source": [
|
"_source": [
|
||||||
"filename",
|
"filename",
|
||||||
|
|
@ -1133,7 +1145,9 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = client.search(index=self.index_name, body=body)
|
resp = client.search(
|
||||||
|
index=self.index_name, body=body, params={"terminate_after": 0}
|
||||||
|
)
|
||||||
except RequestError as e:
|
except RequestError as e:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
lowered = error_message.lower()
|
lowered = error_message.lower()
|
||||||
|
|
@ -1147,7 +1161,11 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
fallback_body["query"]["bool"]["should"][0]["dis_max"]["queries"] = knn_queries_without_candidates
|
fallback_body["query"]["bool"]["should"][0]["dis_max"]["queries"] = knn_queries_without_candidates
|
||||||
except (KeyError, IndexError, TypeError) as inner_err:
|
except (KeyError, IndexError, TypeError) as inner_err:
|
||||||
raise e from inner_err
|
raise e from inner_err
|
||||||
resp = client.search(index=self.index_name, body=fallback_body)
|
resp = client.search(
|
||||||
|
index=self.index_name,
|
||||||
|
body=fallback_body,
|
||||||
|
params={"terminate_after": 0},
|
||||||
|
)
|
||||||
elif "knn_vector" in lowered or ("field" in lowered and "knn" in lowered):
|
elif "knn_vector" in lowered or ("field" in lowered and "knn" in lowered):
|
||||||
fallback_vector = next(iter(query_embeddings.values()), None)
|
fallback_vector = next(iter(query_embeddings.values()), None)
|
||||||
if fallback_vector is None:
|
if fallback_vector is None:
|
||||||
|
|
@ -1170,7 +1188,11 @@ class OpenSearchVectorStoreComponent(LCVectorStoreComponent):
|
||||||
if use_num_candidates:
|
if use_num_candidates:
|
||||||
knn_fallback["knn"][fallback_field]["num_candidates"] = num_candidates
|
knn_fallback["knn"][fallback_field]["num_candidates"] = num_candidates
|
||||||
fallback_body["query"]["bool"]["should"][0]["dis_max"]["queries"] = [knn_fallback]
|
fallback_body["query"]["bool"]["should"][0]["dis_max"]["queries"] = [knn_fallback]
|
||||||
resp = client.search(index=self.index_name, body=fallback_body)
|
resp = client.search(
|
||||||
|
index=self.index_name,
|
||||||
|
body=fallback_body,
|
||||||
|
params={"terminate_after": 0},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
hits = resp.get("hits", {}).get("hits", [])
|
hits = resp.get("hits", {}).get("hits", [])
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -63,12 +63,11 @@ class SearchService:
|
||||||
query_embeddings = {}
|
query_embeddings = {}
|
||||||
available_models = []
|
available_models = []
|
||||||
|
|
||||||
if not is_wildcard_match_all:
|
opensearch_client = self.session_manager.get_user_opensearch_client(
|
||||||
# First, detect which embedding models exist in the corpus
|
user_id, jwt_token
|
||||||
opensearch_client = self.session_manager.get_user_opensearch_client(
|
)
|
||||||
user_id, jwt_token
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if not is_wildcard_match_all:
|
||||||
# Build filter clauses first so we can use them in model detection
|
# Build filter clauses first so we can use them in model detection
|
||||||
filter_clauses = []
|
filter_clauses = []
|
||||||
if filters:
|
if filters:
|
||||||
|
|
@ -104,7 +103,7 @@ class SearchService:
|
||||||
"aggs": {
|
"aggs": {
|
||||||
"embedding_models": {
|
"embedding_models": {
|
||||||
"terms": {
|
"terms": {
|
||||||
"field": "embedding_model.keyword",
|
"field": "embedding_model",
|
||||||
"size": 10
|
"size": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -119,7 +118,9 @@ class SearchService:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
agg_result = await opensearch_client.search(index=INDEX_NAME, body=agg_query)
|
agg_result = await opensearch_client.search(
|
||||||
|
index=INDEX_NAME, body=agg_query, params={"terminate_after": 0}
|
||||||
|
)
|
||||||
buckets = agg_result.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
|
buckets = agg_result.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
|
||||||
available_models = [b["key"] for b in buckets if b["key"]]
|
available_models = [b["key"] for b in buckets if b["key"]]
|
||||||
|
|
||||||
|
|
@ -306,7 +307,7 @@ class SearchService:
|
||||||
"document_types": {"terms": {"field": "mimetype", "size": 10}},
|
"document_types": {"terms": {"field": "mimetype", "size": 10}},
|
||||||
"owners": {"terms": {"field": "owner_name.keyword", "size": 10}},
|
"owners": {"terms": {"field": "owner_name.keyword", "size": 10}},
|
||||||
"connector_types": {"terms": {"field": "connector_type", "size": 10}},
|
"connector_types": {"terms": {"field": "connector_type", "size": 10}},
|
||||||
"embedding_models": {"terms": {"field": "embedding_model.keyword", "size": 10}},
|
"embedding_models": {"terms": {"field": "embedding_model", "size": 10}},
|
||||||
},
|
},
|
||||||
"_source": [
|
"_source": [
|
||||||
"filename",
|
"filename",
|
||||||
|
|
@ -365,8 +366,12 @@ class SearchService:
|
||||||
|
|
||||||
from opensearchpy.exceptions import RequestError
|
from opensearchpy.exceptions import RequestError
|
||||||
|
|
||||||
|
search_params = {"terminate_after": 0}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
|
results = await opensearch_client.search(
|
||||||
|
index=INDEX_NAME, body=search_body, params=search_params
|
||||||
|
)
|
||||||
except RequestError as e:
|
except RequestError as e:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
if (
|
if (
|
||||||
|
|
@ -378,7 +383,9 @@ class SearchService:
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
results = await opensearch_client.search(
|
results = await opensearch_client.search(
|
||||||
index=INDEX_NAME, body=fallback_search_body
|
index=INDEX_NAME,
|
||||||
|
body=fallback_search_body,
|
||||||
|
params=search_params,
|
||||||
)
|
)
|
||||||
except RequestError as retry_error:
|
except RequestError as retry_error:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ This module provides helpers for:
|
||||||
- Ensuring embedding fields exist in the OpenSearch index
|
- Ensuring embedding fields exist in the OpenSearch index
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
from utils.logging_config import get_logger
|
from utils.logging_config import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
@ -100,6 +102,28 @@ async def ensure_embedding_field_exists(
|
||||||
dimensions=dimensions,
|
dimensions=dimensions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _get_field_definition() -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
mapping = await opensearch_client.indices.get_mapping(index=index_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to fetch mapping before ensuring embedding field",
|
||||||
|
index=index_name,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
properties = mapping.get(index_name, {}).get("mappings", {}).get("properties", {})
|
||||||
|
return properties.get(field_name, {}) if isinstance(properties, dict) else {}
|
||||||
|
|
||||||
|
existing_definition = await _get_field_definition()
|
||||||
|
if existing_definition:
|
||||||
|
if existing_definition.get("type") != "knn_vector":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Field '{field_name}' already exists with incompatible type '{existing_definition.get('type')}'"
|
||||||
|
)
|
||||||
|
return field_name
|
||||||
|
|
||||||
# Define the field mapping for both the vector field and the tracking field
|
# Define the field mapping for both the vector field and the tracking field
|
||||||
mapping = {
|
mapping = {
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
@ -136,22 +160,19 @@ async def ensure_embedding_field_exists(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e).lower()
|
logger.error(
|
||||||
# These are expected/safe errors when field already exists
|
"Failed to add embedding field mapping",
|
||||||
if "already" in error_msg or "exists" in error_msg or "mapper_parsing_exception" in error_msg:
|
field_name=field_name,
|
||||||
logger.debug(
|
model_name=model_name,
|
||||||
"Embedding field already exists (expected)",
|
error=str(e),
|
||||||
field_name=field_name,
|
)
|
||||||
model_name=model_name,
|
raise
|
||||||
)
|
|
||||||
else:
|
# Verify mapping was applied correctly
|
||||||
logger.error(
|
new_definition = await _get_field_definition()
|
||||||
"Failed to ensure embedding field exists",
|
if new_definition.get("type") != "knn_vector":
|
||||||
field_name=field_name,
|
raise RuntimeError(
|
||||||
model_name=model_name,
|
f"Failed to ensure '{field_name}' is mapped as knn_vector. Current definition: {new_definition}"
|
||||||
error=str(e),
|
)
|
||||||
)
|
|
||||||
# Don't raise - field might already exist with different params
|
|
||||||
# Better to proceed and let indexing fail if there's a real issue
|
|
||||||
|
|
||||||
return field_name
|
return field_name
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue