Merge branch 'COG-970-refactor-tokenizing' of github.com:topoteretes/cognee into COG-970-refactor-tokenizing

This commit is contained in:
Igor Ilic 2025-01-28 14:48:40 +01:00
commit e0b7be7cf0
3 changed files with 85 additions and 4 deletions

View file

@ -171,17 +171,18 @@ async def main():
)
parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument("--metrics", type=str, nargs="+", default=["Correctness"])
parser.add_argument("--out_dir", type=str, help="Dir to save eval results")
args = parser.parse_args()
if args.rag_option == "cognee_incremental":
avg_scores = await incremental_eval_on_QA_dataset(
args.dataset, args.num_samples, args.metrics
args.dataset, args.num_samples, args.metrics, args.out_dir
)
else:
avg_scores = await eval_on_QA_dataset(
args.dataset, args.rag_option, args.num_samples, args.metrics
args.dataset, args.rag_option, args.num_samples, args.metrics, args.out_dir
)
logger.info(f"{avg_scores}")

View file

@ -0,0 +1,75 @@
import subprocess
import json
import argparse
import os
from typing import List
import sys
def run_command(command: List[str]):
try:
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1
)
while True:
stdout_line = process.stdout.readline()
stderr_line = process.stderr.readline()
if stdout_line == "" and stderr_line == "" and process.poll() is not None:
break
if stdout_line:
print(stdout_line.rstrip())
if stderr_line:
print(f"Error: {stderr_line.rstrip()}", file=sys.stderr)
if process.returncode != 0:
raise subprocess.CalledProcessError(process.returncode, command)
finally:
process.stdout.close()
process.stderr.close()
def run_evals_for_paramsfile(params_file, out_dir):
with open(params_file, "r") as file:
parameters = json.load(file)
for metric in parameters["metric_names"]:
params = parameters
params["metric_names"] = [metric]
temp_paramfile = params_file.replace(".json", f"_{metric}.json")
with open(temp_paramfile, "w") as file:
json.dump(params, file)
command = [
"python",
"evals/run_qa_eval.py",
"--params_file",
temp_paramfile,
"--out_dir",
out_dir,
]
run_command(command)
if os.path.exists(temp_paramfile):
os.remove(temp_paramfile)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--params_file", type=str, required=True, help="Which dataset to evaluate on"
)
parser.add_argument("--out_dir", type=str, help="Dir to save eval results")
args = parser.parse_args()
run_evals_for_paramsfile(args.params_file, args.out_dir)
if __name__ == "__main__":
main()

View file

@ -42,9 +42,13 @@ def save_table_as_image(df, image_path):
def save_results_as_image(results, out_path):
for dataset, num_samples_data in results.items():
for num_samples, table_data in num_samples_data.items():
for rag_option, metric_data in table_data.items():
for name, value in metric_data.items():
metric_name = name
break
df = pd.DataFrame.from_dict(table_data, orient="index")
df.index.name = f"Dataset: {dataset}, Num Samples: {num_samples}"
image_path = out_path / Path(f"table_{dataset}_{num_samples}.png")
image_path = out_path / Path(f"table_{dataset}_{num_samples}_{metric_name}.png")
save_table_as_image(df, image_path)
@ -54,7 +58,8 @@ def get_combinations(parameters):
except ValidationError as e:
raise ValidationError(f"Invalid parameter set: {e.message}")
params_for_combos = {k: v for k, v in parameters.items() if k != "metric_name"}
# params_for_combos = {k: v for k, v in parameters.items() if k != "metric_name"}
params_for_combos = {k: v for k, v in parameters.items()}
keys, values = zip(*params_for_combos.items())
combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)]
return combinations