Cog 1293 corpus builder custom cognify tasks (#527)
<!-- .github/pull_request_template.md --> ## Description - Enable custom tasks in corpus building ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a configurable option to specify the task retrieval strategy during corpus building. - Enhanced the workflow with integrated task fetching, featuring a default retrieval mechanism. - Updated evaluation configuration to support customizable task selection for more flexible operations. - Added a new abstract base class for defining various task retrieval strategies. - Introduced a new enumeration to map task getter types to their corresponding classes. - **Dependencies** - Added a new dependency for downloading files from Google Drive. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
e6db870264
commit
bb8cb692e0
11 changed files with 104 additions and 19 deletions
|
|
@ -10,7 +10,7 @@ from evals.eval_framework.benchmark_adapters.twowikimultihop_adapter import TwoW
|
|||
class BenchmarkAdapter(Enum):
|
||||
DUMMY = ("Dummy", DummyAdapter)
|
||||
HOTPOTQA = ("HotPotQA", HotpotQAAdapter)
|
||||
MUSIQUE = ('Musique', MusiqueQAAdapter)
|
||||
MUSIQUE = ("Musique", MusiqueQAAdapter)
|
||||
TWOWIKIMULTIHOP = ("TwoWikiMultiHop", TwoWikiMultihopAdapter)
|
||||
|
||||
def __new__(cls, adapter_name: str, adapter_class: Type):
|
||||
|
|
|
|||
|
|
@ -18,10 +18,8 @@ class MusiqueQAAdapter(BaseBenchmarkAdapter):
|
|||
dataset_info = {
|
||||
# Name of the final file we want to load
|
||||
"filename": "musique_ans_v1.0_dev.jsonl",
|
||||
|
||||
# A Google Drive URL (or share link) to the ZIP containing this file
|
||||
"download_url": "https://drive.google.com/file/d/1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h/view?usp=sharing",
|
||||
|
||||
# The name of the ZIP archive we expect after downloading
|
||||
"zip_filename": "musique_v1.0.zip",
|
||||
}
|
||||
|
|
@ -69,9 +67,7 @@ class MusiqueQAAdapter(BaseBenchmarkAdapter):
|
|||
for item in data:
|
||||
# Each 'paragraphs' is a list of dicts; we can concatenate their 'paragraph_text'
|
||||
paragraphs = item.get("paragraphs", [])
|
||||
combined_paragraphs = " ".join(
|
||||
paragraph["paragraph_text"] for paragraph in paragraphs
|
||||
)
|
||||
combined_paragraphs = " ".join(paragraph["paragraph_text"] for paragraph in paragraphs)
|
||||
corpus_list.append(combined_paragraphs)
|
||||
|
||||
# Example question & answer
|
||||
|
|
|
|||
|
|
@ -3,11 +3,15 @@ import logging
|
|||
from typing import Optional, Tuple, List, Dict, Union, Any
|
||||
|
||||
from evals.eval_framework.benchmark_adapters.benchmark_adapters import BenchmarkAdapter
|
||||
from evals.eval_framework.corpus_builder.task_getters.task_getters import TaskGetters
|
||||
from evals.eval_framework.corpus_builder.task_getters.base_task_getter import BaseTaskGetter
|
||||
from cognee.shared.utils import setup_logging
|
||||
|
||||
|
||||
class CorpusBuilderExecutor:
|
||||
def __init__(self, benchmark: Union[str, Any] = "Dummy") -> None:
|
||||
def __init__(
|
||||
self, benchmark: Union[str, Any] = "Dummy", task_getter_type: str = "DEFAULT"
|
||||
) -> None:
|
||||
if isinstance(benchmark, str):
|
||||
try:
|
||||
adapter_enum = BenchmarkAdapter(benchmark)
|
||||
|
|
@ -20,6 +24,13 @@ class CorpusBuilderExecutor:
|
|||
self.raw_corpus = None
|
||||
self.questions = None
|
||||
|
||||
try:
|
||||
task_enum = TaskGetters(task_getter_type)
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid task getter type: {task_getter_type}")
|
||||
|
||||
self.task_getter: BaseTaskGetter = task_enum.getter_class()
|
||||
|
||||
def load_corpus(self, limit: Optional[int] = None) -> Tuple[List[Dict], List[str]]:
|
||||
self.raw_corpus, self.questions = self.adapter.load_corpus(limit=limit)
|
||||
return self.raw_corpus, self.questions
|
||||
|
|
@ -32,12 +43,10 @@ class CorpusBuilderExecutor:
|
|||
async def run_cognee(self) -> None:
|
||||
setup_logging(logging.ERROR)
|
||||
|
||||
# Pruning system and databases.
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
# Adding corpus elements to the cognee metastore.
|
||||
await cognee.add(self.raw_corpus)
|
||||
|
||||
# Running cognify to build the knowledge graph.
|
||||
await cognee.cognify()
|
||||
tasks = await self.task_getter.get_tasks()
|
||||
await cognee.cognify(tasks=tasks)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,10 @@ async def create_and_insert_questions_table(questions_payload):
|
|||
async def run_corpus_builder(params: dict) -> None:
|
||||
if params.get("building_corpus_from_scratch"):
|
||||
logging.info("Corpus Builder started...")
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark=params["benchmark"])
|
||||
corpus_builder = CorpusBuilderExecutor(
|
||||
benchmark=params["benchmark"],
|
||||
task_getter_type=params.get("task_getter_type", "Default"),
|
||||
)
|
||||
questions = await corpus_builder.build_corpus(
|
||||
limit=params.get("number_of_samples_in_corpus")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
|
||||
|
||||
class BaseTaskGetter(ABC):
|
||||
"""Abstract base class for asynchronous task retrieval implementations."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_tasks(self) -> List[Task]:
|
||||
"""Asynchronously retrieve a list of tasks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
from cognee.api.v1.cognify.cognify_v2 import get_default_tasks
|
||||
from typing import List
|
||||
from evals.eval_framework.corpus_builder.task_getters.base_task_getter import BaseTaskGetter
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
|
||||
|
||||
class DefaultTaskGetter(BaseTaskGetter):
|
||||
"""Default task getter that retrieves tasks using the standard get_default_tasks function."""
|
||||
|
||||
async def get_tasks(self) -> List[Task]:
|
||||
"""Retrieve default tasks asynchronously."""
|
||||
return await get_default_tasks()
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from enum import Enum
|
||||
from typing import Type
|
||||
from evals.eval_framework.corpus_builder.task_getters.default_task_getter import DefaultTaskGetter
|
||||
|
||||
|
||||
class TaskGetters(Enum):
|
||||
"""Enum mapping task getter types to their respective classes."""
|
||||
|
||||
DEFAULT = ("Default", DefaultTaskGetter)
|
||||
# CUSTOM = ("Custom", CustomTaskGetter)
|
||||
|
||||
def __new__(cls, getter_name: str, getter_class: Type):
|
||||
obj = object.__new__(cls)
|
||||
obj._value_ = getter_name
|
||||
obj.getter_class = getter_class
|
||||
return obj
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
|
@ -8,6 +8,7 @@ class EvalConfig(BaseSettings):
|
|||
building_corpus_from_scratch: bool = True
|
||||
number_of_samples_in_corpus: int = 1
|
||||
benchmark: str = "Dummy" # Options: 'HotPotQA', 'Dummy', 'TwoWikiMultiHop'
|
||||
task_getter_type: str = "Default"
|
||||
|
||||
# Question answering params
|
||||
answering_questions: bool = True
|
||||
|
|
@ -48,6 +49,7 @@ class EvalConfig(BaseSettings):
|
|||
"metrics_path": self.metrics_path,
|
||||
"dashboard_path": self.dashboard_path,
|
||||
"deepeval_model": self.deepeval_model,
|
||||
"task_getter_type": self.task_getter_type,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
45
poetry.lock
generated
45
poetry.lock
generated
|
|
@ -550,7 +550,7 @@ typecheck = ["mypy"]
|
|||
name = "beautifulsoup4"
|
||||
version = "4.13.3"
|
||||
description = "Screen-scraping library"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "beautifulsoup4-4.13.3-py3-none-any.whl", hash = "sha256:99045d7d3f08f91f0d656bc9b7efbae189426cd913d830294a15eefa0ea4df16"},
|
||||
|
|
@ -1067,7 +1067,6 @@ files = [
|
|||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:761817a3377ef15ac23cd7834715081791d4ec77f9297ee694ca1ee9c2c7e5eb"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3c672a53c0fb4725a29c303be906d3c1fa99c32f58abe008a82705f9ee96f40b"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4ac4c9f37eba52cb6fbeaf5b59c152ea976726b865bd4cf87883a7e7006cc543"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:60eb32934076fa07e4316b7b2742fa52cbb190b42c2df2863dbc4230a0a9b385"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ed3534eb1090483c96178fcb0f8893719d96d5274dfde98aa6add34614e97c8e"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f3f6fdfa89ee2d9d496e2c087cebef9d4fcbb0ad63c40e821b39f74bf48d9c5e"},
|
||||
{file = "cryptography-44.0.0-cp37-abi3-win32.whl", hash = "sha256:eb33480f1bad5b78233b0ad3e1b0be21e8ef1da745d8d2aecbb20671658b9053"},
|
||||
|
|
@ -1078,7 +1077,6 @@ files = [
|
|||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c5eb858beed7835e5ad1faba59e865109f3e52b3783b9ac21e7e47dc5554e289"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f53c2c87e0fb4b0c00fa9571082a057e37690a8f12233306161c8f4b819960b7"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:9e6fc8a08e116fb7c7dd1f040074c9d7b51d74a8ea40d4df2fc7aa08b76b9e6c"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9abcc2e083cbe8dde89124a47e5e53ec38751f0d7dfd36801008f316a127d7ba"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d2436114e46b36d00f8b72ff57e598978b37399d2786fd39793c36c6d5cb1c64"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a01956ddfa0a6790d594f5b34fc1bfa6098aca434696a03cfdbe469b8ed79285"},
|
||||
{file = "cryptography-44.0.0-cp39-abi3-win32.whl", hash = "sha256:eca27345e1214d1b9f9490d200f9db5a874479be914199194e746c893788d417"},
|
||||
|
|
@ -1945,6 +1943,26 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe,
|
|||
test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"]
|
||||
tqdm = ["tqdm"]
|
||||
|
||||
[[package]]
|
||||
name = "gdown"
|
||||
version = "5.2.0"
|
||||
description = "Google Drive Public File/Folder Downloader"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "gdown-5.2.0-py3-none-any.whl", hash = "sha256:33083832d82b1101bdd0e9df3edd0fbc0e1c5f14c9d8c38d2a35bf1683b526d6"},
|
||||
{file = "gdown-5.2.0.tar.gz", hash = "sha256:2145165062d85520a3cd98b356c9ed522c5e7984d408535409fd46f94defc787"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
beautifulsoup4 = "*"
|
||||
filelock = "*"
|
||||
requests = {version = "*", extras = ["socks"]}
|
||||
tqdm = "*"
|
||||
|
||||
[package.extras]
|
||||
test = ["build", "mypy", "pytest", "pytest-xdist", "ruff", "twine", "types-requests", "types-setuptools"]
|
||||
|
||||
[[package]]
|
||||
name = "ghp-import"
|
||||
version = "2.1.0"
|
||||
|
|
@ -5204,8 +5222,8 @@ files = [
|
|||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
|
|
@ -6193,8 +6211,8 @@ astroid = ">=3.3.8,<=3.4.0-dev0"
|
|||
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
|
||||
dill = [
|
||||
{version = ">=0.2", markers = "python_version < \"3.11\""},
|
||||
{version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
{version = ">=0.3.7", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
]
|
||||
isort = ">=4.2.5,<5.13.0 || >5.13.0,<7"
|
||||
mccabe = ">=0.6,<0.8"
|
||||
|
|
@ -6295,6 +6313,18 @@ docs = ["myst_parser", "sphinx", "sphinx_rtd_theme"]
|
|||
full = ["Pillow (>=8.0.0)", "PyCryptodome", "cryptography"]
|
||||
image = ["Pillow (>=8.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pysocks"
|
||||
version = "1.7.1"
|
||||
description = "A Python SOCKS client module. See https://github.com/Anorov/PySocks for more information."
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
files = [
|
||||
{file = "PySocks-1.7.1-py27-none-any.whl", hash = "sha256:08e69f092cc6dbe92a0fdd16eeb9b9ffbc13cadfe5ca4c7bd92ffb078b293299"},
|
||||
{file = "PySocks-1.7.1-py3-none-any.whl", hash = "sha256:2725bd0a9925919b9b51739eea5f9e2bae91e83288108a9ad338b2e3a4435ee5"},
|
||||
{file = "PySocks-1.7.1.tar.gz", hash = "sha256:3f8804571ebe159c380ac6de37643bb4685970655d3bba243530d6558b799aa0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "7.4.4"
|
||||
|
|
@ -7025,6 +7055,7 @@ files = [
|
|||
certifi = ">=2017.4.17"
|
||||
charset-normalizer = ">=2,<4"
|
||||
idna = ">=2.5,<4"
|
||||
PySocks = {version = ">=1.5.6,<1.5.7 || >1.5.7", optional = true, markers = "extra == \"socks\""}
|
||||
urllib3 = ">=1.21.1,<3"
|
||||
|
||||
[package.extras]
|
||||
|
|
@ -7786,7 +7817,7 @@ files = [
|
|||
name = "soupsieve"
|
||||
version = "2.6"
|
||||
description = "A modern CSS selector implementation for Beautiful Soup."
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"},
|
||||
|
|
@ -9185,4 +9216,4 @@ weaviate = ["weaviate-client"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10.0,<3.13"
|
||||
content-hash = "a213bb86ac98ac7b327e40eafae1ef038ff74071dea144b8c005eb86b252612f"
|
||||
content-hash = "fa5e6192627dc994e4d4b8c3471f836326afbe045b44407bc5e29ec11bff5f08"
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ parso = {version = "^0.8.4", optional = true}
|
|||
jedi = {version = "^0.19.2", optional = true}
|
||||
plotly = "^6.0.0"
|
||||
mistral-common = {version = "^1.5.2", optional = true}
|
||||
gdown = "^5.2.0"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue