Merge branch 'COG-970-refactor-tokenizing' of github.com:topoteretes/cognee into COG-970-refactor-tokenizing
This commit is contained in:
commit
e0b7be7cf0
3 changed files with 85 additions and 4 deletions
|
|
@ -171,17 +171,18 @@ async def main():
|
|||
)
|
||||
parser.add_argument("--num_samples", type=int, default=500)
|
||||
parser.add_argument("--metrics", type=str, nargs="+", default=["Correctness"])
|
||||
parser.add_argument("--out_dir", type=str, help="Dir to save eval results")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rag_option == "cognee_incremental":
|
||||
avg_scores = await incremental_eval_on_QA_dataset(
|
||||
args.dataset, args.num_samples, args.metrics
|
||||
args.dataset, args.num_samples, args.metrics, args.out_dir
|
||||
)
|
||||
|
||||
else:
|
||||
avg_scores = await eval_on_QA_dataset(
|
||||
args.dataset, args.rag_option, args.num_samples, args.metrics
|
||||
args.dataset, args.rag_option, args.num_samples, args.metrics, args.out_dir
|
||||
)
|
||||
|
||||
logger.info(f"{avg_scores}")
|
||||
|
|
|
|||
75
evals/multimetric_qa_eval_run.py
Normal file
75
evals/multimetric_qa_eval_run.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
import subprocess
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
from typing import List
|
||||
import sys
|
||||
|
||||
|
||||
def run_command(command: List[str]):
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1
|
||||
)
|
||||
|
||||
while True:
|
||||
stdout_line = process.stdout.readline()
|
||||
stderr_line = process.stderr.readline()
|
||||
|
||||
if stdout_line == "" and stderr_line == "" and process.poll() is not None:
|
||||
break
|
||||
|
||||
if stdout_line:
|
||||
print(stdout_line.rstrip())
|
||||
if stderr_line:
|
||||
print(f"Error: {stderr_line.rstrip()}", file=sys.stderr)
|
||||
|
||||
if process.returncode != 0:
|
||||
raise subprocess.CalledProcessError(process.returncode, command)
|
||||
finally:
|
||||
process.stdout.close()
|
||||
process.stderr.close()
|
||||
|
||||
|
||||
def run_evals_for_paramsfile(params_file, out_dir):
|
||||
with open(params_file, "r") as file:
|
||||
parameters = json.load(file)
|
||||
|
||||
for metric in parameters["metric_names"]:
|
||||
params = parameters
|
||||
params["metric_names"] = [metric]
|
||||
|
||||
temp_paramfile = params_file.replace(".json", f"_{metric}.json")
|
||||
with open(temp_paramfile, "w") as file:
|
||||
json.dump(params, file)
|
||||
|
||||
command = [
|
||||
"python",
|
||||
"evals/run_qa_eval.py",
|
||||
"--params_file",
|
||||
temp_paramfile,
|
||||
"--out_dir",
|
||||
out_dir,
|
||||
]
|
||||
|
||||
run_command(command)
|
||||
|
||||
if os.path.exists(temp_paramfile):
|
||||
os.remove(temp_paramfile)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--params_file", type=str, required=True, help="Which dataset to evaluate on"
|
||||
)
|
||||
parser.add_argument("--out_dir", type=str, help="Dir to save eval results")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
run_evals_for_paramsfile(args.params_file, args.out_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -42,9 +42,13 @@ def save_table_as_image(df, image_path):
|
|||
def save_results_as_image(results, out_path):
|
||||
for dataset, num_samples_data in results.items():
|
||||
for num_samples, table_data in num_samples_data.items():
|
||||
for rag_option, metric_data in table_data.items():
|
||||
for name, value in metric_data.items():
|
||||
metric_name = name
|
||||
break
|
||||
df = pd.DataFrame.from_dict(table_data, orient="index")
|
||||
df.index.name = f"Dataset: {dataset}, Num Samples: {num_samples}"
|
||||
image_path = out_path / Path(f"table_{dataset}_{num_samples}.png")
|
||||
image_path = out_path / Path(f"table_{dataset}_{num_samples}_{metric_name}.png")
|
||||
save_table_as_image(df, image_path)
|
||||
|
||||
|
||||
|
|
@ -54,7 +58,8 @@ def get_combinations(parameters):
|
|||
except ValidationError as e:
|
||||
raise ValidationError(f"Invalid parameter set: {e.message}")
|
||||
|
||||
params_for_combos = {k: v for k, v in parameters.items() if k != "metric_name"}
|
||||
# params_for_combos = {k: v for k, v in parameters.items() if k != "metric_name"}
|
||||
params_for_combos = {k: v for k, v in parameters.items()}
|
||||
keys, values = zip(*params_for_combos.items())
|
||||
combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)]
|
||||
return combinations
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue