multi-model integration test
This commit is contained in:
parent
a7c5a9f8f3
commit
fb35417586
1 changed files with 138 additions and 0 deletions
|
|
@ -204,6 +204,144 @@ async def test_upload_and_search_endpoint(tmp_path: Path, disable_langflow_inges
|
|||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_multi_embedding_models(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Ensure /search fans out across multiple embedding models when present."""
|
||||
os.environ["DISABLE_INGEST_WITH_LANGFLOW"] = "true"
|
||||
os.environ["DISABLE_STARTUP_INGEST"] = "true"
|
||||
os.environ["GOOGLE_OAUTH_CLIENT_ID"] = ""
|
||||
os.environ["GOOGLE_OAUTH_CLIENT_SECRET"] = ""
|
||||
|
||||
import sys
|
||||
|
||||
for mod in [
|
||||
"src.api.router",
|
||||
"api.router",
|
||||
"src.api.connector_router",
|
||||
"api.connector_router",
|
||||
"src.config.settings",
|
||||
"config.settings",
|
||||
"src.auth_middleware",
|
||||
"auth_middleware",
|
||||
"src.main",
|
||||
"services.search_service",
|
||||
"src.services.search_service",
|
||||
]:
|
||||
sys.modules.pop(mod, None)
|
||||
|
||||
from src.main import create_app, startup_tasks
|
||||
from src.config.settings import clients, INDEX_NAME
|
||||
from src.config.config_manager import config_manager
|
||||
|
||||
await clients.initialize()
|
||||
try:
|
||||
await clients.opensearch.indices.delete(index=INDEX_NAME)
|
||||
await asyncio.sleep(1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
app = await create_app()
|
||||
await startup_tasks(app.state.services)
|
||||
|
||||
# Mark configuration as edited so /settings accepts updates
|
||||
config = config_manager.get_config()
|
||||
config.edited = True
|
||||
|
||||
from src.main import _ensure_opensearch_index
|
||||
|
||||
await _ensure_opensearch_index()
|
||||
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
await wait_for_service_ready(client)
|
||||
|
||||
async def _upload_doc(name: str, text: str) -> None:
|
||||
file_path = tmp_path / name
|
||||
file_path.write_text(text)
|
||||
files = {
|
||||
"file": (
|
||||
name,
|
||||
file_path.read_bytes(),
|
||||
"text/markdown",
|
||||
)
|
||||
}
|
||||
resp = await client.post("/upload", files=files)
|
||||
assert resp.status_code == 201, resp.text
|
||||
|
||||
async def _wait_for_models(expected_models: set[str], query: str = "physics"):
|
||||
deadline = asyncio.get_event_loop().time() + 30.0
|
||||
last_payload = None
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
resp = await client.post(
|
||||
"/search",
|
||||
json={"query": query, "limit": 10, "scoreThreshold": 0},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
last_payload = resp.text
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
payload = resp.json()
|
||||
buckets = (
|
||||
payload.get("aggregations", {})
|
||||
.get("embedding_models", {})
|
||||
.get("buckets", [])
|
||||
)
|
||||
models = {b.get("key") for b in buckets if b.get("key")}
|
||||
if expected_models <= models:
|
||||
return payload
|
||||
last_payload = payload
|
||||
await asyncio.sleep(0.5)
|
||||
raise AssertionError(
|
||||
f"Embedding models not detected. Last payload: {last_payload}"
|
||||
)
|
||||
|
||||
# Start with explicit small embedding model
|
||||
resp = await client.post(
|
||||
"/settings",
|
||||
json={"embedding_model": "text-embedding-3-small"},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
|
||||
# Ingest first document (small model)
|
||||
await _upload_doc("doc-small.md", "Physics basics and fundamental principles.")
|
||||
payload_small = await _wait_for_models({"text-embedding-3-small"})
|
||||
result_models_small = {r.get("embedding_model") for r in payload_small.get("results", []) if r.get("embedding_model")}
|
||||
assert "text-embedding-3-small" in result_models_small or not result_models_small
|
||||
|
||||
# Update embedding model via settings
|
||||
resp = await client.post(
|
||||
"/settings",
|
||||
json={"embedding_model": "text-embedding-3-large"},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
|
||||
# Ingest second document which should use the large embedding model
|
||||
await _upload_doc("doc-large.md", "Advanced physics covers quantum topics extensively.")
|
||||
|
||||
payload = await _wait_for_models({"text-embedding-3-small", "text-embedding-3-large"})
|
||||
buckets = payload.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
|
||||
models = {b.get("key") for b in buckets}
|
||||
assert {"text-embedding-3-small", "text-embedding-3-large"} <= models
|
||||
|
||||
result_models = {
|
||||
r.get("embedding_model")
|
||||
for r in payload.get("results", [])
|
||||
if r.get("embedding_model")
|
||||
}
|
||||
assert {"text-embedding-3-small", "text-embedding-3-large"} <= result_models
|
||||
finally:
|
||||
from src.config.settings import clients
|
||||
|
||||
try:
|
||||
await clients.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_langflow_ingest", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_upload_ingest_traditional(tmp_path: Path, disable_langflow_ingest: bool):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue