multi-model integration test

This commit is contained in:
phact 2025-10-11 02:34:21 -04:00
parent a7c5a9f8f3
commit fb35417586

View file

@ -204,6 +204,144 @@ async def test_upload_and_search_endpoint(tmp_path: Path, disable_langflow_inges
pass
@pytest.mark.asyncio
async def test_search_multi_embedding_models(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
"""Ensure /search fans out across multiple embedding models when present."""
os.environ["DISABLE_INGEST_WITH_LANGFLOW"] = "true"
os.environ["DISABLE_STARTUP_INGEST"] = "true"
os.environ["GOOGLE_OAUTH_CLIENT_ID"] = ""
os.environ["GOOGLE_OAUTH_CLIENT_SECRET"] = ""
import sys
for mod in [
"src.api.router",
"api.router",
"src.api.connector_router",
"api.connector_router",
"src.config.settings",
"config.settings",
"src.auth_middleware",
"auth_middleware",
"src.main",
"services.search_service",
"src.services.search_service",
]:
sys.modules.pop(mod, None)
from src.main import create_app, startup_tasks
from src.config.settings import clients, INDEX_NAME
from src.config.config_manager import config_manager
await clients.initialize()
try:
await clients.opensearch.indices.delete(index=INDEX_NAME)
await asyncio.sleep(1)
except Exception:
pass
app = await create_app()
await startup_tasks(app.state.services)
# Mark configuration as edited so /settings accepts updates
config = config_manager.get_config()
config.edited = True
from src.main import _ensure_opensearch_index
await _ensure_opensearch_index()
transport = httpx.ASGITransport(app=app)
try:
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
await wait_for_service_ready(client)
async def _upload_doc(name: str, text: str) -> None:
file_path = tmp_path / name
file_path.write_text(text)
files = {
"file": (
name,
file_path.read_bytes(),
"text/markdown",
)
}
resp = await client.post("/upload", files=files)
assert resp.status_code == 201, resp.text
async def _wait_for_models(expected_models: set[str], query: str = "physics"):
deadline = asyncio.get_event_loop().time() + 30.0
last_payload = None
while asyncio.get_event_loop().time() < deadline:
resp = await client.post(
"/search",
json={"query": query, "limit": 10, "scoreThreshold": 0},
)
if resp.status_code != 200:
last_payload = resp.text
await asyncio.sleep(0.5)
continue
payload = resp.json()
buckets = (
payload.get("aggregations", {})
.get("embedding_models", {})
.get("buckets", [])
)
models = {b.get("key") for b in buckets if b.get("key")}
if expected_models <= models:
return payload
last_payload = payload
await asyncio.sleep(0.5)
raise AssertionError(
f"Embedding models not detected. Last payload: {last_payload}"
)
# Start with explicit small embedding model
resp = await client.post(
"/settings",
json={"embedding_model": "text-embedding-3-small"},
)
assert resp.status_code == 200, resp.text
# Ingest first document (small model)
await _upload_doc("doc-small.md", "Physics basics and fundamental principles.")
payload_small = await _wait_for_models({"text-embedding-3-small"})
result_models_small = {r.get("embedding_model") for r in payload_small.get("results", []) if r.get("embedding_model")}
assert "text-embedding-3-small" in result_models_small or not result_models_small
# Update embedding model via settings
resp = await client.post(
"/settings",
json={"embedding_model": "text-embedding-3-large"},
)
assert resp.status_code == 200, resp.text
# Ingest second document which should use the large embedding model
await _upload_doc("doc-large.md", "Advanced physics covers quantum topics extensively.")
payload = await _wait_for_models({"text-embedding-3-small", "text-embedding-3-large"})
buckets = payload.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
models = {b.get("key") for b in buckets}
assert {"text-embedding-3-small", "text-embedding-3-large"} <= models
result_models = {
r.get("embedding_model")
for r in payload.get("results", [])
if r.get("embedding_model")
}
assert {"text-embedding-3-small", "text-embedding-3-large"} <= result_models
finally:
from src.config.settings import clients
try:
await clients.close()
except Exception:
pass
@pytest.mark.parametrize("disable_langflow_ingest", [True, False])
@pytest.mark.asyncio
async def test_router_upload_ingest_traditional(tmp_path: Path, disable_langflow_ingest: bool):