chore: Fix and update visualization (#518)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced enhanced visualization capabilities that let users launch a
dedicated server for visual displays.
  
- **Documentation**
- Updated several interactive notebooks to include execution outputs and
expanded explanatory content for better user guidance.
  
- **Style**
- Refined formatting and layout across notebooks to ensure consistent
presentation and improved readability.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
This commit is contained in:
Vasilije 2025-02-11 13:25:01 -05:00 committed by GitHub
parent 1b630366c9
commit 9ba2e0d6c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 630 additions and 302 deletions

View file

@ -4,7 +4,7 @@ from .api.v1.config.config import config
from .api.v1.datasets.datasets import datasets from .api.v1.datasets.datasets import datasets
from .api.v1.prune import prune from .api.v1.prune import prune
from .api.v1.search import SearchType, get_search_history, search from .api.v1.search import SearchType, get_search_history, search
from .api.v1.visualize import visualize_graph from .api.v1.visualize import visualize_graph, start_visualization_server
from cognee.modules.visualization.cognee_network_visualization import ( from cognee.modules.visualization.cognee_network_visualization import (
cognee_network_visualization, cognee_network_visualization,
) )

View file

@ -1 +1,2 @@
from .visualize import visualize_graph from .visualize import visualize_graph
from .start_visualization_server import visualization_server

View file

@ -0,0 +1,17 @@
from cognee.shared.utils import start_visualization_server
def visualization_server(port):
"""
Start a visualization server on the specified port.
Args:
port (int): The port number to run the server on
Returns:
callable: A shutdown function that can be called to stop the server
Raises:
ValueError: If port is not a valid port number
"""
return start_visualization_server(port=port)

View file

@ -12,7 +12,6 @@ from cognee.shared.utils import setup_logging
async def visualize_graph(destination_file_path: str = None): async def visualize_graph(destination_file_path: str = None):
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
graph_data = await graph_engine.get_graph_data() graph_data = await graph_engine.get_graph_data()
logging.info(graph_data)
graph = await cognee_network_visualization(graph_data, destination_file_path) graph = await cognee_network_visualization(graph_data, destination_file_path)

View file

@ -10,7 +10,9 @@ import graphistry
import networkx as nx import networkx as nx
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import http.server
import socketserver
from threading import Thread
import logging import logging
import sys import sys
@ -364,3 +366,41 @@ def setup_logging(log_level=logging.INFO):
root_logger.addHandler(stream_handler) root_logger.addHandler(stream_handler)
root_logger.setLevel(log_level) root_logger.setLevel(log_level)
def start_visualization_server(
host="0.0.0.0", port=8001, handler_class=http.server.SimpleHTTPRequestHandler
):
"""
Spin up a simple HTTP server in a background thread to serve files.
This is especially handy for quick demos or visualization purposes.
Returns a shutdown() function that can be called to stop the server.
:param host: Host/IP to bind to. Defaults to '0.0.0.0'.
:param port: Port to listen on. Defaults to 8001.
:param handler_class: A handler class, defaults to SimpleHTTPRequestHandler.
:return: A no-argument function `shutdown` which, when called, stops the server.
"""
# Create the server
server = socketserver.TCPServer((host, port), handler_class)
def _serve_forever():
print(f"Visualization server running at: http://{host}:{port}")
server.serve_forever()
# Start the server in a background thread
thread = Thread(target=_serve_forever, daemon=True)
thread.start()
def shutdown():
"""
Shuts down the server and blocks until the thread is joined.
"""
server.shutdown() # Signals the serve_forever() loop to stop
server.server_close() # Frees up the socket
thread.join()
print(f"Visualization server on port {port} has been shut down.")
# Return only the shutdown function (the server runs in the background)
return shutdown

File diff suppressed because one or more lines are too long

View file

@ -16,7 +16,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"import cognee\n", "import cognee\n",
"import logging\n", "import logging\n",
@ -34,9 +36,7 @@
"from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string\n", "from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string\n",
"from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt\n", "from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt\n",
"from cognee.infrastructure.llm.get_llm_client import get_llm_client" "from cognee.infrastructure.llm.get_llm_client import get_llm_client"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -47,17 +47,19 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2025-01-15T10:43:57.893763Z", "end_time": "2025-01-15T10:43:57.893763Z",
"start_time": "2025-01-15T10:43:57.891332Z" "start_time": "2025-01-15T10:43:57.891332Z"
} }
}, },
"outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"\n", "\n",
"# We ignore warnigns for now\n", "# We ignore warnigns for now\n",
"warnings.filterwarnings('ignore')\n", "warnings.filterwarnings(\"ignore\")\n",
"\n", "\n",
"# API key for cognee\n", "# API key for cognee\n",
"if \"LLM_API_KEY\" not in os.environ:\n", "if \"LLM_API_KEY\" not in os.environ:\n",
@ -68,10 +70,10 @@
" os.environ[\"OPENAI_API_KEY\"] = \"\"\n", " os.environ[\"OPENAI_API_KEY\"] = \"\"\n",
"\n", "\n",
"# Graphiti integration is only tested with neo4j + pgvector + postgres for now\n", "# Graphiti integration is only tested with neo4j + pgvector + postgres for now\n",
"GRAPH_DATABASE_PROVIDER=\"neo4j\"\n", "GRAPH_DATABASE_PROVIDER = \"neo4j\"\n",
"GRAPH_DATABASE_URL=\"bolt://localhost:7687\"\n", "GRAPH_DATABASE_URL = \"bolt://localhost:7687\"\n",
"GRAPH_DATABASE_USERNAME=\"neo4j\"\n", "GRAPH_DATABASE_USERNAME = \"neo4j\"\n",
"GRAPH_DATABASE_PASSWORD=\"pleaseletmein\"\n", "GRAPH_DATABASE_PASSWORD = \"pleaseletmein\"\n",
"\n", "\n",
"os.environ[\"VECTOR_DB_PROVIDER\"] = \"pgvector\"\n", "os.environ[\"VECTOR_DB_PROVIDER\"] = \"pgvector\"\n",
"\n", "\n",
@ -79,13 +81,11 @@
"\n", "\n",
"os.environ[\"DB_NAME\"] = \"cognee_db\"\n", "os.environ[\"DB_NAME\"] = \"cognee_db\"\n",
"\n", "\n",
"os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n", "os.environ[\"DB_HOST\"] = \"127.0.0.1\"\n",
"os.environ[\"DB_PORT\"]=\"5432\"\n", "os.environ[\"DB_PORT\"] = \"5432\"\n",
"os.environ[\"DB_USERNAME\"]=\"cognee\"\n", "os.environ[\"DB_USERNAME\"] = \"cognee\"\n",
"os.environ[\"DB_PASSWORD\"]=\"cognee\"" "os.environ[\"DB_PASSWORD\"] = \"cognee\""
], ]
"outputs": [],
"execution_count": 2
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -94,21 +94,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2025-01-15T10:43:57.928664Z", "end_time": "2025-01-15T10:43:57.928664Z",
"start_time": "2025-01-15T10:43:57.927105Z" "start_time": "2025-01-15T10:43:57.927105Z"
} }
}, },
"outputs": [],
"source": [ "source": [
"text_list = [\n", "text_list = [\n",
" \"Kamala Harris is the Attorney General of California. She was previously \"\n", " \"Kamala Harris is the Attorney General of California. She was previously \"\n",
" \"the district attorney for San Francisco.\",\n", " \"the district attorney for San Francisco.\",\n",
" \"As AG, Harris was in office from January 3, 2011 January 3, 2017\",\n", " \"As AG, Harris was in office from January 3, 2011 January 3, 2017\",\n",
"]" "]"
], ]
"outputs": [],
"execution_count": 3
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -117,12 +117,36 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2025-01-15T10:44:25.008501Z", "end_time": "2025-01-15T10:44:25.008501Z",
"start_time": "2025-01-15T10:43:57.932240Z" "start_time": "2025-01-15T10:43:57.932240Z"
} }
}, },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Database deleted successfully.\n",
"Database deleted successfully.\n",
"User d3b51a32-38e1-4fe5-8270-6dc1d6ebfdf0 has registered.\n",
"Pipeline file_load_from_filesystem load step completed in 0.10 seconds\n",
"1 load package(s) were loaded to destination sqlalchemy and into dataset public\n",
"The sqlalchemy destination used postgresql://cognee:***@127.0.0.1:5432/cognee_db location to store data\n",
"Load package 1736937839.7739599 is LOADED and contains no failed jobs\n",
"Pipeline file_load_from_filesystem load step completed in 0.06 seconds\n",
"1 load package(s) were loaded to destination sqlalchemy and into dataset public\n",
"The sqlalchemy destination used postgresql://cognee:***@127.0.0.1:5432/cognee_db location to store data\n",
"Load package 1736937841.8467042 is LOADED and contains no failed jobs\n",
"Graph database initialized.\n",
"Added text: Kamala Harris is the Attorney Gener...\n",
"Added text: As AG, Harris was in office from Ja...\n",
"✅ Result Processed: <graphiti_core.graphiti.Graphiti object at 0x326fe0ce0>\n"
]
}
],
"source": [ "source": [
"# 🔧 Setting Up Logging to Suppress Errors\n", "# 🔧 Setting Up Logging to Suppress Errors\n",
"setup_logging(logging.ERROR) # Keeping logs clean and focused\n", "setup_logging(logging.ERROR) # Keeping logs clean and focused\n",
@ -152,45 +176,31 @@
"\n", "\n",
"# 🔄 Indexing and Transforming Graph Data\n", "# 🔄 Indexing and Transforming Graph Data\n",
"await index_and_transform_graphiti_nodes_and_edges()" "await index_and_transform_graphiti_nodes_and_edges()"
], ]
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Database deleted successfully.\n",
"Database deleted successfully.\n",
"User d3b51a32-38e1-4fe5-8270-6dc1d6ebfdf0 has registered.\n",
"Pipeline file_load_from_filesystem load step completed in 0.10 seconds\n",
"1 load package(s) were loaded to destination sqlalchemy and into dataset public\n",
"The sqlalchemy destination used postgresql://cognee:***@127.0.0.1:5432/cognee_db location to store data\n",
"Load package 1736937839.7739599 is LOADED and contains no failed jobs\n",
"Pipeline file_load_from_filesystem load step completed in 0.06 seconds\n",
"1 load package(s) were loaded to destination sqlalchemy and into dataset public\n",
"The sqlalchemy destination used postgresql://cognee:***@127.0.0.1:5432/cognee_db location to store data\n",
"Load package 1736937841.8467042 is LOADED and contains no failed jobs\n",
"Graph database initialized.\n",
"Added text: Kamala Harris is the Attorney Gener...\n",
"Added text: As AG, Harris was in office from Ja...\n",
"✅ Result Processed: <graphiti_core.graphiti.Graphiti object at 0x326fe0ce0>\n"
]
}
],
"execution_count": 4
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": "## Retrieving and generating answer from graphiti graph with cognee retriever" "source": "## Retrieving and generating answer from graphiti graph with cognee retriever"
}, },
{ {
"cell_type": "code",
"execution_count": 5,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2025-01-15T10:44:27.844438Z", "end_time": "2025-01-15T10:44:27.844438Z",
"start_time": "2025-01-15T10:44:25.013325Z" "start_time": "2025-01-15T10:44:25.013325Z"
} }
}, },
"cell_type": "code", "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"💡 Answer: Kamala Harris was in office as Attorney General of California from January 3, 2011, to January 3, 2017.\n"
]
}
],
"source": [ "source": [
"# Step 1: Formulating the Query 🔍\n", "# Step 1: Formulating the Query 🔍\n",
"query = \"When was Kamala Harris in office?\"\n", "query = \"When was Kamala Harris in office?\"\n",
@ -205,10 +215,7 @@
"# Step 3: Preparing the Context for the LLM\n", "# Step 3: Preparing the Context for the LLM\n",
"context = retrieved_edges_to_string(triplets)\n", "context = retrieved_edges_to_string(triplets)\n",
"\n", "\n",
"args = {\n", "args = {\"question\": query, \"context\": context}\n",
" \"question\": query,\n",
" \"context\": context\n",
"}\n",
"\n", "\n",
"# Step 4: Generating Prompts ✍️\n", "# Step 4: Generating Prompts ✍️\n",
"user_prompt = render_prompt(\"graph_context_for_question.txt\", args)\n", "user_prompt = render_prompt(\"graph_context_for_question.txt\", args)\n",
@ -217,24 +224,14 @@
"# Step 5: Interacting with the LLM 🤖\n", "# Step 5: Interacting with the LLM 🤖\n",
"llm_client = get_llm_client()\n", "llm_client = get_llm_client()\n",
"computed_answer = await llm_client.acreate_structured_output(\n", "computed_answer = await llm_client.acreate_structured_output(\n",
" text_input=user_prompt, # Input prompt for the user context\n", " text_input=user_prompt, # Input prompt for the user context\n",
" system_prompt=system_prompt, # System-level instructions for the model\n", " system_prompt=system_prompt, # System-level instructions for the model\n",
" response_model=str,\n", " response_model=str,\n",
")\n", ")\n",
"\n", "\n",
"# Step 6: Displaying the Computed Answer ✨\n", "# Step 6: Displaying the Computed Answer ✨\n",
"print(f\"💡 Answer: {computed_answer}\")" "print(f\"💡 Answer: {computed_answer}\")"
], ]
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"💡 Answer: Kamala Harris was in office as Attorney General of California from January 3, 2011, to January 3, 2017.\n"
]
}
],
"execution_count": 5
} }
], ],
"metadata": { "metadata": {

View file

@ -98,12 +98,12 @@
"random.seed(42)\n", "random.seed(42)\n",
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n", "instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
"\n", "\n",
"out_path = \"out\" \n", "out_path = \"out\"\n",
"if not Path(out_path).exists():\n", "if not Path(out_path).exists():\n",
" Path(out_path).mkdir()\n", " Path(out_path).mkdir()\n",
"contexts_filename = out_path / Path(\n", "contexts_filename = out_path / Path(\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n", " f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
" )\n", ")\n",
"\n", "\n",
"answers = []\n", "answers = []\n",
"for instance in tqdm(instances, desc=\"Getting answers\"):\n", "for instance in tqdm(instances, desc=\"Getting answers\"):\n",
@ -143,7 +143,7 @@
"source": [ "source": [
"metric_name_list = [\"Correctness\"]\n", "metric_name_list = [\"Correctness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -173,7 +173,7 @@
"source": [ "source": [
"metric_name_list = [\"Comprehensiveness\"]\n", "metric_name_list = [\"Comprehensiveness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -203,7 +203,7 @@
"source": [ "source": [
"metric_name_list = [\"Diversity\"]\n", "metric_name_list = [\"Diversity\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -212,9 +212,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"Diversity = statistics.mean(\n", "Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Diversity)" "print(Diversity)"
] ]
}, },
@ -233,7 +231,7 @@
"source": [ "source": [
"metric_name_list = [\"Empowerment\"]\n", "metric_name_list = [\"Empowerment\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -263,7 +261,7 @@
"source": [ "source": [
"metric_name_list = [\"Directness\"]\n", "metric_name_list = [\"Directness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -272,9 +270,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"Directness = statistics.mean(\n", "Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Directness)" "print(Directness)"
] ]
}, },
@ -301,7 +297,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -310,9 +306,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"F1_score = statistics.mean(\n", "F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(F1_score)" "print(F1_score)"
] ]
}, },
@ -331,7 +325,7 @@
"source": [ "source": [
"metric_name_list = [\"EM\"]\n", "metric_name_list = [\"EM\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -340,9 +334,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"EM = statistics.mean(\n", "EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(EM)" "print(EM)"
] ]
}, },
@ -379,12 +371,12 @@
"random.seed(42)\n", "random.seed(42)\n",
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n", "instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
"\n", "\n",
"out_path = \"out\" \n", "out_path = \"out\"\n",
"if not Path(out_path).exists():\n", "if not Path(out_path).exists():\n",
" Path(out_path).mkdir()\n", " Path(out_path).mkdir()\n",
"contexts_filename = out_path / Path(\n", "contexts_filename = out_path / Path(\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n", " f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
" )\n", ")\n",
"\n", "\n",
"answers = []\n", "answers = []\n",
"for instance in tqdm(instances, desc=\"Getting answers\"):\n", "for instance in tqdm(instances, desc=\"Getting answers\"):\n",
@ -414,7 +406,7 @@
"source": [ "source": [
"metric_name_list = [\"Correctness\"]\n", "metric_name_list = [\"Correctness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -444,7 +436,7 @@
"source": [ "source": [
"metric_name_list = [\"Comprehensiveness\"]\n", "metric_name_list = [\"Comprehensiveness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -474,7 +466,7 @@
"source": [ "source": [
"metric_name_list = [\"Diversity\"]\n", "metric_name_list = [\"Diversity\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -483,9 +475,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"Diversity = statistics.mean(\n", "Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Diversity)" "print(Diversity)"
] ]
}, },
@ -504,7 +494,7 @@
"source": [ "source": [
"metric_name_list = [\"Empowerment\"]\n", "metric_name_list = [\"Empowerment\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -534,7 +524,7 @@
"source": [ "source": [
"metric_name_list = [\"Directness\"]\n", "metric_name_list = [\"Directness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -543,9 +533,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"Directness = statistics.mean(\n", "Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Directness)" "print(Directness)"
] ]
}, },
@ -572,7 +560,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -581,9 +569,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"F1_score = statistics.mean(\n", "F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(F1_score)" "print(F1_score)"
] ]
}, },
@ -602,7 +588,7 @@
"source": [ "source": [
"metric_name_list = [\"EM\"]\n", "metric_name_list = [\"EM\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -611,9 +597,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"EM = statistics.mean(\n", "EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(EM)" "print(EM)"
] ]
}, },
@ -650,12 +634,12 @@
"random.seed(42)\n", "random.seed(42)\n",
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n", "instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
"\n", "\n",
"out_path = \"out\" \n", "out_path = \"out\"\n",
"if not Path(out_path).exists():\n", "if not Path(out_path).exists():\n",
" Path(out_path).mkdir()\n", " Path(out_path).mkdir()\n",
"contexts_filename = out_path / Path(\n", "contexts_filename = out_path / Path(\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n", " f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
" )\n", ")\n",
"\n", "\n",
"answers = []\n", "answers = []\n",
"for instance in tqdm(instances, desc=\"Getting answers\"):\n", "for instance in tqdm(instances, desc=\"Getting answers\"):\n",
@ -685,7 +669,7 @@
"source": [ "source": [
"metric_name_list = [\"Correctness\"]\n", "metric_name_list = [\"Correctness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -715,7 +699,7 @@
"source": [ "source": [
"metric_name_list = [\"Comprehensiveness\"]\n", "metric_name_list = [\"Comprehensiveness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -745,7 +729,7 @@
"source": [ "source": [
"metric_name_list = [\"Diversity\"]\n", "metric_name_list = [\"Diversity\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -754,9 +738,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"Diversity = statistics.mean(\n", "Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Diversity)" "print(Diversity)"
] ]
}, },
@ -775,7 +757,7 @@
"source": [ "source": [
"metric_name_list = [\"Empowerment\"]\n", "metric_name_list = [\"Empowerment\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -805,7 +787,7 @@
"source": [ "source": [
"metric_name_list = [\"Directness\"]\n", "metric_name_list = [\"Directness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -814,9 +796,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"Directness = statistics.mean(\n", "Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Directness)" "print(Directness)"
] ]
}, },
@ -843,7 +823,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -852,9 +832,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"F1_score = statistics.mean(\n", "F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(F1_score)" "print(F1_score)"
] ]
}, },
@ -873,7 +851,7 @@
"source": [ "source": [
"metric_name_list = [\"EM\"]\n", "metric_name_list = [\"EM\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -882,9 +860,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"EM = statistics.mean(\n", "EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(EM)" "print(EM)"
] ]
}, },
@ -921,12 +897,12 @@
"random.seed(42)\n", "random.seed(42)\n",
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n", "instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
"\n", "\n",
"out_path = \"out\" \n", "out_path = \"out\"\n",
"if not Path(out_path).exists():\n", "if not Path(out_path).exists():\n",
" Path(out_path).mkdir()\n", " Path(out_path).mkdir()\n",
"contexts_filename = out_path / Path(\n", "contexts_filename = out_path / Path(\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n", " f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
" )\n", ")\n",
"\n", "\n",
"answers = []\n", "answers = []\n",
"for instance in tqdm(instances, desc=\"Getting answers\"):\n", "for instance in tqdm(instances, desc=\"Getting answers\"):\n",
@ -956,7 +932,7 @@
"source": [ "source": [
"metric_name_list = [\"Correctness\"]\n", "metric_name_list = [\"Correctness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -986,7 +962,7 @@
"source": [ "source": [
"metric_name_list = [\"Comprehensiveness\"]\n", "metric_name_list = [\"Comprehensiveness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -1016,7 +992,7 @@
"source": [ "source": [
"metric_name_list = [\"Diversity\"]\n", "metric_name_list = [\"Diversity\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -1025,9 +1001,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"Diversity = statistics.mean(\n", "Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Diversity)" "print(Diversity)"
] ]
}, },
@ -1046,7 +1020,7 @@
"source": [ "source": [
"metric_name_list = [\"Empowerment\"]\n", "metric_name_list = [\"Empowerment\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -1076,7 +1050,7 @@
"source": [ "source": [
"metric_name_list = [\"Directness\"]\n", "metric_name_list = [\"Directness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -1085,9 +1059,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"Directness = statistics.mean(\n", "Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Directness)" "print(Directness)"
] ]
}, },
@ -1114,7 +1086,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -1123,9 +1095,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"F1_score = statistics.mean(\n", "F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(F1_score)" "print(F1_score)"
] ]
}, },
@ -1144,7 +1114,7 @@
"source": [ "source": [
"metric_name_list = [\"EM\"]\n", "metric_name_list = [\"EM\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -1153,9 +1123,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"EM = statistics.mean(\n", "EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(EM)" "print(EM)"
] ]
} }

View file

@ -682,12 +682,12 @@
"random.seed(42)\n", "random.seed(42)\n",
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n", "instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
"\n", "\n",
"out_path = \"out\" \n", "out_path = \"out\"\n",
"if not Path(out_path).exists():\n", "if not Path(out_path).exists():\n",
" Path(out_path).mkdir()\n", " Path(out_path).mkdir()\n",
"contexts_filename = out_path / Path(\n", "contexts_filename = out_path / Path(\n",
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n", " f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
" )\n", ")\n",
"\n", "\n",
"answers = []\n", "answers = []\n",
"for instance in tqdm(instances, desc=\"Getting answers\"):\n", "for instance in tqdm(instances, desc=\"Getting answers\"):\n",
@ -728,7 +728,7 @@
"source": [ "source": [
"metric_name_list = [\"Correctness\"]\n", "metric_name_list = [\"Correctness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -761,7 +761,7 @@
"source": [ "source": [
"metric_name_list = [\"Comprehensiveness\"]\n", "metric_name_list = [\"Comprehensiveness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -794,7 +794,7 @@
"source": [ "source": [
"metric_name_list = [\"Diversity\"]\n", "metric_name_list = [\"Diversity\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -804,9 +804,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"Diversity = statistics.mean(\n", "Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Diversity)" "print(Diversity)"
] ]
}, },
@ -827,7 +825,7 @@
"source": [ "source": [
"metric_name_list = [\"Empowerment\"]\n", "metric_name_list = [\"Empowerment\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -860,7 +858,7 @@
"source": [ "source": [
"metric_name_list = [\"Directness\"]\n", "metric_name_list = [\"Directness\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -870,9 +868,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"Directness = statistics.mean(\n", "Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(Directness)" "print(Directness)"
] ]
}, },
@ -902,7 +898,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -912,9 +908,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"F1_score = statistics.mean(\n", "F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(F1_score)" "print(F1_score)"
] ]
}, },
@ -935,7 +929,7 @@
"source": [ "source": [
"metric_name_list = [\"EM\"]\n", "metric_name_list = [\"EM\"]\n",
"eval_metrics = get_metrics(metric_name_list)\n", "eval_metrics = get_metrics(metric_name_list)\n",
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"]) " "eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
] ]
}, },
{ {
@ -945,9 +939,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"EM = statistics.mean(\n", "EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
")\n",
"print(EM)" "print(EM)"
] ]
}, },

View file

@ -1,13 +1,13 @@
{ {
"cells": [ "cells": [
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EpokQ8Y_5jIJ7HdixZms81Oqgh2sp7-E?usp=sharing)" "source": "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EpokQ8Y_5jIJ7HdixZms81Oqgh2sp7-E?usp=sharing)"
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## LlamaIndex Cognee GraphRAG Integration\n", "## LlamaIndex Cognee GraphRAG Integration\n",
"\n", "\n",
@ -53,17 +53,15 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"source": "!pip install llama-index-graph-rag-cognee==0.1.3" "source": "!pip install llama-index-graph-rag-cognee==0.1.3"
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"import asyncio\n", "import asyncio\n",
@ -75,8 +73,8 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"Ensure youve set up your API keys and installed necessary dependencies.\n", "Ensure youve set up your API keys and installed necessary dependencies.\n",
"\n", "\n",
@ -86,24 +84,24 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"documents = [\n", "documents = [\n",
" Document(\n", " Document(\n",
" text=\"Jessica Miller, Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams.\"\n", " text=\"Jessica Miller, Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams.\"\n",
" ),\n", " ),\n",
" Document(\n", " Document(\n",
" text=\"David Thompson, Creative Graphic Designer with over 8 years of experience in visual design and branding.\"\n", " text=\"David Thompson, Creative Graphic Designer with over 8 years of experience in visual design and branding.\"\n",
" ),\n", " ),\n",
" ]" "]"
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"### 3. Initializing CogneeGraphRAG\n", "### 3. Initializing CogneeGraphRAG\n",
"\n", "\n",
@ -111,10 +109,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"cogneeRAG = CogneeGraphRAG(\n", "cogneeRAG = CogneeGraphRAG(\n",
" llm_api_key=os.environ[\"OPENAI_API_KEY\"],\n", " llm_api_key=os.environ[\"OPENAI_API_KEY\"],\n",
@ -128,8 +126,8 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"### 4. Adding Data to Cognee\n", "### 4. Adding Data to Cognee\n",
"\n", "\n",
@ -137,15 +135,17 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"source": "await cogneeRAG.add(documents, \"test\")" "metadata": {},
"outputs": [],
"source": [
"await cogneeRAG.add(documents, \"test\")"
]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"This step prepares the data for graph-based processing.\n", "This step prepares the data for graph-based processing.\n",
"\n", "\n",
@ -155,15 +155,17 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"source": "await cogneeRAG.process_data(\"test\")" "metadata": {},
"outputs": [],
"source": [
"await cogneeRAG.process_data(\"test\")"
]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"The graph now contains nodes and relationships derived from the dataset, creating a powerful structure for exploration.\n", "The graph now contains nodes and relationships derived from the dataset, creating a powerful structure for exploration.\n",
"\n", "\n",
@ -173,10 +175,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"search_results = await cogneeRAG.search(\"Tell me who are the people mentioned?\")\n", "search_results = await cogneeRAG.search(\"Tell me who are the people mentioned?\")\n",
"\n", "\n",
@ -186,15 +188,15 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": "### Answer prompt based on RAG approach:" "source": "### Answer prompt based on RAG approach:"
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"search_results = await cogneeRAG.rag_search(\"Tell me who are the people mentioned?\")\n", "search_results = await cogneeRAG.rag_search(\"Tell me who are the people mentioned?\")\n",
"\n", "\n",
@ -204,13 +206,13 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": "In conclusion, the results demonstrate a significant advantage of the knowledge graph-based approach (Graphrag) over the RAG approach. Graphrag successfully identified all the mentioned individuals across multiple documents, showcasing its ability to aggregate and infer information from a global context. In contrast, the RAG approach was limited to identifying individuals within a single document due to its chunking-based processing constraints. This highlights Graphrag's superior capability in comprehensively resolving queries that span across a broader corpus of interconnected data." "source": "In conclusion, the results demonstrate a significant advantage of the knowledge graph-based approach (Graphrag) over the RAG approach. Graphrag successfully identified all the mentioned individuals across multiple documents, showcasing its ability to aggregate and infer information from a global context. In contrast, the RAG approach was limited to identifying individuals within a single document due to its chunking-based processing constraints. This highlights Graphrag's superior capability in comprehensively resolving queries that span across a broader corpus of interconnected data."
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"### 7. Finding Related Nodes\n", "### 7. Finding Related Nodes\n",
"\n", "\n",
@ -218,10 +220,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"related_nodes = await cogneeRAG.get_related_nodes(\"person\")\n", "related_nodes = await cogneeRAG.get_related_nodes(\"person\")\n",
"\n", "\n",
@ -231,8 +233,8 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Why Choose Cognee and LlamaIndex?\n", "## Why Choose Cognee and LlamaIndex?\n",
"\n", "\n",