diff --git a/tests/integration/test_api_endpoints.py b/tests/integration/test_api_endpoints.py index 869928fe..6c28e51c 100644 --- a/tests/integration/test_api_endpoints.py +++ b/tests/integration/test_api_endpoints.py @@ -204,6 +204,144 @@ async def test_upload_and_search_endpoint(tmp_path: Path, disable_langflow_inges pass +@pytest.mark.asyncio +async def test_search_multi_embedding_models( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + """Ensure /search fans out across multiple embedding models when present.""" + os.environ["DISABLE_INGEST_WITH_LANGFLOW"] = "true" + os.environ["DISABLE_STARTUP_INGEST"] = "true" + os.environ["GOOGLE_OAUTH_CLIENT_ID"] = "" + os.environ["GOOGLE_OAUTH_CLIENT_SECRET"] = "" + + import sys + + for mod in [ + "src.api.router", + "api.router", + "src.api.connector_router", + "api.connector_router", + "src.config.settings", + "config.settings", + "src.auth_middleware", + "auth_middleware", + "src.main", + "services.search_service", + "src.services.search_service", + ]: + sys.modules.pop(mod, None) + + from src.main import create_app, startup_tasks + from src.config.settings import clients, INDEX_NAME + from src.config.config_manager import config_manager + + await clients.initialize() + try: + await clients.opensearch.indices.delete(index=INDEX_NAME) + await asyncio.sleep(1) + except Exception: + pass + + app = await create_app() + await startup_tasks(app.state.services) + + # Mark configuration as edited so /settings accepts updates + config = config_manager.get_config() + config.edited = True + + from src.main import _ensure_opensearch_index + + await _ensure_opensearch_index() + + transport = httpx.ASGITransport(app=app) + + try: + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await wait_for_service_ready(client) + + async def _upload_doc(name: str, text: str) -> None: + file_path = tmp_path / name + file_path.write_text(text) + files = { + "file": ( + name, + file_path.read_bytes(), + "text/markdown", + ) + } + resp = await client.post("/upload", files=files) + assert resp.status_code == 201, resp.text + + async def _wait_for_models(expected_models: set[str], query: str = "physics"): + deadline = asyncio.get_event_loop().time() + 30.0 + last_payload = None + while asyncio.get_event_loop().time() < deadline: + resp = await client.post( + "/search", + json={"query": query, "limit": 10, "scoreThreshold": 0}, + ) + if resp.status_code != 200: + last_payload = resp.text + await asyncio.sleep(0.5) + continue + payload = resp.json() + buckets = ( + payload.get("aggregations", {}) + .get("embedding_models", {}) + .get("buckets", []) + ) + models = {b.get("key") for b in buckets if b.get("key")} + if expected_models <= models: + return payload + last_payload = payload + await asyncio.sleep(0.5) + raise AssertionError( + f"Embedding models not detected. Last payload: {last_payload}" + ) + + # Start with explicit small embedding model + resp = await client.post( + "/settings", + json={"embedding_model": "text-embedding-3-small"}, + ) + assert resp.status_code == 200, resp.text + + # Ingest first document (small model) + await _upload_doc("doc-small.md", "Physics basics and fundamental principles.") + payload_small = await _wait_for_models({"text-embedding-3-small"}) + result_models_small = {r.get("embedding_model") for r in payload_small.get("results", []) if r.get("embedding_model")} + assert "text-embedding-3-small" in result_models_small or not result_models_small + + # Update embedding model via settings + resp = await client.post( + "/settings", + json={"embedding_model": "text-embedding-3-large"}, + ) + assert resp.status_code == 200, resp.text + + # Ingest second document which should use the large embedding model + await _upload_doc("doc-large.md", "Advanced physics covers quantum topics extensively.") + + payload = await _wait_for_models({"text-embedding-3-small", "text-embedding-3-large"}) + buckets = payload.get("aggregations", {}).get("embedding_models", {}).get("buckets", []) + models = {b.get("key") for b in buckets} + assert {"text-embedding-3-small", "text-embedding-3-large"} <= models + + result_models = { + r.get("embedding_model") + for r in payload.get("results", []) + if r.get("embedding_model") + } + assert {"text-embedding-3-small", "text-embedding-3-large"} <= result_models + finally: + from src.config.settings import clients + + try: + await clients.close() + except Exception: + pass + + @pytest.mark.parametrize("disable_langflow_ingest", [True, False]) @pytest.mark.asyncio async def test_router_upload_ingest_traditional(tmp_path: Path, disable_langflow_ingest: bool):