add e2e eval
This commit is contained in:
parent
948a0057fb
commit
b35729643d
5 changed files with 156 additions and 35 deletions
|
|
@ -37,16 +37,25 @@ class EvalResponse(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class EvalAddEpisodeResults(BaseModel):
|
||||
baseline_is_better: bool = Field(
|
||||
...,
|
||||
description='boolean if the baseline extraction is higher quality than the candidate extraction.',
|
||||
)
|
||||
|
||||
|
||||
class Prompt(Protocol):
|
||||
qa_prompt: PromptVersion
|
||||
eval_prompt: PromptVersion
|
||||
query_expansion: PromptVersion
|
||||
eval_add_episode_results: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
qa_prompt: PromptFunction
|
||||
eval_prompt: PromptFunction
|
||||
query_expansion: PromptFunction
|
||||
eval_add_episode_results: PromptFunction
|
||||
|
||||
|
||||
def query_expansion(context: dict[str, Any]) -> list[Message]:
|
||||
|
|
@ -112,8 +121,41 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
|
|||
]
|
||||
|
||||
|
||||
def eval_add_episode_results(context: dict[str, Any]) -> list[Message]:
|
||||
sys_prompt = """You are a judge that determines whether a baseline graph building result from a list of messages is better
|
||||
than a candidate graph building result based on the same messages."""
|
||||
|
||||
user_prompt = f"""
|
||||
Given the following PREVIOUS MESSAGES and MESSAGE, determine if the BASELINE graph data extracted from the
|
||||
conversation is higher quality than the CANDIDATE graph data extracted from the conversation.
|
||||
|
||||
Return False if the BASELINE extraction is better, and True otherwise. If the CANDIDATE extraction and
|
||||
BASELINE extraction are near identical in quality, return True.
|
||||
|
||||
<PREVIOUS MESSAGES>
|
||||
{context['previous_messages']}
|
||||
</PREVIOUS MESSAGES>
|
||||
<MESSAGE>
|
||||
{context['answer']}
|
||||
</MESSAGE>
|
||||
|
||||
<BASELINE>
|
||||
{context['baseline']}
|
||||
</BASELINE>
|
||||
|
||||
<CANDIDATE>
|
||||
{context['candidate']}
|
||||
</CANDIDATE>
|
||||
"""
|
||||
return [
|
||||
Message(role='system', content=sys_prompt),
|
||||
Message(role='user', content=user_prompt),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {
|
||||
'qa_prompt': qa_prompt,
|
||||
'eval_prompt': eval_prompt,
|
||||
'query_expansion': query_expansion,
|
||||
'eval_add_episode_results': eval_add_episode_results,
|
||||
}
|
||||
|
|
|
|||
54
poetry.lock
generated
54
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
|
|
@ -1008,13 +1008,13 @@ zstd = ["zstandard (>=0.18.0)"]
|
|||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.30.1"
|
||||
version = "0.30.2"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "huggingface_hub-0.30.1-py3-none-any.whl", hash = "sha256:0f6aa5ec5a4e68e5b9e45d556b4e5ea180c58f5a5ffa734e7f38c9d573028959"},
|
||||
{file = "huggingface_hub-0.30.1.tar.gz", hash = "sha256:f379e8b8d0791295602538856638460ae3cf679c7f304201eb80fb98c771950e"},
|
||||
{file = "huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28"},
|
||||
{file = "huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -1101,13 +1101,13 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio
|
|||
|
||||
[[package]]
|
||||
name = "ipython"
|
||||
version = "8.34.0"
|
||||
version = "8.35.0"
|
||||
description = "IPython: Productive Interactive Computing"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
files = [
|
||||
{file = "ipython-8.34.0-py3-none-any.whl", hash = "sha256:0419883fa46e0baa182c5d50ebb8d6b49df1889fdb70750ad6d8cfe678eda6e3"},
|
||||
{file = "ipython-8.34.0.tar.gz", hash = "sha256:c31d658e754673ecc6514583e7dda8069e47136eb62458816b7d1e6625948b5a"},
|
||||
{file = "ipython-8.35.0-py3-none-any.whl", hash = "sha256:e6b7470468ba6f1f0a7b116bb688a3ece2f13e2f94138e508201fad677a788ba"},
|
||||
{file = "ipython-8.35.0.tar.gz", hash = "sha256:d200b7d93c3f5883fc36ab9ce28a18249c7706e51347681f80a0aef9895f2520"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -1135,7 +1135,7 @@ notebook = ["ipywidgets", "notebook"]
|
|||
parallel = ["ipyparallel"]
|
||||
qtconsole = ["qtconsole"]
|
||||
test = ["packaging", "pickleshare", "pytest", "pytest-asyncio (<0.22)", "testpath"]
|
||||
test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"]
|
||||
test-extra = ["curio", "ipython[test]", "jupyter_ai", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"]
|
||||
|
||||
[[package]]
|
||||
name = "isoduration"
|
||||
|
|
@ -2359,13 +2359,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.70.0"
|
||||
version = "1.71.0"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "openai-1.70.0-py3-none-any.whl", hash = "sha256:f6438d053fd8b2e05fd6bef70871e832d9bbdf55e119d0ac5b92726f1ae6f614"},
|
||||
{file = "openai-1.70.0.tar.gz", hash = "sha256:e52a8d54c3efeb08cf58539b5b21a5abef25368b5432965e4de88cdf4e091b2b"},
|
||||
{file = "openai-1.71.0-py3-none-any.whl", hash = "sha256:e1c643738f1fff1af52bce6ef06a7716c95d089281e7011777179614f32937aa"},
|
||||
{file = "openai-1.71.0.tar.gz", hash = "sha256:52b20bb990a1780f9b0b8ccebac93416343ebd3e4e714e3eff730336833ca207"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -2380,7 +2380,7 @@ typing-extensions = ">=4.11,<5"
|
|||
|
||||
[package.extras]
|
||||
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||
realtime = ["websockets (>=13,<15)"]
|
||||
realtime = ["websockets (>=13,<16)"]
|
||||
voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2887,13 +2887,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.11.2"
|
||||
version = "2.11.3"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pydantic-2.11.2-py3-none-any.whl", hash = "sha256:7f17d25846bcdf89b670a86cdfe7b29a9f1c9ca23dee154221c9aa81845cfca7"},
|
||||
{file = "pydantic-2.11.2.tar.gz", hash = "sha256:2138628e050bd7a1e70b91d4bf4a91167f4ad76fdb83209b107c8d84b854917e"},
|
||||
{file = "pydantic-2.11.3-py3-none-any.whl", hash = "sha256:a082753436a07f9ba1289c6ffa01cd93db3548776088aa917cc43b63f68fa60f"},
|
||||
{file = "pydantic-2.11.3.tar.gz", hash = "sha256:7471657138c16adad9322fe3070c0116dd6c3ad8d649300e3cbdfe91f4db4ec3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -4266,13 +4266,13 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,
|
|||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "4.51.0"
|
||||
version = "4.51.1"
|
||||
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
|
||||
optional = false
|
||||
python-versions = ">=3.9.0"
|
||||
files = [
|
||||
{file = "transformers-4.51.0-py3-none-any.whl", hash = "sha256:2e6baa476735ab8adccbaee6961525a0d1ce8c21d49293af30ef5ee4b082f64d"},
|
||||
{file = "transformers-4.51.0.tar.gz", hash = "sha256:2d302563ff6c2cc2d0e88ef352cf059f9a21ce18102fd43662bb1246f70b8a84"},
|
||||
{file = "transformers-4.51.1-py3-none-any.whl", hash = "sha256:c7038e216afb2a3e9b00dd12d87ad5e3af4c30895f70b28e92f65459eded0161"},
|
||||
{file = "transformers-4.51.1.tar.gz", hash = "sha256:206ea0b75dfde142ed7495b911da76579dce6ea249cc3695fdd29a544a9e007b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -4290,17 +4290,17 @@ tqdm = ">=4.27"
|
|||
[package.extras]
|
||||
accelerate = ["accelerate (>=0.26.0)"]
|
||||
agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=2.0)"]
|
||||
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.3.2,<0.4)", "librosa", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"]
|
||||
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.3.2,<0.4)", "librosa", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"]
|
||||
audio = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
benchmark = ["optimum-benchmark (>=0.3.0)"]
|
||||
codecarbon = ["codecarbon (>=2.8.1)"]
|
||||
deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
|
||||
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.3.2,<0.4)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
|
||||
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "kernels (>=0.3.2,<0.4)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.3.2,<0.4)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
|
||||
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kernels (>=0.3.2,<0.4)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
|
||||
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
flax-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
ftfy = ["ftfy"]
|
||||
hf-xet = ["hf-xet"]
|
||||
hub-kernels = ["kernels (>=0.3.2,<0.4)"]
|
||||
|
|
@ -4321,16 +4321,16 @@ sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
|
|||
serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
|
||||
sigopt = ["sigopt"]
|
||||
sklearn = ["scikit-learn"]
|
||||
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
tf-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
tiktoken = ["blobfile", "tiktoken"]
|
||||
timm = ["timm (<=1.0.11)"]
|
||||
tokenizers = ["tokenizers (>=0.21,<0.22)"]
|
||||
torch = ["accelerate (>=0.26.0)", "torch (>=2.0)"]
|
||||
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
torch-speech = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
|
||||
torchhub = ["filelock", "huggingface-hub (>=0.30.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"]
|
||||
video = ["av"]
|
||||
|
|
|
|||
|
|
@ -134,8 +134,8 @@
|
|||
" max_num_previous_messages, message_index_across_sessions\n",
|
||||
" )\n",
|
||||
" previous_snippets = all_snippets_this_session[\n",
|
||||
" message_index_across_sessions - num_previous_messages:\n",
|
||||
" ]\n",
|
||||
" message_index_across_sessions - num_previous_messages :\n",
|
||||
" ]\n",
|
||||
" previous_messages_only = [\n",
|
||||
" {\n",
|
||||
" 'role': previous_snippet['message']['role'],\n",
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@
|
|||
"Requirement already satisfied: httpcore==1.* in /Users/prestonrasmussen/Library/Caches/pypoetry/virtualenvs/graphiti-core-XzHUgKi9-py3.12/lib/python3.12/site-packages (from httpx<1,>=0.23.0->openai<2.0.0,>=1.53.0->graphiti-core) (1.0.6)\r\n",
|
||||
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/prestonrasmussen/Library/Caches/pypoetry/virtualenvs/graphiti-core-XzHUgKi9-py3.12/lib/python3.12/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai<2.0.0,>=1.53.0->graphiti-core) (0.14.0)\r\n",
|
||||
"\r\n",
|
||||
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m24.1\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m25.0.1\u001B[0m\r\n",
|
||||
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0.1\u001b[0m\r\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\r\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
|
|
@ -316,10 +316,10 @@
|
|||
"\n",
|
||||
" df['message'] = df.apply(\n",
|
||||
" lambda row: '|' * 10\n",
|
||||
" + f\" {row['message_role']} \"\n",
|
||||
" + '|' * 10\n",
|
||||
" + '\\n\\n'\n",
|
||||
" + f\"{row['message']}\"\n",
|
||||
" + f\" {row['message_role']} \"\n",
|
||||
" + '|' * 10\n",
|
||||
" + '\\n\\n'\n",
|
||||
" + f\"{row['message']}\"\n",
|
||||
" if row['message'] is not None\n",
|
||||
" else None,\n",
|
||||
" axis=1,\n",
|
||||
|
|
|
|||
79
tests/evals/eval_e2e_graph_building.py
Normal file
79
tests/evals/eval_e2e_graph_building.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Tuple
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.graphiti import AddEpisodeResults
|
||||
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
from graphiti_core.utils.maintenance import clear_data
|
||||
from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER
|
||||
|
||||
|
||||
async def build_graph(
|
||||
multi_session: list[int], session_length: int, graphiti: Graphiti
|
||||
) -> Tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]:
|
||||
# Get longmemeval dataset
|
||||
lme_dataset_option = 'data/longmemeval_oracle.json' # Can be _oracle, _s, or _m
|
||||
lme_dataset_df = pd.read_json(lme_dataset_option)
|
||||
|
||||
add_episode_results: dict[str, list[AddEpisodeResults]] = {}
|
||||
add_episode_context: dict[str, list[str]] = {}
|
||||
for multi_session_idx in multi_session:
|
||||
multi_session = lme_dataset_df['haystack_sessions'].iloc[multi_session_idx]
|
||||
multi_session_dates = lme_dataset_df['haystack_dates'].iloc[multi_session_idx]
|
||||
|
||||
user_id = 'lme_oracle_experiment_user_' + str(multi_session_idx)
|
||||
await clear_data(graphiti.driver, [user_id])
|
||||
|
||||
add_episode_results[user_id] = []
|
||||
add_episode_context[user_id] = []
|
||||
|
||||
for session_idx, session in enumerate(multi_session):
|
||||
if session_idx >= session_length:
|
||||
continue
|
||||
for msx_idx, msg in enumerate(session):
|
||||
date = multi_session_dates[session_idx] + ' UTC'
|
||||
date_format = '%Y/%m/%d (%a) %H:%M UTC'
|
||||
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
|
||||
|
||||
episode_body = f"{msg["role"]}: {msg["content"]}"
|
||||
results = await graphiti.add_episode(
|
||||
name=msg['name'],
|
||||
episode_body=episode_body,
|
||||
reference_time=date_string,
|
||||
source=EpisodeType.message,
|
||||
source_description='',
|
||||
group_id=user_id,
|
||||
)
|
||||
|
||||
add_episode_results[user_id].append(results)
|
||||
return add_episode_results, add_episode_context
|
||||
|
||||
|
||||
async def build_baseline_graph(multi_session: list[int], session_length: int):
|
||||
# Use gpt-4o for graph building baseline
|
||||
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4o'))
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
||||
|
||||
add_episode_results, _ = await build_graph(multi_session, session_length, graphiti)
|
||||
|
||||
|
||||
async def eval_graph(multi_session: list[int], session_length: int, llm_client=OpenAIClient()):
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
||||
baseline_results: dict[str, list[AddEpisodeResults]] = {}
|
||||
add_episode_results, add_episode_context = await build_graph(
|
||||
multi_session, session_length, graphiti
|
||||
)
|
||||
|
||||
for user_id in add_episode_results:
|
||||
for baseline_result, add_episode_result, episodes in zip(
|
||||
baseline_results[user_id], add_episode_results[user_id], add_episode_context[user_id]
|
||||
):
|
||||
context = {
|
||||
'baseline': baseline_result,
|
||||
'candidate': add_episode_result,
|
||||
'message': episodes[0],
|
||||
'previous_messages': episodes[1:],
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue