Merge branch 'dev' into add_docstrings
This commit is contained in:
commit
e7f24548dd
12 changed files with 131 additions and 48 deletions
38
README.md
38
README.md
|
|
@ -85,7 +85,7 @@ import os
|
|||
os.environ["LLM_API_KEY"] = "YOUR OPENAI_API_KEY"
|
||||
|
||||
```
|
||||
or
|
||||
or
|
||||
```
|
||||
import cognee
|
||||
cognee.config.set_llm_api_key("YOUR_OPENAI_API_KEY")
|
||||
|
|
@ -115,7 +115,7 @@ DB_PORT=5432
|
|||
DB_NAME=cognee_db
|
||||
DB_USERNAME=cognee
|
||||
DB_PASSWORD=cognee
|
||||
```
|
||||
```
|
||||
|
||||
### Simple example
|
||||
|
||||
|
|
@ -140,14 +140,14 @@ async def main():
|
|||
Natural language processing (NLP) is an interdisciplinary
|
||||
subfield of computer science and information retrieval.
|
||||
"""
|
||||
|
||||
|
||||
print("Adding text to cognee:")
|
||||
print(text.strip())
|
||||
print(text.strip())
|
||||
# Add the text, and make it available for cognify
|
||||
await cognee.add(text)
|
||||
print("Text added successfully.\n")
|
||||
|
||||
|
||||
|
||||
print("Running cognify to create knowledge graph...\n")
|
||||
print("Cognify process steps:")
|
||||
print("1. Classifying the document: Determining the type and category of the input text.")
|
||||
|
|
@ -156,19 +156,19 @@ async def main():
|
|||
print("4. Adding data points: Storing the extracted chunks for processing.")
|
||||
print("5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph.")
|
||||
print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n")
|
||||
|
||||
|
||||
# Use LLMs and cognee to create knowledge graph
|
||||
await cognee.cognify()
|
||||
print("Cognify process complete.\n")
|
||||
|
||||
|
||||
|
||||
query_text = 'Tell me about NLP'
|
||||
print(f"Searching cognee for insights with query: '{query_text}'")
|
||||
# Query cognee for insights on the added text
|
||||
search_results = await cognee.search(
|
||||
SearchType.INSIGHTS, query_text=query_text
|
||||
)
|
||||
|
||||
|
||||
print("Search results:")
|
||||
# Display results
|
||||
for result_text in search_results:
|
||||
|
|
@ -212,7 +212,7 @@ Cognee supports a variety of tools and services for different operations:
|
|||
- **Language Models (LLMs)**: You can use either Anyscale or Ollama as your LLM provider.
|
||||
|
||||
- **Graph Stores**: In addition to NetworkX, Neo4j is also supported for graph storage.
|
||||
|
||||
|
||||
- **User management**: Create individual user graphs and manage permissions
|
||||
|
||||
## Demo
|
||||
|
|
@ -258,13 +258,13 @@ pip install cognee
|
|||
|
||||
|
||||
|
||||
| Name | Type | Current state | Known Issues |
|
||||
|----------|--------------------|-------------------|--------------|
|
||||
| Qdrant | Vector | Stable ✅ | |
|
||||
| Weaviate | Vector | Stable ✅ | |
|
||||
| LanceDB | Vector | Stable ✅ | |
|
||||
| Neo4j | Graph | Stable ✅ | |
|
||||
| NetworkX | Graph | Stable ✅ | |
|
||||
| FalkorDB | Vector/Graph | Unstable ❌ | |
|
||||
| PGVector | Vector | Stable ✅ | |
|
||||
| Milvus | Vector | Stable ✅ | |
|
||||
| Name | Type | Current state (Mac/Linux) | Known Issues | Current state (Windows) | Known Issues |
|
||||
|----------|--------------------|---------------------------|--------------|-------------------------|--------------|
|
||||
| Qdrant | Vector | Stable ✅ | | Unstable ❌ | |
|
||||
| Weaviate | Vector | Stable ✅ | | Unstable ❌ | |
|
||||
| LanceDB | Vector | Stable ✅ | | Stable ✅ | |
|
||||
| Neo4j | Graph | Stable ✅ | | Stable ✅ | |
|
||||
| NetworkX | Graph | Stable ✅ | | Stable ✅ | |
|
||||
| FalkorDB | Vector/Graph | Stable ✅ | | Unstable ❌ | |
|
||||
| PGVector | Vector | Stable ✅ | | Unstable ❌ | |
|
||||
| Milvus | Vector | Stable ✅ | | Unstable ❌ | |
|
||||
|
|
|
|||
|
|
@ -152,7 +152,9 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
connection = await self.get_connection()
|
||||
collection = await connection.open_table(collection_name)
|
||||
|
||||
results = await collection.vector_search(query_vector).to_pandas()
|
||||
collection_size = await collection.count_rows()
|
||||
|
||||
results = await collection.vector_search(query_vector).limit(collection_size).to_pandas()
|
||||
|
||||
result_values = list(results.to_dict("index").values())
|
||||
|
||||
|
|
@ -250,9 +252,16 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
)
|
||||
|
||||
async def prune(self):
|
||||
# Clean up the database if it was set up as temporary
|
||||
connection = await self.get_connection()
|
||||
collection_names = await connection.table_names()
|
||||
|
||||
for collection_name in collection_names:
|
||||
collection = await connection.open_table(collection_name)
|
||||
await collection.delete("id IS NOT NULL")
|
||||
await connection.drop_table(collection_name)
|
||||
|
||||
if self.url.startswith("/"):
|
||||
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
|
||||
LocalStorage.remove_all(self.url)
|
||||
|
||||
def get_data_point_schema(self, model_type):
|
||||
return copy_model(
|
||||
|
|
|
|||
|
|
@ -30,4 +30,4 @@ if [ "$ENVIRONMENT" = "dev" ]; then
|
|||
else
|
||||
gunicorn -w 3 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:8000 --log-level error cognee.api.client:app
|
||||
# python ./cognee/api/client.py
|
||||
fi
|
||||
fi
|
||||
|
|
@ -9,7 +9,7 @@ from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
|||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
from evals.qa_dataset_utils import load_qa_dataset
|
||||
from evals.qa_metrics_utils import get_metrics
|
||||
from evals.qa_context_provider_utils import qa_context_providers
|
||||
from evals.qa_context_provider_utils import qa_context_providers, create_cognee_context_getter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -94,14 +94,29 @@ async def eval_on_QA_dataset(
|
|||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def incremental_eval_on_QA_dataset(
|
||||
dataset_name_or_filename: str, num_samples, metric_name_list
|
||||
):
|
||||
pipeline_slice_names = ["base", "extract_chunks", "extract_graph", "summarize"]
|
||||
|
||||
incremental_results = {}
|
||||
for pipeline_slice_name in pipeline_slice_names:
|
||||
results = await eval_on_QA_dataset(
|
||||
dataset_name_or_filename, pipeline_slice_name, num_samples, metric_name_list
|
||||
)
|
||||
incremental_results[pipeline_slice_name] = results
|
||||
|
||||
return incremental_results
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--dataset", type=str, required=True, help="Which dataset to evaluate on")
|
||||
parser.add_argument(
|
||||
"--rag_option",
|
||||
type=str,
|
||||
choices=qa_context_providers.keys(),
|
||||
choices=list(qa_context_providers.keys()) + ["cognee_incremental"],
|
||||
required=True,
|
||||
help="RAG option to use for providing context",
|
||||
)
|
||||
|
|
@ -110,7 +125,18 @@ if __name__ == "__main__":
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
avg_scores = asyncio.run(
|
||||
eval_on_QA_dataset(args.dataset, args.rag_option, args.num_samples, args.metrics)
|
||||
)
|
||||
if args.rag_option == "cognee_incremental":
|
||||
avg_scores = await incremental_eval_on_QA_dataset(
|
||||
args.dataset, args.num_samples, args.metrics
|
||||
)
|
||||
|
||||
else:
|
||||
avg_scores = await eval_on_QA_dataset(
|
||||
args.dataset, args.rag_option, args.num_samples, args.metrics
|
||||
)
|
||||
|
||||
logger.info(f"{avg_scores}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -4,11 +4,8 @@ These are the official evaluation metrics for HotpotQA taken from https://hotpot
|
|||
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
||||
import ujson as json
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
def remove_articles(text):
|
||||
|
|
|
|||
|
|
@ -3,35 +3,49 @@ from cognee.api.v1.search import SearchType
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string
|
||||
from functools import partial
|
||||
from cognee.api.v1.cognify.cognify_v2 import get_default_tasks
|
||||
|
||||
|
||||
async def get_raw_context(instance: dict) -> str:
|
||||
return instance["context"]
|
||||
|
||||
|
||||
async def cognify_instance(instance: dict):
|
||||
async def cognify_instance(instance: dict, task_indices: list[int] = None):
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
for title, sentences in instance["context"]:
|
||||
await cognee.add("\n".join(sentences), dataset_name="QA")
|
||||
await cognee.cognify("QA")
|
||||
all_cognify_tasks = await get_default_tasks()
|
||||
if task_indices:
|
||||
selected_tasks = [all_cognify_tasks[ind] for ind in task_indices]
|
||||
else:
|
||||
selected_tasks = all_cognify_tasks
|
||||
await cognee.cognify("QA", tasks=selected_tasks)
|
||||
|
||||
|
||||
async def get_context_with_cognee(instance: dict) -> str:
|
||||
await cognify_instance(instance)
|
||||
async def get_context_with_cognee(
|
||||
instance: dict,
|
||||
task_indices: list[int] = None,
|
||||
search_types: list[SearchType] = [SearchType.SUMMARIES, SearchType.CHUNKS],
|
||||
) -> str:
|
||||
await cognify_instance(instance, task_indices)
|
||||
|
||||
# TODO: Fix insights
|
||||
# insights = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"])
|
||||
summaries = await cognee.search(SearchType.SUMMARIES, query_text=instance["question"])
|
||||
# search_results = insights + summaries
|
||||
search_results = summaries
|
||||
search_results = []
|
||||
for search_type in search_types:
|
||||
search_results += await cognee.search(search_type, query_text=instance["question"])
|
||||
|
||||
search_results_str = "\n".join([context_item["text"] for context_item in search_results])
|
||||
|
||||
return search_results_str
|
||||
|
||||
|
||||
def create_cognee_context_getter(
|
||||
task_indices=None, search_types=[SearchType.SUMMARIES, SearchType.CHUNKS]
|
||||
):
|
||||
return partial(get_context_with_cognee, task_indices=task_indices, search_types=search_types)
|
||||
|
||||
|
||||
async def get_context_with_simple_rag(instance: dict) -> str:
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
|
@ -57,9 +71,19 @@ async def get_context_with_brute_force_triplet_search(instance: dict) -> str:
|
|||
return search_results_str
|
||||
|
||||
|
||||
valid_pipeline_slices = {
|
||||
"base": [0, 1, 5],
|
||||
"extract_chunks": [0, 1, 2, 5],
|
||||
"extract_graph": [0, 1, 2, 3, 5],
|
||||
"summarize": [0, 1, 2, 3, 4, 5],
|
||||
}
|
||||
|
||||
qa_context_providers = {
|
||||
"no_rag": get_raw_context,
|
||||
"cognee": get_context_with_cognee,
|
||||
"simple_rag": get_context_with_simple_rag,
|
||||
"brute_force": get_context_with_brute_force_triplet_search,
|
||||
} | {
|
||||
name: create_cognee_context_getter(task_indices=slice)
|
||||
for name, slice in valid_pipeline_slices.items()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,8 +3,9 @@
|
|||
"hotpotqa"
|
||||
],
|
||||
"rag_option": [
|
||||
"no_rag",
|
||||
"cognee_incremental",
|
||||
"cognee",
|
||||
"no_rag",
|
||||
"simple_rag",
|
||||
"brute_force"
|
||||
],
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def save_results_as_image(results, out_path):
|
|||
for num_samples, table_data in num_samples_data.items():
|
||||
df = pd.DataFrame.from_dict(table_data, orient="index")
|
||||
df.index.name = f"Dataset: {dataset}, Num Samples: {num_samples}"
|
||||
image_path = Path(out_path) / Path(f"table_{dataset}_{num_samples}.png")
|
||||
image_path = out_path / Path(f"table_{dataset}_{num_samples}.png")
|
||||
save_table_as_image(df, image_path)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -204,4 +204,9 @@ if __name__ == "__main__":
|
|||
"retriever": retrieve,
|
||||
}
|
||||
|
||||
asyncio.run(main(steps_to_enable))
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main(steps_to_enable))
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
|
|
|||
|
|
@ -69,4 +69,9 @@ async def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
setup_logging(logging.ERROR)
|
||||
asyncio.run(main())
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import os
|
||||
import asyncio
|
||||
import pathlib
|
||||
import logging
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.shared.utils import setup_logging
|
||||
|
||||
# Prerequisites:
|
||||
# 1. Copy `.env.template` and rename it to `.env`.
|
||||
|
|
@ -45,4 +47,10 @@ async def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
setup_logging(logging.ERROR)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
import logging
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.shared.utils import setup_logging
|
||||
|
||||
# Prerequisites:
|
||||
# 1. Copy `.env.template` and rename it to `.env`.
|
||||
|
|
@ -66,4 +68,10 @@ async def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
setup_logging(logging.ERROR)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue