Merge branch 'dev' into feature/cog-186-run-cognee-on-windows

This commit is contained in:
hajdul88 2025-01-17 09:06:00 +01:00 committed by GitHub
commit 6e691885e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 1229 additions and 794 deletions

7
.github/pull_request_template.md vendored Normal file
View file

@ -0,0 +1,7 @@
<!-- .github/pull_request_template.md -->
## Description
<!-- Provide a clear description of the changes in this PR -->
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin

53
.github/workflows/approve_dco.yaml vendored Normal file
View file

@ -0,0 +1,53 @@
name: DCO Check
on:
pull_request:
types: [opened, edited, reopened, synchronize, ready_for_review]
jobs:
check-dco:
runs-on: ubuntu-latest
steps:
- name: Validate Developer Certificate of Origin statement
uses: actions/github-script@v6
with:
# If using the built-in GITHUB_TOKEN, ensure it has 'read:org' permission.
# In GitHub Enterprise or private orgs, you might need a PAT (personal access token) with read:org scope.
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const orgName = 'YOUR_ORGANIZATION_NAME'; // Replace with your org
const prUser = context.payload.pull_request.user.login;
const prBody = context.payload.pull_request.body || '';
// Exact text you require in the PR body
const requiredStatement = "I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin";
// 1. Check if user is in the org
let isOrgMember = false;
try {
// Attempt to get membership info
const membership = await github.rest.orgs.getMembershipForUser({
org: orgName,
username: prUser,
});
// If we get here without an error, user is in the org
isOrgMember = true;
console.log(`${prUser} is a member of ${orgName}. Skipping DCO check.`);
} catch (error) {
// If we get a 404, user is NOT an org member
if (error.status === 404) {
console.log(`${prUser} is NOT a member of ${orgName}. Enforcing DCO check.`);
} else {
// Some other error—fail the workflow or handle accordingly
core.setFailed(`Error checking organization membership: ${error.message}`);
}
}
// 2. If user is not in the org, enforce the DCO statement
if (!isOrgMember) {
if (!prBody.includes(requiredStatement)) {
core.setFailed(
`DCO check failed. The PR body must include the following statement:\n\n${requiredStatement}`
);
}
}

View file

@ -1,8 +1,9 @@
name: build | Build and Push Docker Image to DockerHub
name: build | Build and Push Docker Image to dockerhub
on:
push:
branches:
- dev
- main
jobs:
@ -10,42 +11,38 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Extract Git information
id: git-info
run: |
echo "BRANCH_NAME=${GITHUB_REF_NAME}" >> "$GITHUB_ENV"
echo "COMMIT_SHA=${GITHUB_SHA::7}" >> "$GITHUB_ENV"
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
images: cognee/cognee
tags: |
type=ref,event=branch
type=sha,prefix={{branch}}-
type=raw,value=latest,enable={{is_default_branch}}
- name: Build and Push Docker Image
run: |
IMAGE_NAME=cognee/cognee
TAG_VERSION="${BRANCH_NAME}-${COMMIT_SHA}"
- name: Build and push
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=cognee/cognee:buildcache
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
echo "Building image: ${IMAGE_NAME}:${TAG_VERSION}"
docker buildx build \
--platform linux/amd64,linux/arm64 \
--push \
--tag "${IMAGE_NAME}:${TAG_VERSION}" \
--tag "${IMAGE_NAME}:latest" \
.
- name: Verify pushed Docker images
run: |
# Verify both platform variants
for PLATFORM in "linux/amd64" "linux/arm64"; do
echo "Verifying image for $PLATFORM..."
docker buildx imagetools inspect "${IMAGE_NAME}:${TAG_VERSION}" --format "{{.Manifest.$PLATFORM.Digest}}"
done
echo "Successfully verified images in Docker Hub"
- name: Image digest
run: echo ${{ steps.build.outputs.digest }}

View file

@ -42,6 +42,10 @@ jobs:
- name: Install dependencies
run: poetry install --no-interaction -E docs
- name: Download NLTK tokenizer data
run: |
poetry run python -m nltk.downloader punkt_tab averaged_perceptron_tagger_eng
- name: Run unit tests
run: poetry run pytest cognee/tests/unit/

View file

@ -44,6 +44,11 @@ jobs:
- name: Install dependencies
run: poetry install --no-interaction -E docs
- name: Download NLTK tokenizer data
run: |
poetry run python -m nltk.downloader punkt_tab averaged_perceptron_tagger_eng
- name: Run unit tests
run: poetry run pytest cognee/tests/unit/

View file

@ -43,6 +43,9 @@ jobs:
- name: Install dependencies
run: poetry install --no-interaction -E docs
- name: Download NLTK tokenizer data
run: |
poetry run python -m nltk.downloader punkt_tab averaged_perceptron_tagger_eng
- name: Run unit tests
run: poetry run pytest cognee/tests/unit/

View file

@ -79,6 +79,9 @@ $ git config alias.cos "commit -s"
Will allow you to write git cos which will automatically sign-off your commit. By signing a commit you are agreeing to the DCO and agree that you will be banned from the topoteretes GitHub organisation and Discord server if you violate the DCO.
"When a commit is ready to be merged please use the following template to agree to our developer certificate of origin:
'I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin'
We consider the following as violations to the DCO:
Signing the DCO with a fake name or pseudonym, if you are registered on GitHub or another platform with a fake name then you will not be able to contribute to topoteretes before updating your name;

View file

@ -6,7 +6,7 @@
### Installing Manually
A MCP server project
=======
1. Clone the [cognee](www.github.com/topoteretes/cognee) repo
1. Clone the [cognee](https://github.com/topoteretes/cognee) repo
@ -37,7 +37,15 @@ source .venv/bin/activate
4. Add the new server to your Claude config:
The file should be located here: ~/Library/Application\ Support/Claude/
```
cd ~/Library/Application\ Support/Claude/
```
You need to create claude_desktop_config.json in this folder if it doesn't exist
Make sure to add your paths and LLM API key to the file bellow
Use your editor of choice, for example Nano:
```
nano claude_desktop_config.json
```
```
@ -83,3 +91,17 @@ npx -y @smithery/cli install cognee --client claude
Define cognify tool in server.py
Restart your Claude desktop.
To use debugger, run:
```bash
npx @modelcontextprotocol/inspector uv --directory /Users/name/folder run cognee
```
To apply new changes while development you do:
1. Poetry lock in cognee folder
2. uv sync --dev --all-extras --reinstall
3. npx @modelcontextprotocol/inspector uv --directory /Users/vasilije/cognee/cognee-mcp run cognee

View file

@ -3,6 +3,8 @@ import os
import asyncio
from contextlib import redirect_stderr, redirect_stdout
from sqlalchemy.testing.plugin.plugin_base import logging
import cognee
import mcp.server.stdio
import mcp.types as types
@ -10,6 +12,8 @@ from cognee.api.v1.search import SearchType
from cognee.shared.data_models import KnowledgeGraph
from mcp.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions
from PIL import Image
server = Server("cognee-mcp")
@ -87,9 +91,46 @@ async def handle_list_tools() -> list[types.Tool]:
},
},
),
types.Tool(
name="visualize",
description="Visualize the knowledge graph.",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string"},
},
},
),
]
def get_freshest_png(directory: str) -> Image.Image:
if not os.path.exists(directory):
raise FileNotFoundError(f"Directory {directory} does not exist")
# List all files in 'directory' that end with .png
files = [f for f in os.listdir(directory) if f.endswith(".png")]
if not files:
raise FileNotFoundError("No PNG files found in the given directory.")
# Sort by integer value of the filename (minus the '.png')
# Example filename: 1673185134.png -> integer 1673185134
try:
files_sorted = sorted(files, key=lambda x: int(x.replace(".png", "")))
except ValueError as e:
raise ValueError("Invalid PNG filename format. Expected timestamp format.") from e
# The "freshest" file has the largest timestamp
freshest_filename = files_sorted[-1]
freshest_path = os.path.join(directory, freshest_filename)
# Open the image with PIL and return the PIL Image object
try:
return Image.open(freshest_path)
except (IOError, OSError) as e:
raise IOError(f"Failed to open PNG file {freshest_path}") from e
@server.call_tool()
async def handle_call_tool(
name: str, arguments: dict | None
@ -154,6 +195,20 @@ async def handle_call_tool(
text="Pruned",
)
]
elif name == "visualize":
with open(os.devnull, "w") as fnull:
with redirect_stdout(fnull), redirect_stderr(fnull):
try:
results = await cognee.visualize_graph()
return [
types.TextContent(
type="text",
text=results,
)
]
except (FileNotFoundError, IOError, ValueError) as e:
raise ValueError(f"Failed to create visualization: {str(e)}")
else:
raise ValueError(f"Unknown tool: {name}")

View file

@ -4,6 +4,7 @@ version = "0.1.0"
description = "A MCP server project"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"mcp>=1.1.1",
"openai==1.59.4",
@ -51,7 +52,7 @@ dependencies = [
"pydantic-settings>=2.2.1,<3.0.0",
"anthropic>=0.26.1,<1.0.0",
"sentry-sdk[fastapi]>=2.9.0,<3.0.0",
"fastapi-users[sqlalchemy]", # Optional
"fastapi-users[sqlalchemy]>=14.0.0", # Optional
"alembic>=1.13.3,<2.0.0",
"asyncpg==0.30.0", # Optional
"pgvector>=0.3.5,<0.4.0", # Optional
@ -91,4 +92,4 @@ dev = [
]
[project.scripts]
cognee = "cognee_mcp:main"
cognee = "cognee_mcp:main"

18
cognee-mcp/uv.lock generated
View file

@ -561,7 +561,7 @@ name = "click"
version = "8.1.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
wheels = [
@ -570,7 +570,7 @@ wheels = [
[[package]]
name = "cognee"
version = "0.1.21"
version = "0.1.22"
source = { directory = "../" }
dependencies = [
{ name = "aiofiles" },
@ -633,7 +633,7 @@ requires-dist = [
{ name = "dlt", extras = ["sqlalchemy"], specifier = ">=1.4.1,<2.0.0" },
{ name = "falkordb", marker = "extra == 'falkordb'", specifier = "==1.0.9" },
{ name = "fastapi", specifier = ">=0.109.2,<0.116.0" },
{ name = "fastapi-users", extras = ["sqlalchemy"] },
{ name = "fastapi-users", extras = ["sqlalchemy"], specifier = "==14.0.0" },
{ name = "filetype", specifier = ">=1.2.0,<2.0.0" },
{ name = "graphistry", specifier = ">=0.33.5,<0.34.0" },
{ name = "groq", marker = "extra == 'groq'", specifier = "==0.8.0" },
@ -647,12 +647,12 @@ requires-dist = [
{ name = "langfuse", specifier = ">=2.32.0,<3.0.0" },
{ name = "langsmith", marker = "extra == 'langchain'", specifier = "==0.2.3" },
{ name = "litellm", specifier = "==1.57.2" },
{ name = "llama-index-core", marker = "extra == 'llama-index'", specifier = ">=0.12.10.post1,<0.13.0" },
{ name = "llama-index-core", marker = "extra == 'llama-index'", specifier = ">=0.12.11,<0.13.0" },
{ name = "matplotlib", specifier = ">=3.8.3,<4.0.0" },
{ name = "neo4j", marker = "extra == 'neo4j'", specifier = ">=5.20.0,<6.0.0" },
{ name = "nest-asyncio", specifier = "==1.6.0" },
{ name = "networkx", specifier = ">=3.2.1,<4.0.0" },
{ name = "nltk", specifier = ">=3.8.1,<4.0.0" },
{ name = "nltk", specifier = "==3.9.1" },
{ name = "numpy", specifier = "==1.26.4" },
{ name = "openai", specifier = "==1.59.4" },
{ name = "pandas", specifier = "==2.2.3" },
@ -674,7 +674,7 @@ requires-dist = [
{ name = "tiktoken", specifier = "==0.7.0" },
{ name = "transformers", specifier = ">=4.46.3,<5.0.0" },
{ name = "typing-extensions", specifier = "==4.12.2" },
{ name = "unstructured", extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], marker = "extra == 'docs'", specifier = ">=0.16.10,<0.17.0" },
{ name = "unstructured", extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], marker = "extra == 'docs'", specifier = ">=0.16.13,<0.17.0" },
{ name = "uvicorn", specifier = "==0.22.0" },
{ name = "weaviate-client", marker = "extra == 'weaviate'", specifier = "==4.9.6" },
]
@ -777,7 +777,7 @@ requires-dist = [
{ name = "dlt", extras = ["sqlalchemy"], specifier = ">=1.4.1,<2.0.0" },
{ name = "falkordb", specifier = "==1.0.9" },
{ name = "fastapi", specifier = ">=0.109.2,<0.110.0" },
{ name = "fastapi-users", extras = ["sqlalchemy"] },
{ name = "fastapi-users", extras = ["sqlalchemy"], specifier = ">=14.0.0" },
{ name = "filetype", specifier = ">=1.2.0,<2.0.0" },
{ name = "gitpython", specifier = ">=3.1.43,<4.0.0" },
{ name = "graphistry", specifier = ">=0.33.5,<0.34.0" },
@ -3359,7 +3359,7 @@ name = "portalocker"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pywin32", marker = "platform_system == 'Windows'" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
wheels = [
@ -4954,7 +4954,7 @@ name = "tqdm"
version = "4.67.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
wheels = [

View file

@ -4,7 +4,7 @@ from .api.v1.config.config import config
from .api.v1.datasets.datasets import datasets
from .api.v1.prune import prune
from .api.v1.search import SearchType, get_search_history, search
from .api.v1.visualize import visualize
from .api.v1.visualize import visualize_graph
from .shared.utils import create_cognee_style_network_with_logo
# Pipelines

View file

@ -10,5 +10,6 @@ async def visualize_graph(label: str = "name"):
logging.info(graph_data)
graph = await create_cognee_style_network_with_logo(graph_data, label=label)
logging.info("The HTML file has been stored on your home directory! Navigate there with cd ~")
return graph

View file

@ -62,10 +62,12 @@ class Neo4jAdapter(GraphDBInterface):
async def add_node(self, node: DataPoint):
serialized_properties = self.serialize_properties(node.model_dump())
query = dedent("""MERGE (node {id: $node_id})
query = dedent(
"""MERGE (node {id: $node_id})
ON CREATE SET node += $properties, node.updated_at = timestamp()
ON MATCH SET node += $properties, node.updated_at = timestamp()
RETURN ID(node) AS internal_id, node.id AS nodeId""")
RETURN ID(node) AS internal_id, node.id AS nodeId"""
)
params = {
"node_id": str(node.id),
@ -182,13 +184,15 @@ class Neo4jAdapter(GraphDBInterface):
):
serialized_properties = self.serialize_properties(edge_properties)
query = dedent("""MATCH (from_node {id: $from_node}),
query = dedent(
"""MATCH (from_node {id: $from_node}),
(to_node {id: $to_node})
MERGE (from_node)-[r]->(to_node)
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r
""")
"""
)
params = {
"from_node": str(from_node),

View file

@ -88,23 +88,27 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
}
)
return dedent(f"""
return dedent(
f"""
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp()
ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp()
""").strip()
"""
).strip()
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
properties = await self.stringify_properties(edge[3])
properties = f"{{{properties}}}"
return dedent(f"""
return dedent(
f"""
MERGE (source {{id:'{edge[0]}'}})
MERGE (target {{id: '{edge[1]}'}})
MERGE (source)-[edge:{edge[2]} {properties}]->(target)
ON MATCH SET edge.updated_at = timestamp()
ON CREATE SET edge.updated_at = timestamp()
""").strip()
"""
).strip()
async def create_collection(self, collection_name: str):
pass
@ -195,12 +199,14 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
self.query(query)
async def has_edges(self, edges):
query = dedent("""
query = dedent(
"""
UNWIND $edges AS edge
MATCH (a)-[r]->(b)
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
""").strip()
"""
).strip()
params = {
"edges": [
@ -279,14 +285,16 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
[label, attribute_name] = collection_name.split(".")
query = dedent(f"""
query = dedent(
f"""
CALL db.idx.vector.queryNodes(
'{label}',
'{attribute_name}',
{limit},
vecf32({query_vector})
) YIELD node, score
""").strip()
"""
).strip()
result = self.query(query)

View file

@ -93,10 +93,12 @@ class SQLAlchemyAdapter:
if self.engine.dialect.name == "postgresql":
async with self.engine.begin() as connection:
result = await connection.execute(
text("""
text(
"""
SELECT schema_name FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema');
""")
"""
)
)
return [schema[0] for schema in result.fetchall()]
return []

View file

@ -1,24 +1,34 @@
from datetime import datetime, timezone
from typing import Optional
from typing import Optional, Any, Dict
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import pickle
# Define metadata type
class MetaData(TypedDict):
index_fields: list[str]
# Updated DataPoint model with versioning and new fields
class DataPoint(BaseModel):
__tablename__ = "data_point"
id: UUID = Field(default_factory=uuid4)
updated_at: Optional[datetime] = datetime.now(timezone.utc)
created_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
)
updated_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
)
version: int = 1 # Default version
topological_rank: Optional[int] = 0
_metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"}
# class Config:
# underscore_attrs_are_private = True
# Override the Pydantic configuration
class Config:
underscore_attrs_are_private = True
@classmethod
def get_embeddable_data(self, data_point):
@ -31,11 +41,11 @@ class DataPoint(BaseModel):
if isinstance(attribute, str):
return attribute.strip()
else:
return attribute
return attribute
@classmethod
def get_embeddable_properties(self, data_point):
"""Retrieve all embeddable properties."""
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
return [
getattr(data_point, field, None) for field in data_point._metadata["index_fields"]
@ -45,4 +55,40 @@ class DataPoint(BaseModel):
@classmethod
def get_embeddable_property_names(self, data_point):
"""Retrieve names of embeddable properties."""
return data_point._metadata["index_fields"] or []
def update_version(self):
"""Update the version and updated_at timestamp."""
self.version += 1
self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)
# JSON Serialization
def to_json(self) -> str:
"""Serialize the instance to a JSON string."""
return self.json()
@classmethod
def from_json(self, json_str: str):
"""Deserialize the instance from a JSON string."""
return self.model_validate_json(json_str)
# Pickle Serialization
def to_pickle(self) -> bytes:
"""Serialize the instance to pickle-compatible bytes."""
return pickle.dumps(self.dict())
@classmethod
def from_pickle(self, pickled_data: bytes):
"""Deserialize the instance from pickled bytes."""
data = pickle.loads(pickled_data)
return self(**data)
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Serialize model to a dictionary."""
return self.model_dump(**kwargs)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DataPoint":
"""Deserialize model from a dictionary."""
return cls.model_validate(data)

View file

@ -11,9 +11,7 @@ import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import tiktoken
import nltk
import base64
import time
import logging
import sys
@ -23,13 +21,40 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from uuid import uuid4
import pathlib
import nltk
from cognee.shared.exceptions import IngestionError
# Analytics Proxy Url, currently hosted by Vercel
proxy_url = "https://test.prometh.ai"
def get_entities(tagged_tokens):
nltk.download("maxent_ne_chunker", quiet=True)
from nltk.chunk import ne_chunk
return ne_chunk(tagged_tokens)
def extract_pos_tags(sentence):
"""Extract Part-of-Speech (POS) tags for words in a sentence."""
# Ensure that the necessary NLTK resources are downloaded
nltk.download("words", quiet=True)
nltk.download("punkt", quiet=True)
nltk.download("averaged_perceptron_tagger", quiet=True)
from nltk.tag import pos_tag
from nltk.tokenize import word_tokenize
# Tokenize the sentence into words
tokens = word_tokenize(sentence)
# Tag each word with its corresponding POS tag
pos_tags = pos_tag(tokens)
return pos_tags
def get_anonymous_id():
"""Creates or reads a anonymous user id"""
home_dir = str(pathlib.Path(pathlib.Path(__file__).parent.parent.parent.resolve()))
@ -243,33 +268,6 @@ async def render_graph(
# return df.replace([np.inf, -np.inf, np.nan], None)
def get_entities(tagged_tokens):
nltk.download("maxent_ne_chunker", quiet=True)
from nltk.chunk import ne_chunk
return ne_chunk(tagged_tokens)
def extract_pos_tags(sentence):
"""Extract Part-of-Speech (POS) tags for words in a sentence."""
# Ensure that the necessary NLTK resources are downloaded
nltk.download("words", quiet=True)
nltk.download("punkt", quiet=True)
nltk.download("averaged_perceptron_tagger", quiet=True)
from nltk.tag import pos_tag
from nltk.tokenize import word_tokenize
# Tokenize the sentence into words
tokens = word_tokenize(sentence)
# Tag each word with its corresponding POS tag
pos_tags = pos_tag(tokens)
return pos_tags
logging.basicConfig(level=logging.INFO)
@ -396,6 +394,7 @@ async def create_cognee_style_network_with_logo(
from bokeh.embed import file_html
from bokeh.resources import CDN
from bokeh.io import export_png
logging.info("Converting graph to serializable format...")
G = await convert_to_serializable_graph(G)
@ -445,13 +444,14 @@ async def create_cognee_style_network_with_logo(
logging.info(f"Saving visualization to {output_filename}...")
html_content = file_html(p, CDN, title)
with open(output_filename, "w") as f:
home_dir = os.path.expanduser("~")
# Construct the final output file path
output_filepath = os.path.join(home_dir, output_filename)
with open(output_filepath, "w") as f:
f.write(html_content)
logging.info("Visualization complete.")
if bokeh_object:
return p
return html_content
@ -512,7 +512,7 @@ if __name__ == "__main__":
G,
output_filename="example_network.html",
title="Example Cognee Network",
node_attribute="group", # Attribute to use for coloring nodes
label="group", # Attribute to use for coloring nodes
layout_func=nx.spring_layout, # Layout function
layout_scale=3.0, # Scale for the layout
logo_alpha=0.2,

View file

@ -19,9 +19,11 @@ async def index_and_transform_graphiti_nodes_and_edges():
raise RuntimeError("Initialization error") from e
await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""")
await graph_engine.query("""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
await graph_engine.query(
"""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
r.target_node_id = target.id,
r.relationship_name = type(r) RETURN r""")
r.relationship_name = type(r) RETURN r"""
)
await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""")
nodes_data, edges_data = await graph_engine.get_model_independent_graph_data()

View file

@ -36,12 +36,12 @@ def test_AudioDocument():
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -25,12 +25,12 @@ def test_ImageDocument():
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -27,12 +27,12 @@ def test_PdfDocument():
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker")
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -39,12 +39,12 @@ def test_TextDocument(input_file, chunk_size):
for ground_truth, paragraph_data in zip(
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker")
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'

View file

@ -71,32 +71,32 @@ def test_UnstructuredDocument():
for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"
# Test DOCX
for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert "sentence_end" == paragraph_data.cut_type, (
f" sentence_end != {paragraph_data.cut_type = }"
)
assert (
"sentence_end" == paragraph_data.cut_type
), f" sentence_end != {paragraph_data.cut_type = }"
# TEST CSV
for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"):
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
f"Read text doesn't match expected text: {paragraph_data.text}"
)
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)
assert (
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
), f"Read text doesn't match expected text: {paragraph_data.text}"
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"
# Test XLSX
for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"

View file

@ -30,9 +30,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found."
assert result[0]["name"] == "Natural_language_processing_copy", (
"Result name does not match expected value."
)
assert (
result[0]["name"] == "Natural_language_processing_copy"
), "Result name does not match expected value."
result = await relational_engine.get_all_data_from_table("datasets")
assert len(result) == 2, "Unexpected number of datasets found."
@ -61,9 +61,9 @@ async def test_deduplication():
result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found."
assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], (
"Content hash is not a part of file name."
)
assert (
hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"]
), "Content hash is not a part of file name."
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)

View file

@ -85,9 +85,9 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists(get_relational_engine().db_path), (
"SQLite relational database is not empty"
)
assert not os.path.exists(
get_relational_engine().db_path
), "SQLite relational database is not empty"
from cognee.infrastructure.databases.graph import get_graph_config

View file

@ -82,9 +82,9 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists(get_relational_engine().db_path), (
"SQLite relational database is not empty"
)
assert not os.path.exists(
get_relational_engine().db_path
), "SQLite relational database is not empty"
from cognee.infrastructure.databases.graph import get_graph_config

View file

@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile(data.raw_data_location), (
f"Data location doesn't exist: {data.raw_data_location}"
)
assert os.path.isfile(
data.raw_data_location
), f"Data location doesn't exist: {data.raw_data_location}"
# Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id)
assert not os.path.exists(data.raw_data_location), (
f"Data location still exists after deletion: {data.raw_data_location}"
)
assert not os.path.exists(
data.raw_data_location
), f"Data location still exists after deletion: {data.raw_data_location}"
async with engine.get_async_session() as session:
# Get data entry from database based on file path
data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
).one()
assert os.path.isfile(data.raw_data_location), (
f"Data location doesn't exist: {data.raw_data_location}"
)
assert os.path.isfile(
data.raw_data_location
), f"Data location doesn't exist: {data.raw_data_location}"
# Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id)
assert os.path.exists(data.raw_data_location), (
f"Data location doesn't exists: {data.raw_data_location}"
)
assert os.path.exists(
data.raw_data_location
), f"Data location doesn't exists: {data.raw_data_location}"
async def test_getting_of_documents(dataset_name_1):
@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1):
user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
assert len(document_ids) == 1, (
f"Number of expected documents doesn't match {len(document_ids)} != 1"
)
assert (
len(document_ids) == 1
), f"Number of expected documents doesn't match {len(document_ids)} != 1"
# Test getting of documents for search when no dataset is provided
user = await get_default_user()
document_ids = await get_document_ids_for_user(user.id)
assert len(document_ids) == 2, (
f"Number of expected documents doesn't match {len(document_ids)} != 2"
)
assert (
len(document_ids) == 2
), f"Number of expected documents doesn't match {len(document_ids)} != 2"
async def main():

View file

@ -17,9 +17,9 @@ batch_paragraphs_vals = [True, False]
def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs):
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
reconstructed_text = "".join([chunk["text"] for chunk in chunks])
assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
)
assert (
reconstructed_text == input_text
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize(
@ -36,9 +36,9 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks])
larger_chunks = chunk_lengths[chunk_lengths > paragraph_length]
assert np.all(chunk_lengths <= paragraph_length), (
f"{paragraph_length = }: {larger_chunks} are too large"
)
assert np.all(
chunk_lengths <= paragraph_length
), f"{paragraph_length = }: {larger_chunks} are too large"
@pytest.mark.parametrize(
@ -50,6 +50,6 @@ def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
)
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
f"{chunk_indices = } are not monotonically increasing"
)
assert np.all(
chunk_indices == np.arange(len(chunk_indices))
), f"{chunk_indices = } are not monotonically increasing"

View file

@ -58,9 +58,9 @@ def run_chunking_test(test_text, expected_chunks):
for expected_chunks_item, chunk in zip(expected_chunks, chunks):
for key in ["text", "word_count", "cut_type"]:
assert chunk[key] == expected_chunks_item[key], (
f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
)
assert (
chunk[key] == expected_chunks_item[key]
), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
def test_chunking_whole_text():

View file

@ -16,9 +16,9 @@ maximum_length_vals = [None, 8, 64]
def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
chunks = chunk_by_sentence(input_text, maximum_length)
reconstructed_text = "".join([chunk[1] for chunk in chunks])
assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
)
assert (
reconstructed_text == input_text
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize(
@ -36,6 +36,6 @@ def test_paragraph_chunk_length(input_text, maximum_length):
chunk_lengths = np.array([len(list(chunk_by_word(chunk[1]))) for chunk in chunks])
larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
assert np.all(chunk_lengths <= maximum_length), (
f"{maximum_length = }: {larger_chunks} are too large"
)
assert np.all(
chunk_lengths <= maximum_length
), f"{maximum_length = }: {larger_chunks} are too large"

View file

@ -17,9 +17,9 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS
def test_chunk_by_word_isomorphism(input_text):
chunks = chunk_by_word(input_text)
reconstructed_text = "".join([chunk[0] for chunk in chunks])
assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
)
assert (
reconstructed_text == input_text
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@pytest.mark.parametrize(

View file

@ -8,7 +8,7 @@ import logging
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from evals.qa_dataset_utils import load_qa_dataset
from evals.qa_metrics_utils import get_metric
from evals.qa_metrics_utils import get_metrics
from evals.qa_context_provider_utils import qa_context_providers
logger = logging.getLogger(__name__)
@ -34,7 +34,7 @@ async def answer_qa_instance(instance, context_provider):
return answer_prediction
async def deepeval_answers(instances, answers, eval_metric):
async def deepeval_answers(instances, answers, eval_metrics):
test_cases = []
for instance, answer in zip(instances, answers):
@ -44,37 +44,54 @@ async def deepeval_answers(instances, answers, eval_metric):
test_cases.append(test_case)
eval_set = EvaluationDataset(test_cases)
eval_results = eval_set.evaluate([eval_metric])
eval_results = eval_set.evaluate(eval_metrics)
return eval_results
async def deepeval_on_instances(instances, context_provider, eval_metric):
async def deepeval_on_instances(instances, context_provider, eval_metrics):
answers = []
for instance in tqdm(instances, desc="Getting answers"):
answer = await answer_qa_instance(instance, context_provider)
answers.append(answer)
eval_results = await deepeval_answers(instances, answers, eval_metric)
avg_score = statistics.mean(
[result.metrics_data[0].score for result in eval_results.test_results]
)
eval_results = await deepeval_answers(instances, answers, eval_metrics)
score_lists_dict = {}
for instance_result in eval_results.test_results:
for metric_result in instance_result.metrics_data:
if metric_result.name not in score_lists_dict:
score_lists_dict[metric_result.name] = []
score_lists_dict[metric_result.name].append(metric_result.score)
return avg_score
avg_scores = {
metric_name: statistics.mean(scorelist)
for metric_name, scorelist in score_lists_dict.items()
}
return avg_scores
async def eval_on_QA_dataset(
dataset_name_or_filename: str, context_provider_name, num_samples, eval_metric_name
dataset_name_or_filename: str, context_provider_name, num_samples, metric_name_list
):
dataset = load_qa_dataset(dataset_name_or_filename)
context_provider = qa_context_providers[context_provider_name]
eval_metric = get_metric(eval_metric_name)
eval_metrics = get_metrics(metric_name_list)
instances = dataset if not num_samples else dataset[:num_samples]
if eval_metric_name.startswith("promptfoo"):
return await eval_metric.measure(instances, context_provider)
if "promptfoo_metrics" in eval_metrics:
promptfoo_results = await eval_metrics["promptfoo_metrics"].measure(
instances, context_provider
)
else:
return await deepeval_on_instances(instances, context_provider, eval_metric)
promptfoo_results = {}
deepeval_results = await deepeval_on_instances(
instances, context_provider, eval_metrics["deepeval_metrics"]
)
results = promptfoo_results | deepeval_results
return results
if __name__ == "__main__":
@ -89,11 +106,11 @@ if __name__ == "__main__":
help="RAG option to use for providing context",
)
parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument("--metric_name", type=str, default="Correctness")
parser.add_argument("--metrics", type=str, nargs="+", default=["Correctness"])
args = parser.parse_args()
avg_score = asyncio.run(
eval_on_QA_dataset(args.dataset, args.rag_option, args.num_samples, args.metric_name)
avg_scores = asyncio.run(
eval_on_QA_dataset(args.dataset, args.rag_option, args.num_samples, args.metrics)
)
logger.info(f"Average {args.metric_name}: {avg_score}")
logger.info(f"{avg_scores}")

View file

@ -3,19 +3,42 @@ import os
import yaml
import json
import shutil
from cognee.infrastructure.llm.prompts.llm_judge_prompts import llm_judge_prompts
def is_valid_promptfoo_metric(metric_name: str):
try:
prefix, suffix = metric_name.split(".")
except ValueError:
return False
if prefix != "promptfoo":
return False
if suffix not in llm_judge_prompts:
return False
return True
class PromptfooMetric:
def __init__(self, judge_prompt):
def __init__(self, metric_name_list):
promptfoo_path = shutil.which("promptfoo")
self.wrapper = PromptfooWrapper(promptfoo_path=promptfoo_path)
self.judge_prompt = judge_prompt
self.prompts = {}
for metric_name in metric_name_list:
if is_valid_promptfoo_metric(metric_name):
self.prompts[metric_name] = llm_judge_prompts[metric_name.split(".")[1]]
else:
raise Exception(f"{metric_name} is not a valid promptfoo metric")
async def measure(self, instances, context_provider):
with open(os.path.join(os.getcwd(), "evals/promptfoo_config_template.yaml"), "r") as file:
config = yaml.safe_load(file)
config["defaultTest"] = [{"assert": {"type": "llm_rubric", "value": self.judge_prompt}}]
config["defaultTest"] = {
"assert": [
{"type": "llm-rubric", "value": prompt, "name": metric_name}
for metric_name, prompt in self.prompts.items()
]
}
# Fill config file with test cases
tests = []
@ -48,6 +71,9 @@ class PromptfooMetric:
with open(file_path, "r") as file:
results = json.load(file)
self.score = results["results"]["prompts"][0]["metrics"]["score"]
scores = {}
return self.score
for result in results["results"]["results"][0]["gradingResult"]["componentResults"]:
scores[result["assertion"]["name"]] = result["score"]
return scores

View file

@ -21,9 +21,11 @@ async def cognify_instance(instance: dict):
async def get_context_with_cognee(instance: dict) -> str:
await cognify_instance(instance)
insights = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"])
# TODO: Fix insights
# insights = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"])
summaries = await cognee.search(SearchType.SUMMARIES, query_text=instance["question"])
search_results = insights + summaries
# search_results = insights + summaries
search_results = summaries
search_results_str = "\n".join([context_item["text"] for context_item in search_results])
@ -31,7 +33,11 @@ async def get_context_with_cognee(instance: dict) -> str:
async def get_context_with_simple_rag(instance: dict) -> str:
await cognify_instance(instance)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
for title, sentences in instance["context"]:
await cognee.add("\n".join(sentences), dataset_name="QA")
vector_engine = get_vector_engine()
found_chunks = await vector_engine.search("document_chunk_text", instance["question"], limit=5)

View file

@ -0,0 +1,18 @@
{
"dataset": [
"hotpotqa"
],
"rag_option": [
"no_rag",
"cognee",
"simple_rag",
"brute_force"
],
"num_samples": [
2
],
"metric_names": [
"Correctness",
"Comprehensiveness"
]
}

60
evals/qa_eval_utils.py Normal file
View file

@ -0,0 +1,60 @@
import itertools
import matplotlib.pyplot as plt
from jsonschema import ValidationError, validate
import pandas as pd
from pathlib import Path
paramset_json_schema = {
"type": "object",
"properties": {
"dataset": {
"type": "array",
"items": {"type": "string"},
},
"rag_option": {
"type": "array",
"items": {"type": "string"},
},
"num_samples": {
"type": "array",
"items": {"type": "integer", "minimum": 1},
},
"metric_names": {
"type": "array",
"items": {"type": "string"},
},
},
"required": ["dataset", "rag_option", "num_samples", "metric_names"],
"additionalProperties": False,
}
def save_table_as_image(df, image_path):
plt.figure(figsize=(10, 6))
plt.axis("tight")
plt.axis("off")
plt.table(cellText=df.values, colLabels=df.columns, rowLabels=df.index, loc="center")
plt.title(f"{df.index.name}")
plt.savefig(image_path, bbox_inches="tight")
plt.close()
def save_results_as_image(results, out_path):
for dataset, num_samples_data in results.items():
for num_samples, table_data in num_samples_data.items():
df = pd.DataFrame.from_dict(table_data, orient="index")
df.index.name = f"Dataset: {dataset}, Num Samples: {num_samples}"
image_path = Path(out_path) / Path(f"table_{dataset}_{num_samples}.png")
save_table_as_image(df, image_path)
def get_combinations(parameters):
try:
validate(instance=parameters, schema=paramset_json_schema)
except ValidationError as e:
raise ValidationError(f"Invalid parameter set: {e.message}")
params_for_combos = {k: v for k, v in parameters.items() if k != "metric_name"}
keys, values = zip(*params_for_combos.items())
combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)]
return combinations

View file

@ -7,10 +7,9 @@ from evals.deepeval_metrics import (
f1_score_metric,
em_score_metric,
)
from evals.promptfoo_metrics import PromptfooMetric
from deepeval.metrics import AnswerRelevancyMetric
import deepeval.metrics
from cognee.infrastructure.llm.prompts.llm_judge_prompts import llm_judge_prompts
from evals.promptfoo_metrics import is_valid_promptfoo_metric, PromptfooMetric
native_deepeval_metrics = {"AnswerRelevancy": AnswerRelevancyMetric}
@ -24,18 +23,10 @@ custom_deepeval_metrics = {
"EM": em_score_metric,
}
promptfoo_metrics = {
"promptfoo.correctness": PromptfooMetric(llm_judge_prompts["correctness"]),
"promptfoo.comprehensiveness": PromptfooMetric(llm_judge_prompts["comprehensiveness"]),
"promptfoo.diversity": PromptfooMetric(llm_judge_prompts["diversity"]),
"promptfoo.empowerment": PromptfooMetric(llm_judge_prompts["empowerment"]),
"promptfoo.directness": PromptfooMetric(llm_judge_prompts["directness"]),
}
qa_metrics = native_deepeval_metrics | custom_deepeval_metrics | promptfoo_metrics
qa_metrics = native_deepeval_metrics | custom_deepeval_metrics
def get_metric(metric_name: str):
def get_deepeval_metric(metric_name: str):
if metric_name in qa_metrics:
metric = qa_metrics[metric_name]
else:
@ -49,3 +40,27 @@ def get_metric(metric_name: str):
metric = metric()
return metric
def get_metrics(metric_name_list: list[str]):
metrics = {
"deepeval_metrics": [],
}
promptfoo_metric_names = []
for metric_name in metric_name_list:
if (
(metric_name in native_deepeval_metrics)
or (metric_name in custom_deepeval_metrics)
or hasattr(deepeval.metrics, metric_name)
):
metric = get_deepeval_metric(metric_name)
metrics["deepeval_metrics"].append(metric)
elif is_valid_promptfoo_metric(metric_name):
promptfoo_metric_names.append(metric_name)
if len(promptfoo_metric_names) > 0:
metrics["promptfoo_metrics"] = PromptfooMetric(promptfoo_metric_names)
return metrics

57
evals/run_qa_eval.py Normal file
View file

@ -0,0 +1,57 @@
import asyncio
from evals.eval_on_hotpot import eval_on_QA_dataset
from evals.qa_eval_utils import get_combinations, save_results_as_image
import argparse
from pathlib import Path
import json
async def run_evals_on_paramset(paramset: dict, out_path: str):
combinations = get_combinations(paramset)
json_path = Path(out_path) / Path("results.json")
results = {}
for params in combinations:
dataset = params["dataset"]
num_samples = params["num_samples"]
rag_option = params["rag_option"]
result = await eval_on_QA_dataset(
dataset,
rag_option,
num_samples,
paramset["metric_names"],
)
if dataset not in results:
results[dataset] = {}
if num_samples not in results[dataset]:
results[dataset][num_samples] = {}
results[dataset][num_samples][rag_option] = result
with open(json_path, "w") as file:
json.dump(results, file, indent=1)
save_results_as_image(results, out_path)
return results
async def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--params_file", type=str, required=True, help="Which dataset to evaluate on"
)
parser.add_argument("--out_dir", type=str, help="Dir to save eval results")
args = parser.parse_args()
with open(args.params_file, "r") as file:
parameters = json.load(file)
await run_evals_on_paramset(parameters, args.out_dir)
if __name__ == "__main__":
asyncio.run(main())

1109
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -40,7 +40,6 @@ networkx = "^3.2.1"
aiosqlite = "^0.20.0"
pandas = "2.2.3"
filetype = "^1.2.0"
nltk = "^3.8.1"
dlt = {extras = ["sqlalchemy"], version = "^1.4.1"}
aiofiles = "^23.2.1"
qdrant-client = {version = "^1.9.0", optional = true}
@ -64,19 +63,20 @@ langfuse = "^2.32.0"
pydantic-settings = "^2.2.1"
anthropic = "^0.26.1"
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
fastapi-users = {version = "*", extras = ["sqlalchemy"]}
fastapi-users = {version = "14.0.0", extras = ["sqlalchemy"]}
alembic = "^1.13.3"
asyncpg = {version = "0.30.0", optional = true}
pgvector = {version = "^0.3.5", optional = true}
psycopg2 = {version = "^2.9.10", optional = true}
llama-index-core = {version = "^0.12.10.post1", optional = true}
llama-index-core = {version = "^0.12.11", optional = true}
deepeval = {version = "^2.0.1", optional = true}
transformers = "^4.46.3"
pymilvus = {version = "^2.5.0", optional = true}
unstructured = { extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], version = "^0.16.10", optional = true }
unstructured = { extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], version = "^0.16.13", optional = true }
pre-commit = "^4.0.1"
httpx = "0.27.0"
bokeh="^3.6.2"
nltk = "3.9.1"