feat: Adds modal parallel evaluation for retriever development (#844)

<!-- .github/pull_request_template.md -->

## Description
Adds modal parallel evaluation for retriever development

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
hajdul88 2025-05-20 15:16:13 +02:00 committed by GitHub
parent f8f78773dd
commit 5c36a5dd8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 63 additions and 37 deletions

View file

@ -18,15 +18,12 @@ RUN apt-get update && apt-get install -y \
WORKDIR /app WORKDIR /app
COPY pyproject.toml poetry.lock README.md /app/
ENV PYTHONPATH=/app
WORKDIR /app
COPY pyproject.toml poetry.lock /app/
RUN pip install poetry RUN pip install poetry
RUN poetry install --all-extras --no-root --without dev RUN poetry config virtualenvs.create false
RUN poetry install --extras neo4j --extras qdrant --no-root
COPY cognee/ /app/cognee COPY cognee/ /app/cognee
COPY README.md /app/README.md

View file

@ -1,6 +1,6 @@
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List from typing import List, Optional
class EvalConfig(BaseSettings): class EvalConfig(BaseSettings):
@ -43,6 +43,7 @@ class EvalConfig(BaseSettings):
dashboard_path: str = "dashboard.html" dashboard_path: str = "dashboard.html"
direct_llm_system_prompt: str = "direct_llm_eval_system.txt" direct_llm_system_prompt: str = "direct_llm_eval_system.txt"
direct_llm_eval_prompt: str = "direct_llm_eval_prompt.txt" direct_llm_eval_prompt: str = "direct_llm_eval_prompt.txt"
instance_filter: Optional[List[str]] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -68,6 +69,7 @@ class EvalConfig(BaseSettings):
"task_getter_type": self.task_getter_type, "task_getter_type": self.task_getter_type,
"direct_llm_system_prompt": self.direct_llm_system_prompt, "direct_llm_system_prompt": self.direct_llm_system_prompt,
"direct_llm_eval_prompt": self.direct_llm_eval_prompt, "direct_llm_eval_prompt": self.direct_llm_eval_prompt,
"instance_filter": self.instance_filter,
} }

View file

@ -167,4 +167,4 @@ def create_dashboard(
with open(output_file, "w", encoding="utf-8") as f: with open(output_file, "w", encoding="utf-8") as f:
f.write(dashboard_html) f.write(dashboard_html)
return output_file return dashboard_html

View file

@ -1,8 +1,9 @@
import modal import modal
import os import os
import json
import asyncio import asyncio
import datetime import datetime
import hashlib
import json
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.eval_framework.eval_config import EvalConfig from cognee.eval_framework.eval_config import EvalConfig
from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_builder from cognee.eval_framework.corpus_builder.run_corpus_builder import run_corpus_builder
@ -10,8 +11,10 @@ from cognee.eval_framework.answer_generation.run_question_answering_module impor
run_question_answering, run_question_answering,
) )
from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation from cognee.eval_framework.evaluation.run_evaluation_module import run_evaluation
from cognee.eval_framework.metrics_dashboard import create_dashboard
logger = get_logger() logger = get_logger()
vol = modal.Volume.from_name("evaluation_dashboard_results", create_if_missing=True)
def read_and_combine_metrics(eval_params: dict) -> dict: def read_and_combine_metrics(eval_params: dict) -> dict:
@ -46,32 +49,54 @@ image = (
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"), "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
} }
) )
.poetry_install_from_file(poetry_pyproject_toml="pyproject.toml")
.pip_install("protobuf", "h2", "deepeval", "gdown", "plotly") .pip_install("protobuf", "h2", "deepeval", "gdown", "plotly")
) )
@app.function(image=image, concurrency_limit=2, timeout=1800, retries=1) @app.function(image=image, concurrency_limit=10, timeout=86400, volumes={"/data": vol})
async def modal_run_eval(eval_params=None): async def modal_run_eval(eval_params=None):
"""Runs evaluation pipeline and returns combined metrics results.""" """Runs evaluation pipeline and returns combined metrics results."""
if eval_params is None: if eval_params is None:
eval_params = EvalConfig().to_dict() eval_params = EvalConfig().to_dict()
version_name = "baseline"
benchmark_name = os.environ.get("BENCHMARK", eval_params.get("benchmark", "benchmark"))
timestamp = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
answers_filename = (
f"{version_name}_{benchmark_name}_{timestamp}_{eval_params.get('answers_path')}"
)
html_filename = (
f"{version_name}_{benchmark_name}_{timestamp}_{eval_params.get('dashboard_path')}"
)
logger.info(f"Running evaluation with params: {eval_params}") logger.info(f"Running evaluation with params: {eval_params}")
# Run the evaluation pipeline # Run the evaluation pipeline
await run_corpus_builder(eval_params) await run_corpus_builder(eval_params, instance_filter=eval_params.get("instance_filter"))
await run_question_answering(eval_params) await run_question_answering(eval_params)
await run_evaluation(eval_params) answers = await run_evaluation(eval_params)
# Early return if metrics calculation wasn't requested with open("/data/" + answers_filename, "w") as f:
if not eval_params.get("evaluating_answers") or not eval_params.get("calculate_metrics"): json.dump(answers, f, ensure_ascii=False, indent=4)
logger.info( vol.commit()
"Skipping metrics collection as either evaluating_answers or calculate_metrics is False"
if eval_params.get("dashboard"):
logger.info("Generating dashboard...")
html_output = create_dashboard(
metrics_path=eval_params["metrics_path"],
aggregate_metrics_path=eval_params["aggregate_metrics_path"],
output_file=eval_params["dashboard_path"],
benchmark=eval_params["benchmark"],
) )
return None
return read_and_combine_metrics(eval_params) with open("/data/" + html_filename, "w") as f:
f.write(html_output)
vol.commit()
logger.info("Evaluation set finished...")
return True
@app.local_entrypoint() @app.local_entrypoint()
@ -80,37 +105,39 @@ async def main():
configs = [ configs = [
EvalConfig( EvalConfig(
task_getter_type="Default", task_getter_type="Default",
number_of_samples_in_corpus=2, number_of_samples_in_corpus=10,
benchmark="HotPotQA",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True, building_corpus_from_scratch=True,
answering_questions=True, answering_questions=True,
evaluating_answers=True, evaluating_answers=True,
calculate_metrics=True, calculate_metrics=True,
dashboard=False, dashboard=True,
), ),
EvalConfig( EvalConfig(
task_getter_type="Default", task_getter_type="Default",
number_of_samples_in_corpus=10, number_of_samples_in_corpus=10,
benchmark="TwoWikiMultiHop",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True, building_corpus_from_scratch=True,
answering_questions=True, answering_questions=True,
evaluating_answers=True, evaluating_answers=True,
calculate_metrics=True, calculate_metrics=True,
dashboard=False, dashboard=True,
),
EvalConfig(
task_getter_type="Default",
number_of_samples_in_corpus=10,
benchmark="Musique",
qa_engine="cognee_graph_completion",
building_corpus_from_scratch=True,
answering_questions=True,
evaluating_answers=True,
calculate_metrics=True,
dashboard=True,
), ),
] ]
# Run evaluations in parallel with different configurations # Run evaluations in parallel with different configurations
modal_tasks = [modal_run_eval.remote.aio(config.to_dict()) for config in configs] modal_tasks = [modal_run_eval.remote.aio(config.to_dict()) for config in configs]
results = await asyncio.gather(*modal_tasks) await asyncio.gather(*modal_tasks)
# Filter out None results and save combined results
results = [r for r in results if r is not None]
if results:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f"combined_results_{timestamp}.json"
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
logger.info(f"Completed parallel evaluation runs. Results saved to {output_file}")
else:
logger.info("No metrics were collected from any of the evaluation runs")