chore: Fix and update visualization (#518)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced enhanced visualization capabilities that let users launch a
dedicated server for visual displays.
  
- **Documentation**
- Updated several interactive notebooks to include execution outputs and
expanded explanatory content for better user guidance.
  
- **Style**
- Refined formatting and layout across notebooks to ensure consistent
presentation and improved readability.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
This commit is contained in:
Vasilije 2025-02-11 13:25:01 -05:00 committed by GitHub
parent 1b630366c9
commit 9ba2e0d6c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 630 additions and 302 deletions

View file

@ -4,7 +4,7 @@ from .api.v1.config.config import config
from .api.v1.datasets.datasets import datasets
from .api.v1.prune import prune
from .api.v1.search import SearchType, get_search_history, search
from .api.v1.visualize import visualize_graph
from .api.v1.visualize import visualize_graph, start_visualization_server
from cognee.modules.visualization.cognee_network_visualization import (
cognee_network_visualization,
)

View file

@ -1 +1,2 @@
from .visualize import visualize_graph
from .start_visualization_server import visualization_server

View file

@ -0,0 +1,17 @@
from cognee.shared.utils import start_visualization_server
def visualization_server(port):
"""
Start a visualization server on the specified port.
Args:
port (int): The port number to run the server on
Returns:
callable: A shutdown function that can be called to stop the server
Raises:
ValueError: If port is not a valid port number
"""
return start_visualization_server(port=port)

View file

@ -12,7 +12,6 @@ from cognee.shared.utils import setup_logging
async def visualize_graph(destination_file_path: str = None):
graph_engine = await get_graph_engine()
graph_data = await graph_engine.get_graph_data()
logging.info(graph_data)
graph = await cognee_network_visualization(graph_data, destination_file_path)

View file

@ -10,7 +10,9 @@ import graphistry
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import http.server
import socketserver
from threading import Thread
import logging
import sys
@ -364,3 +366,41 @@ def setup_logging(log_level=logging.INFO):
root_logger.addHandler(stream_handler)
root_logger.setLevel(log_level)
def start_visualization_server(
host="0.0.0.0", port=8001, handler_class=http.server.SimpleHTTPRequestHandler
):
"""
Spin up a simple HTTP server in a background thread to serve files.
This is especially handy for quick demos or visualization purposes.
Returns a shutdown() function that can be called to stop the server.
:param host: Host/IP to bind to. Defaults to '0.0.0.0'.
:param port: Port to listen on. Defaults to 8001.
:param handler_class: A handler class, defaults to SimpleHTTPRequestHandler.
:return: A no-argument function `shutdown` which, when called, stops the server.
"""
# Create the server
server = socketserver.TCPServer((host, port), handler_class)
def _serve_forever():
print(f"Visualization server running at: http://{host}:{port}")
server.serve_forever()
# Start the server in a background thread
thread = Thread(target=_serve_forever, daemon=True)
thread.start()
def shutdown():
"""
Shuts down the server and blocks until the thread is joined.
"""
server.shutdown() # Signals the serve_forever() loop to stop
server.server_close() # Frees up the socket
thread.join()
print(f"Visualization server on port {port} has been shut down.")
# Return only the shutdown function (the server runs in the background)
return shutdown

File diff suppressed because one or more lines are too long

View file

@ -16,7 +16,9 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import cognee\n",
"import logging\n",
@ -34,9 +36,7 @@
"from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string\n",
"from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt\n",
"from cognee.infrastructure.llm.get_llm_client import get_llm_client"
],
"outputs": [],
"execution_count": null
]
},
{
"cell_type": "markdown",
@ -47,17 +47,19 @@
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-15T10:43:57.893763Z",
"start_time": "2025-01-15T10:43:57.891332Z"
}
},
"outputs": [],
"source": [
"import os\n",
"\n",
"# We ignore warnigns for now\n",
"warnings.filterwarnings('ignore')\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"# API key for cognee\n",
"if \"LLM_API_KEY\" not in os.environ:\n",
@ -68,10 +70,10 @@
" os.environ[\"OPENAI_API_KEY\"] = \"\"\n",
"\n",
"# Graphiti integration is only tested with neo4j + pgvector + postgres for now\n",
"GRAPH_DATABASE_PROVIDER=\"neo4j\"\n",
"GRAPH_DATABASE_URL=\"bolt://localhost:7687\"\n",
"GRAPH_DATABASE_USERNAME=\"neo4j\"\n",
"GRAPH_DATABASE_PASSWORD=\"pleaseletmein\"\n",
"GRAPH_DATABASE_PROVIDER = \"neo4j\"\n",
"GRAPH_DATABASE_URL = \"bolt://localhost:7687\"\n",
"GRAPH_DATABASE_USERNAME = \"neo4j\"\n",
"GRAPH_DATABASE_PASSWORD = \"pleaseletmein\"\n",
"\n",
"os.environ[\"VECTOR_DB_PROVIDER\"] = \"pgvector\"\n",
"\n",
@ -79,13 +81,11 @@
"\n",
"os.environ[\"DB_NAME\"] = \"cognee_db\"\n",
"\n",
"os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
"os.environ[\"DB_PORT\"]=\"5432\"\n",
"os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
"os.environ[\"DB_PASSWORD\"]=\"cognee\""
],
"outputs": [],
"execution_count": 2
"os.environ[\"DB_HOST\"] = \"127.0.0.1\"\n",
"os.environ[\"DB_PORT\"] = \"5432\"\n",
"os.environ[\"DB_USERNAME\"] = \"cognee\"\n",
"os.environ[\"DB_PASSWORD\"] = \"cognee\""
]
},
{
"cell_type": "markdown",
@ -94,21 +94,21 @@
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-15T10:43:57.928664Z",
"start_time": "2025-01-15T10:43:57.927105Z"
}
},
"outputs": [],
"source": [
"text_list = [\n",
" \"Kamala Harris is the Attorney General of California. She was previously \"\n",
" \"the district attorney for San Francisco.\",\n",
" \"As AG, Harris was in office from January 3, 2011 January 3, 2017\",\n",
"]"
],
"outputs": [],
"execution_count": 3
]
},
{
"cell_type": "markdown",
@ -117,12 +117,36 @@
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-15T10:44:25.008501Z",
"start_time": "2025-01-15T10:43:57.932240Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Database deleted successfully.\n",
"Database deleted successfully.\n",
"User d3b51a32-38e1-4fe5-8270-6dc1d6ebfdf0 has registered.\n",
"Pipeline file_load_from_filesystem load step completed in 0.10 seconds\n",
"1 load package(s) were loaded to destination sqlalchemy and into dataset public\n",
"The sqlalchemy destination used postgresql://cognee:***@127.0.0.1:5432/cognee_db location to store data\n",
"Load package 1736937839.7739599 is LOADED and contains no failed jobs\n",
"Pipeline file_load_from_filesystem load step completed in 0.06 seconds\n",
"1 load package(s) were loaded to destination sqlalchemy and into dataset public\n",
"The sqlalchemy destination used postgresql://cognee:***@127.0.0.1:5432/cognee_db location to store data\n",
"Load package 1736937841.8467042 is LOADED and contains no failed jobs\n",
"Graph database initialized.\n",
"Added text: Kamala Harris is the Attorney Gener...\n",
"Added text: As AG, Harris was in office from Ja...\n",
"✅ Result Processed: <graphiti_core.graphiti.Graphiti object at 0x326fe0ce0>\n"
]
}
],
"source": [
"# 🔧 Setting Up Logging to Suppress Errors\n",
"setup_logging(logging.ERROR) # Keeping logs clean and focused\n",
@ -152,45 +176,31 @@
"\n",
"# 🔄 Indexing and Transforming Graph Data\n",
"await index_and_transform_graphiti_nodes_and_edges()"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Database deleted successfully.\n",
"Database deleted successfully.\n",
"User d3b51a32-38e1-4fe5-8270-6dc1d6ebfdf0 has registered.\n",
"Pipeline file_load_from_filesystem load step completed in 0.10 seconds\n",
"1 load package(s) were loaded to destination sqlalchemy and into dataset public\n",
"The sqlalchemy destination used postgresql://cognee:***@127.0.0.1:5432/cognee_db location to store data\n",
"Load package 1736937839.7739599 is LOADED and contains no failed jobs\n",
"Pipeline file_load_from_filesystem load step completed in 0.06 seconds\n",
"1 load package(s) were loaded to destination sqlalchemy and into dataset public\n",
"The sqlalchemy destination used postgresql://cognee:***@127.0.0.1:5432/cognee_db location to store data\n",
"Load package 1736937841.8467042 is LOADED and contains no failed jobs\n",
"Graph database initialized.\n",
"Added text: Kamala Harris is the Attorney Gener...\n",
"Added text: As AG, Harris was in office from Ja...\n",
"✅ Result Processed: <graphiti_core.graphiti.Graphiti object at 0x326fe0ce0>\n"
]
}
],
"execution_count": 4
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": "## Retrieving and generating answer from graphiti graph with cognee retriever"
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-15T10:44:27.844438Z",
"start_time": "2025-01-15T10:44:25.013325Z"
}
},
"cell_type": "code",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"💡 Answer: Kamala Harris was in office as Attorney General of California from January 3, 2011, to January 3, 2017.\n"
]
}
],
"source": [
"# Step 1: Formulating the Query 🔍\n",
"query = \"When was Kamala Harris in office?\"\n",
@ -205,10 +215,7 @@
"# Step 3: Preparing the Context for the LLM\n",
"context = retrieved_edges_to_string(triplets)\n",
"\n",
"args = {\n",
" \"question\": query,\n",
" \"context\": context\n",
"}\n",
"args = {\"question\": query, \"context\": context}\n",
"\n",
"# Step 4: Generating Prompts ✍️\n",
"user_prompt = render_prompt(\"graph_context_for_question.txt\", args)\n",
@ -217,24 +224,14 @@
"# Step 5: Interacting with the LLM 🤖\n",
"llm_client = get_llm_client()\n",
"computed_answer = await llm_client.acreate_structured_output(\n",
" text_input=user_prompt, # Input prompt for the user context\n",
" system_prompt=system_prompt, # System-level instructions for the model\n",
" text_input=user_prompt, # Input prompt for the user context\n",
" system_prompt=system_prompt, # System-level instructions for the model\n",
" response_model=str,\n",
")\n",
"\n",
"# Step 6: Displaying the Computed Answer ✨\n",
"print(f\"💡 Answer: {computed_answer}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"💡 Answer: Kamala Harris was in office as Attorney General of California from January 3, 2011, to January 3, 2017.\n"
]
}
],
"execution_count": 5
]
}
],
"metadata": {

View file

@ -98,12 +98,12 @@
"random.seed(42)\n",
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
"\n",
"out_path = \"out\" \n",
"out_path = \"out\"\n",
"if not Path(out_path).exists():\n",
" Path(out_path).mkdir()\n",
"contexts_filename = out_path / Path(\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
" )\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
")\n",
"\n",
"answers = []\n",
"for instance in tqdm(instances, desc=\"Getting answers\"):\n",
@ -143,7 +143,7 @@
"source": [
"metric_name_list = [\"Correctness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -173,7 +173,7 @@
"source": [
"metric_name_list = [\"Comprehensiveness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -203,7 +203,7 @@
"source": [
"metric_name_list = [\"Diversity\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -212,9 +212,7 @@
"metadata": {},
"outputs": [],
"source": [
"Diversity = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Diversity)"
]
},
@ -233,7 +231,7 @@
"source": [
"metric_name_list = [\"Empowerment\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -263,7 +261,7 @@
"source": [
"metric_name_list = [\"Directness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -272,9 +270,7 @@
"metadata": {},
"outputs": [],
"source": [
"Directness = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Directness)"
]
},
@ -301,7 +297,7 @@
"metadata": {},
"outputs": [],
"source": [
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -310,9 +306,7 @@
"metadata": {},
"outputs": [],
"source": [
"F1_score = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(F1_score)"
]
},
@ -331,7 +325,7 @@
"source": [
"metric_name_list = [\"EM\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -340,9 +334,7 @@
"metadata": {},
"outputs": [],
"source": [
"EM = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(EM)"
]
},
@ -379,12 +371,12 @@
"random.seed(42)\n",
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
"\n",
"out_path = \"out\" \n",
"out_path = \"out\"\n",
"if not Path(out_path).exists():\n",
" Path(out_path).mkdir()\n",
"contexts_filename = out_path / Path(\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
" )\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
")\n",
"\n",
"answers = []\n",
"for instance in tqdm(instances, desc=\"Getting answers\"):\n",
@ -414,7 +406,7 @@
"source": [
"metric_name_list = [\"Correctness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -444,7 +436,7 @@
"source": [
"metric_name_list = [\"Comprehensiveness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -474,7 +466,7 @@
"source": [
"metric_name_list = [\"Diversity\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -483,9 +475,7 @@
"metadata": {},
"outputs": [],
"source": [
"Diversity = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Diversity)"
]
},
@ -504,7 +494,7 @@
"source": [
"metric_name_list = [\"Empowerment\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -534,7 +524,7 @@
"source": [
"metric_name_list = [\"Directness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -543,9 +533,7 @@
"metadata": {},
"outputs": [],
"source": [
"Directness = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Directness)"
]
},
@ -572,7 +560,7 @@
"metadata": {},
"outputs": [],
"source": [
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -581,9 +569,7 @@
"metadata": {},
"outputs": [],
"source": [
"F1_score = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(F1_score)"
]
},
@ -602,7 +588,7 @@
"source": [
"metric_name_list = [\"EM\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -611,9 +597,7 @@
"metadata": {},
"outputs": [],
"source": [
"EM = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(EM)"
]
},
@ -650,12 +634,12 @@
"random.seed(42)\n",
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
"\n",
"out_path = \"out\" \n",
"out_path = \"out\"\n",
"if not Path(out_path).exists():\n",
" Path(out_path).mkdir()\n",
"contexts_filename = out_path / Path(\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
" )\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
")\n",
"\n",
"answers = []\n",
"for instance in tqdm(instances, desc=\"Getting answers\"):\n",
@ -685,7 +669,7 @@
"source": [
"metric_name_list = [\"Correctness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -715,7 +699,7 @@
"source": [
"metric_name_list = [\"Comprehensiveness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -745,7 +729,7 @@
"source": [
"metric_name_list = [\"Diversity\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -754,9 +738,7 @@
"metadata": {},
"outputs": [],
"source": [
"Diversity = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Diversity)"
]
},
@ -775,7 +757,7 @@
"source": [
"metric_name_list = [\"Empowerment\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -805,7 +787,7 @@
"source": [
"metric_name_list = [\"Directness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -814,9 +796,7 @@
"metadata": {},
"outputs": [],
"source": [
"Directness = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Directness)"
]
},
@ -843,7 +823,7 @@
"metadata": {},
"outputs": [],
"source": [
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -852,9 +832,7 @@
"metadata": {},
"outputs": [],
"source": [
"F1_score = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(F1_score)"
]
},
@ -873,7 +851,7 @@
"source": [
"metric_name_list = [\"EM\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -882,9 +860,7 @@
"metadata": {},
"outputs": [],
"source": [
"EM = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(EM)"
]
},
@ -921,12 +897,12 @@
"random.seed(42)\n",
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
"\n",
"out_path = \"out\" \n",
"out_path = \"out\"\n",
"if not Path(out_path).exists():\n",
" Path(out_path).mkdir()\n",
"contexts_filename = out_path / Path(\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
" )\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
")\n",
"\n",
"answers = []\n",
"for instance in tqdm(instances, desc=\"Getting answers\"):\n",
@ -956,7 +932,7 @@
"source": [
"metric_name_list = [\"Correctness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -986,7 +962,7 @@
"source": [
"metric_name_list = [\"Comprehensiveness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -1016,7 +992,7 @@
"source": [
"metric_name_list = [\"Diversity\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -1025,9 +1001,7 @@
"metadata": {},
"outputs": [],
"source": [
"Diversity = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Diversity)"
]
},
@ -1046,7 +1020,7 @@
"source": [
"metric_name_list = [\"Empowerment\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -1076,7 +1050,7 @@
"source": [
"metric_name_list = [\"Directness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -1085,9 +1059,7 @@
"metadata": {},
"outputs": [],
"source": [
"Directness = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Directness)"
]
},
@ -1114,7 +1086,7 @@
"metadata": {},
"outputs": [],
"source": [
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -1123,9 +1095,7 @@
"metadata": {},
"outputs": [],
"source": [
"F1_score = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(F1_score)"
]
},
@ -1144,7 +1114,7 @@
"source": [
"metric_name_list = [\"EM\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -1153,9 +1123,7 @@
"metadata": {},
"outputs": [],
"source": [
"EM = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(EM)"
]
}

View file

@ -682,12 +682,12 @@
"random.seed(42)\n",
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
"\n",
"out_path = \"out\" \n",
"out_path = \"out\"\n",
"if not Path(out_path).exists():\n",
" Path(out_path).mkdir()\n",
"contexts_filename = out_path / Path(\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
" )\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
")\n",
"\n",
"answers = []\n",
"for instance in tqdm(instances, desc=\"Getting answers\"):\n",
@ -728,7 +728,7 @@
"source": [
"metric_name_list = [\"Correctness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -761,7 +761,7 @@
"source": [
"metric_name_list = [\"Comprehensiveness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -794,7 +794,7 @@
"source": [
"metric_name_list = [\"Diversity\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -804,9 +804,7 @@
"metadata": {},
"outputs": [],
"source": [
"Diversity = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Diversity)"
]
},
@ -827,7 +825,7 @@
"source": [
"metric_name_list = [\"Empowerment\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -860,7 +858,7 @@
"source": [
"metric_name_list = [\"Directness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -870,9 +868,7 @@
"metadata": {},
"outputs": [],
"source": [
"Directness = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(Directness)"
]
},
@ -902,7 +898,7 @@
"metadata": {},
"outputs": [],
"source": [
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -912,9 +908,7 @@
"metadata": {},
"outputs": [],
"source": [
"F1_score = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(F1_score)"
]
},
@ -935,7 +929,7 @@
"source": [
"metric_name_list = [\"EM\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) "
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
]
},
{
@ -945,9 +939,7 @@
"metadata": {},
"outputs": [],
"source": [
"EM = statistics.mean(\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
"print(EM)"
]
},

View file

@ -1,13 +1,13 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EpokQ8Y_5jIJ7HdixZms81Oqgh2sp7-E?usp=sharing)"
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## LlamaIndex Cognee GraphRAG Integration\n",
"\n",
@ -53,17 +53,15 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "!pip install llama-index-graph-rag-cognee==0.1.3"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import asyncio\n",
@ -75,8 +73,8 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Ensure youve set up your API keys and installed necessary dependencies.\n",
"\n",
@ -86,24 +84,24 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"documents = [\n",
" Document(\n",
" text=\"Jessica Miller, Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams.\"\n",
" ),\n",
" Document(\n",
" text=\"David Thompson, Creative Graphic Designer with over 8 years of experience in visual design and branding.\"\n",
" ),\n",
" ]"
" Document(\n",
" text=\"Jessica Miller, Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams.\"\n",
" ),\n",
" Document(\n",
" text=\"David Thompson, Creative Graphic Designer with over 8 years of experience in visual design and branding.\"\n",
" ),\n",
"]"
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Initializing CogneeGraphRAG\n",
"\n",
@ -111,10 +109,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cogneeRAG = CogneeGraphRAG(\n",
" llm_api_key=os.environ[\"OPENAI_API_KEY\"],\n",
@ -128,8 +126,8 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Adding Data to Cognee\n",
"\n",
@ -137,15 +135,17 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "await cogneeRAG.add(documents, \"test\")"
"metadata": {},
"outputs": [],
"source": [
"await cogneeRAG.add(documents, \"test\")"
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This step prepares the data for graph-based processing.\n",
"\n",
@ -155,15 +155,17 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "await cogneeRAG.process_data(\"test\")"
"metadata": {},
"outputs": [],
"source": [
"await cogneeRAG.process_data(\"test\")"
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"The graph now contains nodes and relationships derived from the dataset, creating a powerful structure for exploration.\n",
"\n",
@ -173,10 +175,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"search_results = await cogneeRAG.search(\"Tell me who are the people mentioned?\")\n",
"\n",
@ -186,15 +188,15 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": "### Answer prompt based on RAG approach:"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"search_results = await cogneeRAG.rag_search(\"Tell me who are the people mentioned?\")\n",
"\n",
@ -204,13 +206,13 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": "In conclusion, the results demonstrate a significant advantage of the knowledge graph-based approach (Graphrag) over the RAG approach. Graphrag successfully identified all the mentioned individuals across multiple documents, showcasing its ability to aggregate and infer information from a global context. In contrast, the RAG approach was limited to identifying individuals within a single document due to its chunking-based processing constraints. This highlights Graphrag's superior capability in comprehensively resolving queries that span across a broader corpus of interconnected data."
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### 7. Finding Related Nodes\n",
"\n",
@ -218,10 +220,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"related_nodes = await cogneeRAG.get_related_nodes(\"person\")\n",
"\n",
@ -231,8 +233,8 @@
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Why Choose Cognee and LlamaIndex?\n",
"\n",