Merge branch 'dev' into add_docstrings

This commit is contained in:
Igor Ilic 2025-01-17 17:00:23 +01:00 committed by GitHub
commit e7f24548dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 131 additions and 48 deletions

View file

@ -258,13 +258,13 @@ pip install cognee
| Name | Type | Current state | Known Issues | | Name | Type | Current state (Mac/Linux) | Known Issues | Current state (Windows) | Known Issues |
|----------|--------------------|-------------------|--------------| |----------|--------------------|---------------------------|--------------|-------------------------|--------------|
| Qdrant | Vector | Stable ✅ | | | Qdrant | Vector | Stable ✅ | | Unstable ❌ | |
| Weaviate | Vector | Stable ✅ | | | Weaviate | Vector | Stable ✅ | | Unstable ❌ | |
| LanceDB | Vector | Stable ✅ | | | LanceDB | Vector | Stable ✅ | | Stable ✅ | |
| Neo4j | Graph | Stable ✅ | | | Neo4j | Graph | Stable ✅ | | Stable ✅ | |
| NetworkX | Graph | Stable ✅ | | | NetworkX | Graph | Stable ✅ | | Stable ✅ | |
| FalkorDB | Vector/Graph | Unstable ❌ | | | FalkorDB | Vector/Graph | Stable ✅ | | Unstable ❌ | |
| PGVector | Vector | Stable ✅ | | | PGVector | Vector | Stable ✅ | | Unstable ❌ | |
| Milvus | Vector | Stable ✅ | | | Milvus | Vector | Stable ✅ | | Unstable ❌ | |

View file

@ -152,7 +152,9 @@ class LanceDBAdapter(VectorDBInterface):
connection = await self.get_connection() connection = await self.get_connection()
collection = await connection.open_table(collection_name) collection = await connection.open_table(collection_name)
results = await collection.vector_search(query_vector).to_pandas() collection_size = await collection.count_rows()
results = await collection.vector_search(query_vector).limit(collection_size).to_pandas()
result_values = list(results.to_dict("index").values()) result_values = list(results.to_dict("index").values())
@ -250,9 +252,16 @@ class LanceDBAdapter(VectorDBInterface):
) )
async def prune(self): async def prune(self):
# Clean up the database if it was set up as temporary connection = await self.get_connection()
collection_names = await connection.table_names()
for collection_name in collection_names:
collection = await connection.open_table(collection_name)
await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name)
if self.url.startswith("/"): if self.url.startswith("/"):
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside LocalStorage.remove_all(self.url)
def get_data_point_schema(self, model_type): def get_data_point_schema(self, model_type):
return copy_model( return copy_model(

View file

@ -9,7 +9,7 @@ from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from evals.qa_dataset_utils import load_qa_dataset from evals.qa_dataset_utils import load_qa_dataset
from evals.qa_metrics_utils import get_metrics from evals.qa_metrics_utils import get_metrics
from evals.qa_context_provider_utils import qa_context_providers from evals.qa_context_provider_utils import qa_context_providers, create_cognee_context_getter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -94,14 +94,29 @@ async def eval_on_QA_dataset(
return results return results
if __name__ == "__main__": async def incremental_eval_on_QA_dataset(
dataset_name_or_filename: str, num_samples, metric_name_list
):
pipeline_slice_names = ["base", "extract_chunks", "extract_graph", "summarize"]
incremental_results = {}
for pipeline_slice_name in pipeline_slice_names:
results = await eval_on_QA_dataset(
dataset_name_or_filename, pipeline_slice_name, num_samples, metric_name_list
)
incremental_results[pipeline_slice_name] = results
return incremental_results
async def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True, help="Which dataset to evaluate on") parser.add_argument("--dataset", type=str, required=True, help="Which dataset to evaluate on")
parser.add_argument( parser.add_argument(
"--rag_option", "--rag_option",
type=str, type=str,
choices=qa_context_providers.keys(), choices=list(qa_context_providers.keys()) + ["cognee_incremental"],
required=True, required=True,
help="RAG option to use for providing context", help="RAG option to use for providing context",
) )
@ -110,7 +125,18 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
avg_scores = asyncio.run( if args.rag_option == "cognee_incremental":
eval_on_QA_dataset(args.dataset, args.rag_option, args.num_samples, args.metrics) avg_scores = await incremental_eval_on_QA_dataset(
) args.dataset, args.num_samples, args.metrics
)
else:
avg_scores = await eval_on_QA_dataset(
args.dataset, args.rag_option, args.num_samples, args.metrics
)
logger.info(f"{avg_scores}") logger.info(f"{avg_scores}")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -4,11 +4,8 @@ These are the official evaluation metrics for HotpotQA taken from https://hotpot
import re import re
import string import string
import sys
from collections import Counter from collections import Counter
import ujson as json
def normalize_answer(s): def normalize_answer(s):
def remove_articles(text): def remove_articles(text):

View file

@ -3,35 +3,49 @@ from cognee.api.v1.search import SearchType
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string
from functools import partial
from cognee.api.v1.cognify.cognify_v2 import get_default_tasks
async def get_raw_context(instance: dict) -> str: async def get_raw_context(instance: dict) -> str:
return instance["context"] return instance["context"]
async def cognify_instance(instance: dict): async def cognify_instance(instance: dict, task_indices: list[int] = None):
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
for title, sentences in instance["context"]: for title, sentences in instance["context"]:
await cognee.add("\n".join(sentences), dataset_name="QA") await cognee.add("\n".join(sentences), dataset_name="QA")
await cognee.cognify("QA") all_cognify_tasks = await get_default_tasks()
if task_indices:
selected_tasks = [all_cognify_tasks[ind] for ind in task_indices]
else:
selected_tasks = all_cognify_tasks
await cognee.cognify("QA", tasks=selected_tasks)
async def get_context_with_cognee(instance: dict) -> str: async def get_context_with_cognee(
await cognify_instance(instance) instance: dict,
task_indices: list[int] = None,
search_types: list[SearchType] = [SearchType.SUMMARIES, SearchType.CHUNKS],
) -> str:
await cognify_instance(instance, task_indices)
# TODO: Fix insights search_results = []
# insights = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"]) for search_type in search_types:
summaries = await cognee.search(SearchType.SUMMARIES, query_text=instance["question"]) search_results += await cognee.search(search_type, query_text=instance["question"])
# search_results = insights + summaries
search_results = summaries
search_results_str = "\n".join([context_item["text"] for context_item in search_results]) search_results_str = "\n".join([context_item["text"] for context_item in search_results])
return search_results_str return search_results_str
def create_cognee_context_getter(
task_indices=None, search_types=[SearchType.SUMMARIES, SearchType.CHUNKS]
):
return partial(get_context_with_cognee, task_indices=task_indices, search_types=search_types)
async def get_context_with_simple_rag(instance: dict) -> str: async def get_context_with_simple_rag(instance: dict) -> str:
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
@ -57,9 +71,19 @@ async def get_context_with_brute_force_triplet_search(instance: dict) -> str:
return search_results_str return search_results_str
valid_pipeline_slices = {
"base": [0, 1, 5],
"extract_chunks": [0, 1, 2, 5],
"extract_graph": [0, 1, 2, 3, 5],
"summarize": [0, 1, 2, 3, 4, 5],
}
qa_context_providers = { qa_context_providers = {
"no_rag": get_raw_context, "no_rag": get_raw_context,
"cognee": get_context_with_cognee, "cognee": get_context_with_cognee,
"simple_rag": get_context_with_simple_rag, "simple_rag": get_context_with_simple_rag,
"brute_force": get_context_with_brute_force_triplet_search, "brute_force": get_context_with_brute_force_triplet_search,
} | {
name: create_cognee_context_getter(task_indices=slice)
for name, slice in valid_pipeline_slices.items()
} }

View file

@ -3,8 +3,9 @@
"hotpotqa" "hotpotqa"
], ],
"rag_option": [ "rag_option": [
"no_rag", "cognee_incremental",
"cognee", "cognee",
"no_rag",
"simple_rag", "simple_rag",
"brute_force" "brute_force"
], ],

View file

@ -44,7 +44,7 @@ def save_results_as_image(results, out_path):
for num_samples, table_data in num_samples_data.items(): for num_samples, table_data in num_samples_data.items():
df = pd.DataFrame.from_dict(table_data, orient="index") df = pd.DataFrame.from_dict(table_data, orient="index")
df.index.name = f"Dataset: {dataset}, Num Samples: {num_samples}" df.index.name = f"Dataset: {dataset}, Num Samples: {num_samples}"
image_path = Path(out_path) / Path(f"table_{dataset}_{num_samples}.png") image_path = out_path / Path(f"table_{dataset}_{num_samples}.png")
save_table_as_image(df, image_path) save_table_as_image(df, image_path)

View file

@ -204,4 +204,9 @@ if __name__ == "__main__":
"retriever": retrieve, "retriever": retrieve,
} }
asyncio.run(main(steps_to_enable)) loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main(steps_to_enable))
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -69,4 +69,9 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
setup_logging(logging.ERROR) setup_logging(logging.ERROR)
asyncio.run(main()) loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -1,9 +1,11 @@
import os import os
import asyncio import asyncio
import pathlib import pathlib
import logging
import cognee import cognee
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.shared.utils import setup_logging
# Prerequisites: # Prerequisites:
# 1. Copy `.env.template` and rename it to `.env`. # 1. Copy `.env.template` and rename it to `.env`.
@ -45,4 +47,10 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) setup_logging(logging.ERROR)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -1,6 +1,8 @@
import asyncio import asyncio
import cognee import cognee
import logging
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.shared.utils import setup_logging
# Prerequisites: # Prerequisites:
# 1. Copy `.env.template` and rename it to `.env`. # 1. Copy `.env.template` and rename it to `.env`.
@ -66,4 +68,10 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) setup_logging(logging.ERROR)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())