Merge branch 'dev' into feature/cog-186-run-cognee-on-windows
This commit is contained in:
commit
6e691885e6
41 changed files with 1229 additions and 794 deletions
7
.github/pull_request_template.md
vendored
Normal file
7
.github/pull_request_template.md
vendored
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
<!-- .github/pull_request_template.md -->
|
||||
|
||||
## Description
|
||||
<!-- Provide a clear description of the changes in this PR -->
|
||||
|
||||
## DCO Affirmation
|
||||
I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin
|
||||
53
.github/workflows/approve_dco.yaml
vendored
Normal file
53
.github/workflows/approve_dco.yaml
vendored
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
name: DCO Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize, ready_for_review]
|
||||
|
||||
jobs:
|
||||
check-dco:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate Developer Certificate of Origin statement
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
# If using the built-in GITHUB_TOKEN, ensure it has 'read:org' permission.
|
||||
# In GitHub Enterprise or private orgs, you might need a PAT (personal access token) with read:org scope.
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const orgName = 'YOUR_ORGANIZATION_NAME'; // Replace with your org
|
||||
const prUser = context.payload.pull_request.user.login;
|
||||
const prBody = context.payload.pull_request.body || '';
|
||||
|
||||
// Exact text you require in the PR body
|
||||
const requiredStatement = "I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin";
|
||||
|
||||
// 1. Check if user is in the org
|
||||
let isOrgMember = false;
|
||||
try {
|
||||
// Attempt to get membership info
|
||||
const membership = await github.rest.orgs.getMembershipForUser({
|
||||
org: orgName,
|
||||
username: prUser,
|
||||
});
|
||||
// If we get here without an error, user is in the org
|
||||
isOrgMember = true;
|
||||
console.log(`${prUser} is a member of ${orgName}. Skipping DCO check.`);
|
||||
} catch (error) {
|
||||
// If we get a 404, user is NOT an org member
|
||||
if (error.status === 404) {
|
||||
console.log(`${prUser} is NOT a member of ${orgName}. Enforcing DCO check.`);
|
||||
} else {
|
||||
// Some other error—fail the workflow or handle accordingly
|
||||
core.setFailed(`Error checking organization membership: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. If user is not in the org, enforce the DCO statement
|
||||
if (!isOrgMember) {
|
||||
if (!prBody.includes(requiredStatement)) {
|
||||
core.setFailed(
|
||||
`DCO check failed. The PR body must include the following statement:\n\n${requiredStatement}`
|
||||
);
|
||||
}
|
||||
}
|
||||
67
.github/workflows/dockerhub.yml
vendored
67
.github/workflows/dockerhub.yml
vendored
|
|
@ -1,8 +1,9 @@
|
|||
name: build | Build and Push Docker Image to DockerHub
|
||||
name: build | Build and Push Docker Image to dockerhub
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
- main
|
||||
|
||||
jobs:
|
||||
|
|
@ -10,42 +11,38 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Extract Git information
|
||||
id: git-info
|
||||
run: |
|
||||
echo "BRANCH_NAME=${GITHUB_REF_NAME}" >> "$GITHUB_ENV"
|
||||
echo "COMMIT_SHA=${GITHUB_SHA::7}" >> "$GITHUB_ENV"
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: cognee/cognee
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=sha,prefix={{branch}}-
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
|
||||
- name: Build and Push Docker Image
|
||||
run: |
|
||||
IMAGE_NAME=cognee/cognee
|
||||
TAG_VERSION="${BRANCH_NAME}-${COMMIT_SHA}"
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=registry,ref=cognee/cognee:buildcache
|
||||
cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max
|
||||
|
||||
echo "Building image: ${IMAGE_NAME}:${TAG_VERSION}"
|
||||
docker buildx build \
|
||||
--platform linux/amd64,linux/arm64 \
|
||||
--push \
|
||||
--tag "${IMAGE_NAME}:${TAG_VERSION}" \
|
||||
--tag "${IMAGE_NAME}:latest" \
|
||||
.
|
||||
|
||||
- name: Verify pushed Docker images
|
||||
run: |
|
||||
# Verify both platform variants
|
||||
for PLATFORM in "linux/amd64" "linux/arm64"; do
|
||||
echo "Verifying image for $PLATFORM..."
|
||||
docker buildx imagetools inspect "${IMAGE_NAME}:${TAG_VERSION}" --format "{{.Manifest.$PLATFORM.Digest}}"
|
||||
done
|
||||
echo "Successfully verified images in Docker Hub"
|
||||
- name: Image digest
|
||||
run: echo ${{ steps.build.outputs.digest }}
|
||||
4
.github/workflows/test_python_3_10.yml
vendored
4
.github/workflows/test_python_3_10.yml
vendored
|
|
@ -42,6 +42,10 @@ jobs:
|
|||
|
||||
- name: Install dependencies
|
||||
run: poetry install --no-interaction -E docs
|
||||
- name: Download NLTK tokenizer data
|
||||
run: |
|
||||
poetry run python -m nltk.downloader punkt_tab averaged_perceptron_tagger_eng
|
||||
|
||||
|
||||
- name: Run unit tests
|
||||
run: poetry run pytest cognee/tests/unit/
|
||||
|
|
|
|||
5
.github/workflows/test_python_3_11.yml
vendored
5
.github/workflows/test_python_3_11.yml
vendored
|
|
@ -44,6 +44,11 @@ jobs:
|
|||
- name: Install dependencies
|
||||
run: poetry install --no-interaction -E docs
|
||||
|
||||
- name: Download NLTK tokenizer data
|
||||
run: |
|
||||
poetry run python -m nltk.downloader punkt_tab averaged_perceptron_tagger_eng
|
||||
|
||||
|
||||
- name: Run unit tests
|
||||
run: poetry run pytest cognee/tests/unit/
|
||||
|
||||
|
|
|
|||
3
.github/workflows/test_python_3_12.yml
vendored
3
.github/workflows/test_python_3_12.yml
vendored
|
|
@ -43,6 +43,9 @@ jobs:
|
|||
|
||||
- name: Install dependencies
|
||||
run: poetry install --no-interaction -E docs
|
||||
- name: Download NLTK tokenizer data
|
||||
run: |
|
||||
poetry run python -m nltk.downloader punkt_tab averaged_perceptron_tagger_eng
|
||||
|
||||
- name: Run unit tests
|
||||
run: poetry run pytest cognee/tests/unit/
|
||||
|
|
|
|||
|
|
@ -79,6 +79,9 @@ $ git config alias.cos "commit -s"
|
|||
|
||||
Will allow you to write git cos which will automatically sign-off your commit. By signing a commit you are agreeing to the DCO and agree that you will be banned from the topoteretes GitHub organisation and Discord server if you violate the DCO.
|
||||
|
||||
"When a commit is ready to be merged please use the following template to agree to our developer certificate of origin:
|
||||
'I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin'
|
||||
|
||||
We consider the following as violations to the DCO:
|
||||
|
||||
Signing the DCO with a fake name or pseudonym, if you are registered on GitHub or another platform with a fake name then you will not be able to contribute to topoteretes before updating your name;
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
### Installing Manually
|
||||
A MCP server project
|
||||
=======
|
||||
1. Clone the [cognee](www.github.com/topoteretes/cognee) repo
|
||||
1. Clone the [cognee](https://github.com/topoteretes/cognee) repo
|
||||
|
||||
|
||||
|
||||
|
|
@ -37,7 +37,15 @@ source .venv/bin/activate
|
|||
4. Add the new server to your Claude config:
|
||||
|
||||
The file should be located here: ~/Library/Application\ Support/Claude/
|
||||
```
|
||||
cd ~/Library/Application\ Support/Claude/
|
||||
```
|
||||
You need to create claude_desktop_config.json in this folder if it doesn't exist
|
||||
Make sure to add your paths and LLM API key to the file bellow
|
||||
Use your editor of choice, for example Nano:
|
||||
```
|
||||
nano claude_desktop_config.json
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
|
|
@ -83,3 +91,17 @@ npx -y @smithery/cli install cognee --client claude
|
|||
|
||||
Define cognify tool in server.py
|
||||
Restart your Claude desktop.
|
||||
|
||||
|
||||
To use debugger, run:
|
||||
```bash
|
||||
npx @modelcontextprotocol/inspector uv --directory /Users/name/folder run cognee
|
||||
```
|
||||
|
||||
To apply new changes while development you do:
|
||||
|
||||
1. Poetry lock in cognee folder
|
||||
2. uv sync --dev --all-extras --reinstall
|
||||
3. npx @modelcontextprotocol/inspector uv --directory /Users/vasilije/cognee/cognee-mcp run cognee
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import os
|
|||
import asyncio
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
|
||||
from sqlalchemy.testing.plugin.plugin_base import logging
|
||||
|
||||
import cognee
|
||||
import mcp.server.stdio
|
||||
import mcp.types as types
|
||||
|
|
@ -10,6 +12,8 @@ from cognee.api.v1.search import SearchType
|
|||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from mcp.server import NotificationOptions, Server
|
||||
from mcp.server.models import InitializationOptions
|
||||
from PIL import Image
|
||||
|
||||
|
||||
server = Server("cognee-mcp")
|
||||
|
||||
|
|
@ -87,9 +91,46 @@ async def handle_list_tools() -> list[types.Tool]:
|
|||
},
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="visualize",
|
||||
description="Visualize the knowledge graph.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_freshest_png(directory: str) -> Image.Image:
|
||||
if not os.path.exists(directory):
|
||||
raise FileNotFoundError(f"Directory {directory} does not exist")
|
||||
|
||||
# List all files in 'directory' that end with .png
|
||||
files = [f for f in os.listdir(directory) if f.endswith(".png")]
|
||||
if not files:
|
||||
raise FileNotFoundError("No PNG files found in the given directory.")
|
||||
|
||||
# Sort by integer value of the filename (minus the '.png')
|
||||
# Example filename: 1673185134.png -> integer 1673185134
|
||||
try:
|
||||
files_sorted = sorted(files, key=lambda x: int(x.replace(".png", "")))
|
||||
except ValueError as e:
|
||||
raise ValueError("Invalid PNG filename format. Expected timestamp format.") from e
|
||||
|
||||
# The "freshest" file has the largest timestamp
|
||||
freshest_filename = files_sorted[-1]
|
||||
freshest_path = os.path.join(directory, freshest_filename)
|
||||
|
||||
# Open the image with PIL and return the PIL Image object
|
||||
try:
|
||||
return Image.open(freshest_path)
|
||||
except (IOError, OSError) as e:
|
||||
raise IOError(f"Failed to open PNG file {freshest_path}") from e
|
||||
|
||||
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: dict | None
|
||||
|
|
@ -154,6 +195,20 @@ async def handle_call_tool(
|
|||
text="Pruned",
|
||||
)
|
||||
]
|
||||
|
||||
elif name == "visualize":
|
||||
with open(os.devnull, "w") as fnull:
|
||||
with redirect_stdout(fnull), redirect_stderr(fnull):
|
||||
try:
|
||||
results = await cognee.visualize_graph()
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=results,
|
||||
)
|
||||
]
|
||||
except (FileNotFoundError, IOError, ValueError) as e:
|
||||
raise ValueError(f"Failed to create visualization: {str(e)}")
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ version = "0.1.0"
|
|||
description = "A MCP server project"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
dependencies = [
|
||||
"mcp>=1.1.1",
|
||||
"openai==1.59.4",
|
||||
|
|
@ -51,7 +52,7 @@ dependencies = [
|
|||
"pydantic-settings>=2.2.1,<3.0.0",
|
||||
"anthropic>=0.26.1,<1.0.0",
|
||||
"sentry-sdk[fastapi]>=2.9.0,<3.0.0",
|
||||
"fastapi-users[sqlalchemy]", # Optional
|
||||
"fastapi-users[sqlalchemy]>=14.0.0", # Optional
|
||||
"alembic>=1.13.3,<2.0.0",
|
||||
"asyncpg==0.30.0", # Optional
|
||||
"pgvector>=0.3.5,<0.4.0", # Optional
|
||||
|
|
@ -91,4 +92,4 @@ dev = [
|
|||
]
|
||||
|
||||
[project.scripts]
|
||||
cognee = "cognee_mcp:main"
|
||||
cognee = "cognee_mcp:main"
|
||||
18
cognee-mcp/uv.lock
generated
18
cognee-mcp/uv.lock
generated
|
|
@ -561,7 +561,7 @@ name = "click"
|
|||
version = "8.1.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
|
||||
wheels = [
|
||||
|
|
@ -570,7 +570,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "cognee"
|
||||
version = "0.1.21"
|
||||
version = "0.1.22"
|
||||
source = { directory = "../" }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
|
|
@ -633,7 +633,7 @@ requires-dist = [
|
|||
{ name = "dlt", extras = ["sqlalchemy"], specifier = ">=1.4.1,<2.0.0" },
|
||||
{ name = "falkordb", marker = "extra == 'falkordb'", specifier = "==1.0.9" },
|
||||
{ name = "fastapi", specifier = ">=0.109.2,<0.116.0" },
|
||||
{ name = "fastapi-users", extras = ["sqlalchemy"] },
|
||||
{ name = "fastapi-users", extras = ["sqlalchemy"], specifier = "==14.0.0" },
|
||||
{ name = "filetype", specifier = ">=1.2.0,<2.0.0" },
|
||||
{ name = "graphistry", specifier = ">=0.33.5,<0.34.0" },
|
||||
{ name = "groq", marker = "extra == 'groq'", specifier = "==0.8.0" },
|
||||
|
|
@ -647,12 +647,12 @@ requires-dist = [
|
|||
{ name = "langfuse", specifier = ">=2.32.0,<3.0.0" },
|
||||
{ name = "langsmith", marker = "extra == 'langchain'", specifier = "==0.2.3" },
|
||||
{ name = "litellm", specifier = "==1.57.2" },
|
||||
{ name = "llama-index-core", marker = "extra == 'llama-index'", specifier = ">=0.12.10.post1,<0.13.0" },
|
||||
{ name = "llama-index-core", marker = "extra == 'llama-index'", specifier = ">=0.12.11,<0.13.0" },
|
||||
{ name = "matplotlib", specifier = ">=3.8.3,<4.0.0" },
|
||||
{ name = "neo4j", marker = "extra == 'neo4j'", specifier = ">=5.20.0,<6.0.0" },
|
||||
{ name = "nest-asyncio", specifier = "==1.6.0" },
|
||||
{ name = "networkx", specifier = ">=3.2.1,<4.0.0" },
|
||||
{ name = "nltk", specifier = ">=3.8.1,<4.0.0" },
|
||||
{ name = "nltk", specifier = "==3.9.1" },
|
||||
{ name = "numpy", specifier = "==1.26.4" },
|
||||
{ name = "openai", specifier = "==1.59.4" },
|
||||
{ name = "pandas", specifier = "==2.2.3" },
|
||||
|
|
@ -674,7 +674,7 @@ requires-dist = [
|
|||
{ name = "tiktoken", specifier = "==0.7.0" },
|
||||
{ name = "transformers", specifier = ">=4.46.3,<5.0.0" },
|
||||
{ name = "typing-extensions", specifier = "==4.12.2" },
|
||||
{ name = "unstructured", extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], marker = "extra == 'docs'", specifier = ">=0.16.10,<0.17.0" },
|
||||
{ name = "unstructured", extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], marker = "extra == 'docs'", specifier = ">=0.16.13,<0.17.0" },
|
||||
{ name = "uvicorn", specifier = "==0.22.0" },
|
||||
{ name = "weaviate-client", marker = "extra == 'weaviate'", specifier = "==4.9.6" },
|
||||
]
|
||||
|
|
@ -777,7 +777,7 @@ requires-dist = [
|
|||
{ name = "dlt", extras = ["sqlalchemy"], specifier = ">=1.4.1,<2.0.0" },
|
||||
{ name = "falkordb", specifier = "==1.0.9" },
|
||||
{ name = "fastapi", specifier = ">=0.109.2,<0.110.0" },
|
||||
{ name = "fastapi-users", extras = ["sqlalchemy"] },
|
||||
{ name = "fastapi-users", extras = ["sqlalchemy"], specifier = ">=14.0.0" },
|
||||
{ name = "filetype", specifier = ">=1.2.0,<2.0.0" },
|
||||
{ name = "gitpython", specifier = ">=3.1.43,<4.0.0" },
|
||||
{ name = "graphistry", specifier = ">=0.33.5,<0.34.0" },
|
||||
|
|
@ -3359,7 +3359,7 @@ name = "portalocker"
|
|||
version = "2.10.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pywin32", marker = "platform_system == 'Windows'" },
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
|
||||
wheels = [
|
||||
|
|
@ -4954,7 +4954,7 @@ name = "tqdm"
|
|||
version = "4.67.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
|
||||
wheels = [
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from .api.v1.config.config import config
|
|||
from .api.v1.datasets.datasets import datasets
|
||||
from .api.v1.prune import prune
|
||||
from .api.v1.search import SearchType, get_search_history, search
|
||||
from .api.v1.visualize import visualize
|
||||
from .api.v1.visualize import visualize_graph
|
||||
from .shared.utils import create_cognee_style_network_with_logo
|
||||
|
||||
# Pipelines
|
||||
|
|
|
|||
|
|
@ -10,5 +10,6 @@ async def visualize_graph(label: str = "name"):
|
|||
logging.info(graph_data)
|
||||
|
||||
graph = await create_cognee_style_network_with_logo(graph_data, label=label)
|
||||
logging.info("The HTML file has been stored on your home directory! Navigate there with cd ~")
|
||||
|
||||
return graph
|
||||
|
|
|
|||
|
|
@ -62,10 +62,12 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
async def add_node(self, node: DataPoint):
|
||||
serialized_properties = self.serialize_properties(node.model_dump())
|
||||
|
||||
query = dedent("""MERGE (node {id: $node_id})
|
||||
query = dedent(
|
||||
"""MERGE (node {id: $node_id})
|
||||
ON CREATE SET node += $properties, node.updated_at = timestamp()
|
||||
ON MATCH SET node += $properties, node.updated_at = timestamp()
|
||||
RETURN ID(node) AS internal_id, node.id AS nodeId""")
|
||||
RETURN ID(node) AS internal_id, node.id AS nodeId"""
|
||||
)
|
||||
|
||||
params = {
|
||||
"node_id": str(node.id),
|
||||
|
|
@ -182,13 +184,15 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
):
|
||||
serialized_properties = self.serialize_properties(edge_properties)
|
||||
|
||||
query = dedent("""MATCH (from_node {id: $from_node}),
|
||||
query = dedent(
|
||||
"""MATCH (from_node {id: $from_node}),
|
||||
(to_node {id: $to_node})
|
||||
MERGE (from_node)-[r]->(to_node)
|
||||
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
|
||||
ON MATCH SET r += $properties, r.updated_at = timestamp()
|
||||
RETURN r
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
params = {
|
||||
"from_node": str(from_node),
|
||||
|
|
|
|||
|
|
@ -88,23 +88,27 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
}
|
||||
)
|
||||
|
||||
return dedent(f"""
|
||||
return dedent(
|
||||
f"""
|
||||
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
|
||||
ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp()
|
||||
ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp()
|
||||
""").strip()
|
||||
"""
|
||||
).strip()
|
||||
|
||||
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
|
||||
properties = await self.stringify_properties(edge[3])
|
||||
properties = f"{{{properties}}}"
|
||||
|
||||
return dedent(f"""
|
||||
return dedent(
|
||||
f"""
|
||||
MERGE (source {{id:'{edge[0]}'}})
|
||||
MERGE (target {{id: '{edge[1]}'}})
|
||||
MERGE (source)-[edge:{edge[2]} {properties}]->(target)
|
||||
ON MATCH SET edge.updated_at = timestamp()
|
||||
ON CREATE SET edge.updated_at = timestamp()
|
||||
""").strip()
|
||||
"""
|
||||
).strip()
|
||||
|
||||
async def create_collection(self, collection_name: str):
|
||||
pass
|
||||
|
|
@ -195,12 +199,14 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
self.query(query)
|
||||
|
||||
async def has_edges(self, edges):
|
||||
query = dedent("""
|
||||
query = dedent(
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (a)-[r]->(b)
|
||||
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
|
||||
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
|
||||
""").strip()
|
||||
"""
|
||||
).strip()
|
||||
|
||||
params = {
|
||||
"edges": [
|
||||
|
|
@ -279,14 +285,16 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
|
||||
[label, attribute_name] = collection_name.split(".")
|
||||
|
||||
query = dedent(f"""
|
||||
query = dedent(
|
||||
f"""
|
||||
CALL db.idx.vector.queryNodes(
|
||||
'{label}',
|
||||
'{attribute_name}',
|
||||
{limit},
|
||||
vecf32({query_vector})
|
||||
) YIELD node, score
|
||||
""").strip()
|
||||
"""
|
||||
).strip()
|
||||
|
||||
result = self.query(query)
|
||||
|
||||
|
|
|
|||
|
|
@ -93,10 +93,12 @@ class SQLAlchemyAdapter:
|
|||
if self.engine.dialect.name == "postgresql":
|
||||
async with self.engine.begin() as connection:
|
||||
result = await connection.execute(
|
||||
text("""
|
||||
text(
|
||||
"""
|
||||
SELECT schema_name FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema');
|
||||
""")
|
||||
"""
|
||||
)
|
||||
)
|
||||
return [schema[0] for schema in result.fetchall()]
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -1,24 +1,34 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Dict
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
import pickle
|
||||
|
||||
|
||||
# Define metadata type
|
||||
class MetaData(TypedDict):
|
||||
index_fields: list[str]
|
||||
|
||||
|
||||
# Updated DataPoint model with versioning and new fields
|
||||
class DataPoint(BaseModel):
|
||||
__tablename__ = "data_point"
|
||||
id: UUID = Field(default_factory=uuid4)
|
||||
updated_at: Optional[datetime] = datetime.now(timezone.utc)
|
||||
created_at: int = Field(
|
||||
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
)
|
||||
updated_at: int = Field(
|
||||
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
)
|
||||
version: int = 1 # Default version
|
||||
topological_rank: Optional[int] = 0
|
||||
_metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"}
|
||||
|
||||
# class Config:
|
||||
# underscore_attrs_are_private = True
|
||||
# Override the Pydantic configuration
|
||||
class Config:
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
@classmethod
|
||||
def get_embeddable_data(self, data_point):
|
||||
|
|
@ -31,11 +41,11 @@ class DataPoint(BaseModel):
|
|||
|
||||
if isinstance(attribute, str):
|
||||
return attribute.strip()
|
||||
else:
|
||||
return attribute
|
||||
return attribute
|
||||
|
||||
@classmethod
|
||||
def get_embeddable_properties(self, data_point):
|
||||
"""Retrieve all embeddable properties."""
|
||||
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
|
||||
return [
|
||||
getattr(data_point, field, None) for field in data_point._metadata["index_fields"]
|
||||
|
|
@ -45,4 +55,40 @@ class DataPoint(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def get_embeddable_property_names(self, data_point):
|
||||
"""Retrieve names of embeddable properties."""
|
||||
return data_point._metadata["index_fields"] or []
|
||||
|
||||
def update_version(self):
|
||||
"""Update the version and updated_at timestamp."""
|
||||
self.version += 1
|
||||
self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
|
||||
# JSON Serialization
|
||||
def to_json(self) -> str:
|
||||
"""Serialize the instance to a JSON string."""
|
||||
return self.json()
|
||||
|
||||
@classmethod
|
||||
def from_json(self, json_str: str):
|
||||
"""Deserialize the instance from a JSON string."""
|
||||
return self.model_validate_json(json_str)
|
||||
|
||||
# Pickle Serialization
|
||||
def to_pickle(self) -> bytes:
|
||||
"""Serialize the instance to pickle-compatible bytes."""
|
||||
return pickle.dumps(self.dict())
|
||||
|
||||
@classmethod
|
||||
def from_pickle(self, pickled_data: bytes):
|
||||
"""Deserialize the instance from pickled bytes."""
|
||||
data = pickle.loads(pickled_data)
|
||||
return self(**data)
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Serialize model to a dictionary."""
|
||||
return self.model_dump(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DataPoint":
|
||||
"""Deserialize model from a dictionary."""
|
||||
return cls.model_validate(data)
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ import networkx as nx
|
|||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import tiktoken
|
||||
import nltk
|
||||
import base64
|
||||
|
||||
import time
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
|
@ -23,13 +21,40 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
|
||||
from uuid import uuid4
|
||||
import pathlib
|
||||
|
||||
import nltk
|
||||
from cognee.shared.exceptions import IngestionError
|
||||
|
||||
# Analytics Proxy Url, currently hosted by Vercel
|
||||
proxy_url = "https://test.prometh.ai"
|
||||
|
||||
|
||||
def get_entities(tagged_tokens):
|
||||
nltk.download("maxent_ne_chunker", quiet=True)
|
||||
from nltk.chunk import ne_chunk
|
||||
|
||||
return ne_chunk(tagged_tokens)
|
||||
|
||||
|
||||
def extract_pos_tags(sentence):
|
||||
"""Extract Part-of-Speech (POS) tags for words in a sentence."""
|
||||
|
||||
# Ensure that the necessary NLTK resources are downloaded
|
||||
nltk.download("words", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||
|
||||
from nltk.tag import pos_tag
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
# Tokenize the sentence into words
|
||||
tokens = word_tokenize(sentence)
|
||||
|
||||
# Tag each word with its corresponding POS tag
|
||||
pos_tags = pos_tag(tokens)
|
||||
|
||||
return pos_tags
|
||||
|
||||
|
||||
def get_anonymous_id():
|
||||
"""Creates or reads a anonymous user id"""
|
||||
home_dir = str(pathlib.Path(pathlib.Path(__file__).parent.parent.parent.resolve()))
|
||||
|
|
@ -243,33 +268,6 @@ async def render_graph(
|
|||
# return df.replace([np.inf, -np.inf, np.nan], None)
|
||||
|
||||
|
||||
def get_entities(tagged_tokens):
|
||||
nltk.download("maxent_ne_chunker", quiet=True)
|
||||
from nltk.chunk import ne_chunk
|
||||
|
||||
return ne_chunk(tagged_tokens)
|
||||
|
||||
|
||||
def extract_pos_tags(sentence):
|
||||
"""Extract Part-of-Speech (POS) tags for words in a sentence."""
|
||||
|
||||
# Ensure that the necessary NLTK resources are downloaded
|
||||
nltk.download("words", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
nltk.download("averaged_perceptron_tagger", quiet=True)
|
||||
|
||||
from nltk.tag import pos_tag
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
# Tokenize the sentence into words
|
||||
tokens = word_tokenize(sentence)
|
||||
|
||||
# Tag each word with its corresponding POS tag
|
||||
pos_tags = pos_tag(tokens)
|
||||
|
||||
return pos_tags
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
|
|
@ -396,6 +394,7 @@ async def create_cognee_style_network_with_logo(
|
|||
|
||||
from bokeh.embed import file_html
|
||||
from bokeh.resources import CDN
|
||||
from bokeh.io import export_png
|
||||
|
||||
logging.info("Converting graph to serializable format...")
|
||||
G = await convert_to_serializable_graph(G)
|
||||
|
|
@ -445,13 +444,14 @@ async def create_cognee_style_network_with_logo(
|
|||
|
||||
logging.info(f"Saving visualization to {output_filename}...")
|
||||
html_content = file_html(p, CDN, title)
|
||||
with open(output_filename, "w") as f:
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
# Construct the final output file path
|
||||
output_filepath = os.path.join(home_dir, output_filename)
|
||||
with open(output_filepath, "w") as f:
|
||||
f.write(html_content)
|
||||
|
||||
logging.info("Visualization complete.")
|
||||
|
||||
if bokeh_object:
|
||||
return p
|
||||
return html_content
|
||||
|
||||
|
||||
|
|
@ -512,7 +512,7 @@ if __name__ == "__main__":
|
|||
G,
|
||||
output_filename="example_network.html",
|
||||
title="Example Cognee Network",
|
||||
node_attribute="group", # Attribute to use for coloring nodes
|
||||
label="group", # Attribute to use for coloring nodes
|
||||
layout_func=nx.spring_layout, # Layout function
|
||||
layout_scale=3.0, # Scale for the layout
|
||||
logo_alpha=0.2,
|
||||
|
|
|
|||
|
|
@ -19,9 +19,11 @@ async def index_and_transform_graphiti_nodes_and_edges():
|
|||
raise RuntimeError("Initialization error") from e
|
||||
|
||||
await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""")
|
||||
await graph_engine.query("""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
|
||||
await graph_engine.query(
|
||||
"""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
|
||||
r.target_node_id = target.id,
|
||||
r.relationship_name = type(r) RETURN r""")
|
||||
r.relationship_name = type(r) RETURN r"""
|
||||
)
|
||||
await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""")
|
||||
|
||||
nodes_data, edges_data = await graph_engine.get_model_independent_graph_data()
|
||||
|
|
|
|||
|
|
@ -36,12 +36,12 @@ def test_AudioDocument():
|
|||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
|
||||
):
|
||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
)
|
||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||
)
|
||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||
)
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
assert ground_truth["len_text"] == len(
|
||||
paragraph_data.text
|
||||
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||
assert (
|
||||
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||
|
|
|
|||
|
|
@ -25,12 +25,12 @@ def test_ImageDocument():
|
|||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
|
||||
):
|
||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
)
|
||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||
)
|
||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||
)
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
assert ground_truth["len_text"] == len(
|
||||
paragraph_data.text
|
||||
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||
assert (
|
||||
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||
|
|
|
|||
|
|
@ -27,12 +27,12 @@ def test_PdfDocument():
|
|||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker")
|
||||
):
|
||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
)
|
||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||
)
|
||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||
)
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
assert ground_truth["len_text"] == len(
|
||||
paragraph_data.text
|
||||
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||
assert (
|
||||
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||
|
|
|
|||
|
|
@ -39,12 +39,12 @@ def test_TextDocument(input_file, chunk_size):
|
|||
for ground_truth, paragraph_data in zip(
|
||||
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker")
|
||||
):
|
||||
assert ground_truth["word_count"] == paragraph_data.word_count, (
|
||||
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
)
|
||||
assert ground_truth["len_text"] == len(paragraph_data.text), (
|
||||
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||
)
|
||||
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
|
||||
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||
)
|
||||
assert (
|
||||
ground_truth["word_count"] == paragraph_data.word_count
|
||||
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
|
||||
assert ground_truth["len_text"] == len(
|
||||
paragraph_data.text
|
||||
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
|
||||
assert (
|
||||
ground_truth["cut_type"] == paragraph_data.cut_type
|
||||
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
|
||||
|
|
|
|||
|
|
@ -71,32 +71,32 @@ def test_UnstructuredDocument():
|
|||
for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
|
||||
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
||||
)
|
||||
assert (
|
||||
"sentence_cut" == paragraph_data.cut_type
|
||||
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||
|
||||
# Test DOCX
|
||||
for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
|
||||
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
||||
assert "sentence_end" == paragraph_data.cut_type, (
|
||||
f" sentence_end != {paragraph_data.cut_type = }"
|
||||
)
|
||||
assert (
|
||||
"sentence_end" == paragraph_data.cut_type
|
||||
), f" sentence_end != {paragraph_data.cut_type = }"
|
||||
|
||||
# TEST CSV
|
||||
for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
|
||||
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
|
||||
f"Read text doesn't match expected text: {paragraph_data.text}"
|
||||
)
|
||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
||||
)
|
||||
assert (
|
||||
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
|
||||
), f"Read text doesn't match expected text: {paragraph_data.text}"
|
||||
assert (
|
||||
"sentence_cut" == paragraph_data.cut_type
|
||||
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||
|
||||
# Test XLSX
|
||||
for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"):
|
||||
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
|
||||
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
||||
f" sentence_cut != {paragraph_data.cut_type = }"
|
||||
)
|
||||
assert (
|
||||
"sentence_cut" == paragraph_data.cut_type
|
||||
), f" sentence_cut != {paragraph_data.cut_type = }"
|
||||
|
|
|
|||
|
|
@ -30,9 +30,9 @@ async def test_deduplication():
|
|||
|
||||
result = await relational_engine.get_all_data_from_table("data")
|
||||
assert len(result) == 1, "More than one data entity was found."
|
||||
assert result[0]["name"] == "Natural_language_processing_copy", (
|
||||
"Result name does not match expected value."
|
||||
)
|
||||
assert (
|
||||
result[0]["name"] == "Natural_language_processing_copy"
|
||||
), "Result name does not match expected value."
|
||||
|
||||
result = await relational_engine.get_all_data_from_table("datasets")
|
||||
assert len(result) == 2, "Unexpected number of datasets found."
|
||||
|
|
@ -61,9 +61,9 @@ async def test_deduplication():
|
|||
|
||||
result = await relational_engine.get_all_data_from_table("data")
|
||||
assert len(result) == 1, "More than one data entity was found."
|
||||
assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], (
|
||||
"Content hash is not a part of file name."
|
||||
)
|
||||
assert (
|
||||
hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"]
|
||||
), "Content hash is not a part of file name."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
|
|
|||
|
|
@ -85,9 +85,9 @@ async def main():
|
|||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
assert not os.path.exists(get_relational_engine().db_path), (
|
||||
"SQLite relational database is not empty"
|
||||
)
|
||||
assert not os.path.exists(
|
||||
get_relational_engine().db_path
|
||||
), "SQLite relational database is not empty"
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
|
||||
|
|
|
|||
|
|
@ -82,9 +82,9 @@ async def main():
|
|||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
assert not os.path.exists(get_relational_engine().db_path), (
|
||||
"SQLite relational database is not empty"
|
||||
)
|
||||
assert not os.path.exists(
|
||||
get_relational_engine().db_path
|
||||
), "SQLite relational database is not empty"
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
|
||||
|
|
|
|||
|
|
@ -24,28 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
|
|||
data_hash = hashlib.md5(encoded_text).hexdigest()
|
||||
# Get data entry from database based on hash contents
|
||||
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
||||
assert os.path.isfile(data.raw_data_location), (
|
||||
f"Data location doesn't exist: {data.raw_data_location}"
|
||||
)
|
||||
assert os.path.isfile(
|
||||
data.raw_data_location
|
||||
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||
# Test deletion of data along with local files created by cognee
|
||||
await engine.delete_data_entity(data.id)
|
||||
assert not os.path.exists(data.raw_data_location), (
|
||||
f"Data location still exists after deletion: {data.raw_data_location}"
|
||||
)
|
||||
assert not os.path.exists(
|
||||
data.raw_data_location
|
||||
), f"Data location still exists after deletion: {data.raw_data_location}"
|
||||
|
||||
async with engine.get_async_session() as session:
|
||||
# Get data entry from database based on file path
|
||||
data = (
|
||||
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
||||
).one()
|
||||
assert os.path.isfile(data.raw_data_location), (
|
||||
f"Data location doesn't exist: {data.raw_data_location}"
|
||||
)
|
||||
assert os.path.isfile(
|
||||
data.raw_data_location
|
||||
), f"Data location doesn't exist: {data.raw_data_location}"
|
||||
# Test local files not created by cognee won't get deleted
|
||||
await engine.delete_data_entity(data.id)
|
||||
assert os.path.exists(data.raw_data_location), (
|
||||
f"Data location doesn't exists: {data.raw_data_location}"
|
||||
)
|
||||
assert os.path.exists(
|
||||
data.raw_data_location
|
||||
), f"Data location doesn't exists: {data.raw_data_location}"
|
||||
|
||||
|
||||
async def test_getting_of_documents(dataset_name_1):
|
||||
|
|
@ -54,16 +54,16 @@ async def test_getting_of_documents(dataset_name_1):
|
|||
|
||||
user = await get_default_user()
|
||||
document_ids = await get_document_ids_for_user(user.id, [dataset_name_1])
|
||||
assert len(document_ids) == 1, (
|
||||
f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
||||
)
|
||||
assert (
|
||||
len(document_ids) == 1
|
||||
), f"Number of expected documents doesn't match {len(document_ids)} != 1"
|
||||
|
||||
# Test getting of documents for search when no dataset is provided
|
||||
user = await get_default_user()
|
||||
document_ids = await get_document_ids_for_user(user.id)
|
||||
assert len(document_ids) == 2, (
|
||||
f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
||||
)
|
||||
assert (
|
||||
len(document_ids) == 2
|
||||
), f"Number of expected documents doesn't match {len(document_ids)} != 2"
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ batch_paragraphs_vals = [True, False]
|
|||
def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs):
|
||||
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)
|
||||
reconstructed_text = "".join([chunk["text"] for chunk in chunks])
|
||||
assert reconstructed_text == input_text, (
|
||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||
)
|
||||
assert (
|
||||
reconstructed_text == input_text
|
||||
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -36,9 +36,9 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
|
|||
chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks])
|
||||
|
||||
larger_chunks = chunk_lengths[chunk_lengths > paragraph_length]
|
||||
assert np.all(chunk_lengths <= paragraph_length), (
|
||||
f"{paragraph_length = }: {larger_chunks} are too large"
|
||||
)
|
||||
assert np.all(
|
||||
chunk_lengths <= paragraph_length
|
||||
), f"{paragraph_length = }: {larger_chunks} are too large"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -50,6 +50,6 @@ def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_
|
|||
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs
|
||||
)
|
||||
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
||||
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
|
||||
f"{chunk_indices = } are not monotonically increasing"
|
||||
)
|
||||
assert np.all(
|
||||
chunk_indices == np.arange(len(chunk_indices))
|
||||
), f"{chunk_indices = } are not monotonically increasing"
|
||||
|
|
|
|||
|
|
@ -58,9 +58,9 @@ def run_chunking_test(test_text, expected_chunks):
|
|||
|
||||
for expected_chunks_item, chunk in zip(expected_chunks, chunks):
|
||||
for key in ["text", "word_count", "cut_type"]:
|
||||
assert chunk[key] == expected_chunks_item[key], (
|
||||
f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
|
||||
)
|
||||
assert (
|
||||
chunk[key] == expected_chunks_item[key]
|
||||
), f"{key = }: {chunk[key] = } != {expected_chunks_item[key] = }"
|
||||
|
||||
|
||||
def test_chunking_whole_text():
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ maximum_length_vals = [None, 8, 64]
|
|||
def test_chunk_by_sentence_isomorphism(input_text, maximum_length):
|
||||
chunks = chunk_by_sentence(input_text, maximum_length)
|
||||
reconstructed_text = "".join([chunk[1] for chunk in chunks])
|
||||
assert reconstructed_text == input_text, (
|
||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||
)
|
||||
assert (
|
||||
reconstructed_text == input_text
|
||||
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -36,6 +36,6 @@ def test_paragraph_chunk_length(input_text, maximum_length):
|
|||
chunk_lengths = np.array([len(list(chunk_by_word(chunk[1]))) for chunk in chunks])
|
||||
|
||||
larger_chunks = chunk_lengths[chunk_lengths > maximum_length]
|
||||
assert np.all(chunk_lengths <= maximum_length), (
|
||||
f"{maximum_length = }: {larger_chunks} are too large"
|
||||
)
|
||||
assert np.all(
|
||||
chunk_lengths <= maximum_length
|
||||
), f"{maximum_length = }: {larger_chunks} are too large"
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS
|
|||
def test_chunk_by_word_isomorphism(input_text):
|
||||
chunks = chunk_by_word(input_text)
|
||||
reconstructed_text = "".join([chunk[0] for chunk in chunks])
|
||||
assert reconstructed_text == input_text, (
|
||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||
)
|
||||
assert (
|
||||
reconstructed_text == input_text
|
||||
), f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import logging
|
|||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
from evals.qa_dataset_utils import load_qa_dataset
|
||||
from evals.qa_metrics_utils import get_metric
|
||||
from evals.qa_metrics_utils import get_metrics
|
||||
from evals.qa_context_provider_utils import qa_context_providers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -34,7 +34,7 @@ async def answer_qa_instance(instance, context_provider):
|
|||
return answer_prediction
|
||||
|
||||
|
||||
async def deepeval_answers(instances, answers, eval_metric):
|
||||
async def deepeval_answers(instances, answers, eval_metrics):
|
||||
test_cases = []
|
||||
|
||||
for instance, answer in zip(instances, answers):
|
||||
|
|
@ -44,37 +44,54 @@ async def deepeval_answers(instances, answers, eval_metric):
|
|||
test_cases.append(test_case)
|
||||
|
||||
eval_set = EvaluationDataset(test_cases)
|
||||
eval_results = eval_set.evaluate([eval_metric])
|
||||
eval_results = eval_set.evaluate(eval_metrics)
|
||||
|
||||
return eval_results
|
||||
|
||||
|
||||
async def deepeval_on_instances(instances, context_provider, eval_metric):
|
||||
async def deepeval_on_instances(instances, context_provider, eval_metrics):
|
||||
answers = []
|
||||
for instance in tqdm(instances, desc="Getting answers"):
|
||||
answer = await answer_qa_instance(instance, context_provider)
|
||||
answers.append(answer)
|
||||
|
||||
eval_results = await deepeval_answers(instances, answers, eval_metric)
|
||||
avg_score = statistics.mean(
|
||||
[result.metrics_data[0].score for result in eval_results.test_results]
|
||||
)
|
||||
eval_results = await deepeval_answers(instances, answers, eval_metrics)
|
||||
score_lists_dict = {}
|
||||
for instance_result in eval_results.test_results:
|
||||
for metric_result in instance_result.metrics_data:
|
||||
if metric_result.name not in score_lists_dict:
|
||||
score_lists_dict[metric_result.name] = []
|
||||
score_lists_dict[metric_result.name].append(metric_result.score)
|
||||
|
||||
return avg_score
|
||||
avg_scores = {
|
||||
metric_name: statistics.mean(scorelist)
|
||||
for metric_name, scorelist in score_lists_dict.items()
|
||||
}
|
||||
|
||||
return avg_scores
|
||||
|
||||
|
||||
async def eval_on_QA_dataset(
|
||||
dataset_name_or_filename: str, context_provider_name, num_samples, eval_metric_name
|
||||
dataset_name_or_filename: str, context_provider_name, num_samples, metric_name_list
|
||||
):
|
||||
dataset = load_qa_dataset(dataset_name_or_filename)
|
||||
context_provider = qa_context_providers[context_provider_name]
|
||||
eval_metric = get_metric(eval_metric_name)
|
||||
eval_metrics = get_metrics(metric_name_list)
|
||||
instances = dataset if not num_samples else dataset[:num_samples]
|
||||
|
||||
if eval_metric_name.startswith("promptfoo"):
|
||||
return await eval_metric.measure(instances, context_provider)
|
||||
if "promptfoo_metrics" in eval_metrics:
|
||||
promptfoo_results = await eval_metrics["promptfoo_metrics"].measure(
|
||||
instances, context_provider
|
||||
)
|
||||
else:
|
||||
return await deepeval_on_instances(instances, context_provider, eval_metric)
|
||||
promptfoo_results = {}
|
||||
deepeval_results = await deepeval_on_instances(
|
||||
instances, context_provider, eval_metrics["deepeval_metrics"]
|
||||
)
|
||||
|
||||
results = promptfoo_results | deepeval_results
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -89,11 +106,11 @@ if __name__ == "__main__":
|
|||
help="RAG option to use for providing context",
|
||||
)
|
||||
parser.add_argument("--num_samples", type=int, default=500)
|
||||
parser.add_argument("--metric_name", type=str, default="Correctness")
|
||||
parser.add_argument("--metrics", type=str, nargs="+", default=["Correctness"])
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
avg_score = asyncio.run(
|
||||
eval_on_QA_dataset(args.dataset, args.rag_option, args.num_samples, args.metric_name)
|
||||
avg_scores = asyncio.run(
|
||||
eval_on_QA_dataset(args.dataset, args.rag_option, args.num_samples, args.metrics)
|
||||
)
|
||||
logger.info(f"Average {args.metric_name}: {avg_score}")
|
||||
logger.info(f"{avg_scores}")
|
||||
|
|
|
|||
|
|
@ -3,19 +3,42 @@ import os
|
|||
import yaml
|
||||
import json
|
||||
import shutil
|
||||
from cognee.infrastructure.llm.prompts.llm_judge_prompts import llm_judge_prompts
|
||||
|
||||
|
||||
def is_valid_promptfoo_metric(metric_name: str):
|
||||
try:
|
||||
prefix, suffix = metric_name.split(".")
|
||||
except ValueError:
|
||||
return False
|
||||
if prefix != "promptfoo":
|
||||
return False
|
||||
if suffix not in llm_judge_prompts:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class PromptfooMetric:
|
||||
def __init__(self, judge_prompt):
|
||||
def __init__(self, metric_name_list):
|
||||
promptfoo_path = shutil.which("promptfoo")
|
||||
self.wrapper = PromptfooWrapper(promptfoo_path=promptfoo_path)
|
||||
self.judge_prompt = judge_prompt
|
||||
self.prompts = {}
|
||||
for metric_name in metric_name_list:
|
||||
if is_valid_promptfoo_metric(metric_name):
|
||||
self.prompts[metric_name] = llm_judge_prompts[metric_name.split(".")[1]]
|
||||
else:
|
||||
raise Exception(f"{metric_name} is not a valid promptfoo metric")
|
||||
|
||||
async def measure(self, instances, context_provider):
|
||||
with open(os.path.join(os.getcwd(), "evals/promptfoo_config_template.yaml"), "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
config["defaultTest"] = [{"assert": {"type": "llm_rubric", "value": self.judge_prompt}}]
|
||||
config["defaultTest"] = {
|
||||
"assert": [
|
||||
{"type": "llm-rubric", "value": prompt, "name": metric_name}
|
||||
for metric_name, prompt in self.prompts.items()
|
||||
]
|
||||
}
|
||||
|
||||
# Fill config file with test cases
|
||||
tests = []
|
||||
|
|
@ -48,6 +71,9 @@ class PromptfooMetric:
|
|||
with open(file_path, "r") as file:
|
||||
results = json.load(file)
|
||||
|
||||
self.score = results["results"]["prompts"][0]["metrics"]["score"]
|
||||
scores = {}
|
||||
|
||||
return self.score
|
||||
for result in results["results"]["results"][0]["gradingResult"]["componentResults"]:
|
||||
scores[result["assertion"]["name"]] = result["score"]
|
||||
|
||||
return scores
|
||||
|
|
|
|||
|
|
@ -21,9 +21,11 @@ async def cognify_instance(instance: dict):
|
|||
async def get_context_with_cognee(instance: dict) -> str:
|
||||
await cognify_instance(instance)
|
||||
|
||||
insights = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"])
|
||||
# TODO: Fix insights
|
||||
# insights = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"])
|
||||
summaries = await cognee.search(SearchType.SUMMARIES, query_text=instance["question"])
|
||||
search_results = insights + summaries
|
||||
# search_results = insights + summaries
|
||||
search_results = summaries
|
||||
|
||||
search_results_str = "\n".join([context_item["text"] for context_item in search_results])
|
||||
|
||||
|
|
@ -31,7 +33,11 @@ async def get_context_with_cognee(instance: dict) -> str:
|
|||
|
||||
|
||||
async def get_context_with_simple_rag(instance: dict) -> str:
|
||||
await cognify_instance(instance)
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
for title, sentences in instance["context"]:
|
||||
await cognee.add("\n".join(sentences), dataset_name="QA")
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
found_chunks = await vector_engine.search("document_chunk_text", instance["question"], limit=5)
|
||||
|
|
|
|||
18
evals/qa_eval_parameters.json
Normal file
18
evals/qa_eval_parameters.json
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
{
|
||||
"dataset": [
|
||||
"hotpotqa"
|
||||
],
|
||||
"rag_option": [
|
||||
"no_rag",
|
||||
"cognee",
|
||||
"simple_rag",
|
||||
"brute_force"
|
||||
],
|
||||
"num_samples": [
|
||||
2
|
||||
],
|
||||
"metric_names": [
|
||||
"Correctness",
|
||||
"Comprehensiveness"
|
||||
]
|
||||
}
|
||||
60
evals/qa_eval_utils.py
Normal file
60
evals/qa_eval_utils.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import itertools
|
||||
import matplotlib.pyplot as plt
|
||||
from jsonschema import ValidationError, validate
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
paramset_json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dataset": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"rag_option": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"num_samples": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer", "minimum": 1},
|
||||
},
|
||||
"metric_names": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["dataset", "rag_option", "num_samples", "metric_names"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
|
||||
def save_table_as_image(df, image_path):
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.axis("tight")
|
||||
plt.axis("off")
|
||||
plt.table(cellText=df.values, colLabels=df.columns, rowLabels=df.index, loc="center")
|
||||
plt.title(f"{df.index.name}")
|
||||
plt.savefig(image_path, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
def save_results_as_image(results, out_path):
|
||||
for dataset, num_samples_data in results.items():
|
||||
for num_samples, table_data in num_samples_data.items():
|
||||
df = pd.DataFrame.from_dict(table_data, orient="index")
|
||||
df.index.name = f"Dataset: {dataset}, Num Samples: {num_samples}"
|
||||
image_path = Path(out_path) / Path(f"table_{dataset}_{num_samples}.png")
|
||||
save_table_as_image(df, image_path)
|
||||
|
||||
|
||||
def get_combinations(parameters):
|
||||
try:
|
||||
validate(instance=parameters, schema=paramset_json_schema)
|
||||
except ValidationError as e:
|
||||
raise ValidationError(f"Invalid parameter set: {e.message}")
|
||||
|
||||
params_for_combos = {k: v for k, v in parameters.items() if k != "metric_name"}
|
||||
keys, values = zip(*params_for_combos.items())
|
||||
combinations = [dict(zip(keys, combo)) for combo in itertools.product(*values)]
|
||||
return combinations
|
||||
|
|
@ -7,10 +7,9 @@ from evals.deepeval_metrics import (
|
|||
f1_score_metric,
|
||||
em_score_metric,
|
||||
)
|
||||
from evals.promptfoo_metrics import PromptfooMetric
|
||||
from deepeval.metrics import AnswerRelevancyMetric
|
||||
import deepeval.metrics
|
||||
from cognee.infrastructure.llm.prompts.llm_judge_prompts import llm_judge_prompts
|
||||
from evals.promptfoo_metrics import is_valid_promptfoo_metric, PromptfooMetric
|
||||
|
||||
native_deepeval_metrics = {"AnswerRelevancy": AnswerRelevancyMetric}
|
||||
|
||||
|
|
@ -24,18 +23,10 @@ custom_deepeval_metrics = {
|
|||
"EM": em_score_metric,
|
||||
}
|
||||
|
||||
promptfoo_metrics = {
|
||||
"promptfoo.correctness": PromptfooMetric(llm_judge_prompts["correctness"]),
|
||||
"promptfoo.comprehensiveness": PromptfooMetric(llm_judge_prompts["comprehensiveness"]),
|
||||
"promptfoo.diversity": PromptfooMetric(llm_judge_prompts["diversity"]),
|
||||
"promptfoo.empowerment": PromptfooMetric(llm_judge_prompts["empowerment"]),
|
||||
"promptfoo.directness": PromptfooMetric(llm_judge_prompts["directness"]),
|
||||
}
|
||||
|
||||
qa_metrics = native_deepeval_metrics | custom_deepeval_metrics | promptfoo_metrics
|
||||
qa_metrics = native_deepeval_metrics | custom_deepeval_metrics
|
||||
|
||||
|
||||
def get_metric(metric_name: str):
|
||||
def get_deepeval_metric(metric_name: str):
|
||||
if metric_name in qa_metrics:
|
||||
metric = qa_metrics[metric_name]
|
||||
else:
|
||||
|
|
@ -49,3 +40,27 @@ def get_metric(metric_name: str):
|
|||
metric = metric()
|
||||
|
||||
return metric
|
||||
|
||||
|
||||
def get_metrics(metric_name_list: list[str]):
|
||||
metrics = {
|
||||
"deepeval_metrics": [],
|
||||
}
|
||||
|
||||
promptfoo_metric_names = []
|
||||
|
||||
for metric_name in metric_name_list:
|
||||
if (
|
||||
(metric_name in native_deepeval_metrics)
|
||||
or (metric_name in custom_deepeval_metrics)
|
||||
or hasattr(deepeval.metrics, metric_name)
|
||||
):
|
||||
metric = get_deepeval_metric(metric_name)
|
||||
metrics["deepeval_metrics"].append(metric)
|
||||
elif is_valid_promptfoo_metric(metric_name):
|
||||
promptfoo_metric_names.append(metric_name)
|
||||
|
||||
if len(promptfoo_metric_names) > 0:
|
||||
metrics["promptfoo_metrics"] = PromptfooMetric(promptfoo_metric_names)
|
||||
|
||||
return metrics
|
||||
|
|
|
|||
57
evals/run_qa_eval.py
Normal file
57
evals/run_qa_eval.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import asyncio
|
||||
from evals.eval_on_hotpot import eval_on_QA_dataset
|
||||
from evals.qa_eval_utils import get_combinations, save_results_as_image
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
|
||||
async def run_evals_on_paramset(paramset: dict, out_path: str):
|
||||
combinations = get_combinations(paramset)
|
||||
json_path = Path(out_path) / Path("results.json")
|
||||
results = {}
|
||||
for params in combinations:
|
||||
dataset = params["dataset"]
|
||||
num_samples = params["num_samples"]
|
||||
rag_option = params["rag_option"]
|
||||
|
||||
result = await eval_on_QA_dataset(
|
||||
dataset,
|
||||
rag_option,
|
||||
num_samples,
|
||||
paramset["metric_names"],
|
||||
)
|
||||
|
||||
if dataset not in results:
|
||||
results[dataset] = {}
|
||||
if num_samples not in results[dataset]:
|
||||
results[dataset][num_samples] = {}
|
||||
|
||||
results[dataset][num_samples][rag_option] = result
|
||||
|
||||
with open(json_path, "w") as file:
|
||||
json.dump(results, file, indent=1)
|
||||
|
||||
save_results_as_image(results, out_path)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--params_file", type=str, required=True, help="Which dataset to evaluate on"
|
||||
)
|
||||
parser.add_argument("--out_dir", type=str, help="Dir to save eval results")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.params_file, "r") as file:
|
||||
parameters = json.load(file)
|
||||
|
||||
await run_evals_on_paramset(parameters, args.out_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
1109
poetry.lock
generated
1109
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -40,7 +40,6 @@ networkx = "^3.2.1"
|
|||
aiosqlite = "^0.20.0"
|
||||
pandas = "2.2.3"
|
||||
filetype = "^1.2.0"
|
||||
nltk = "^3.8.1"
|
||||
dlt = {extras = ["sqlalchemy"], version = "^1.4.1"}
|
||||
aiofiles = "^23.2.1"
|
||||
qdrant-client = {version = "^1.9.0", optional = true}
|
||||
|
|
@ -64,19 +63,20 @@ langfuse = "^2.32.0"
|
|||
pydantic-settings = "^2.2.1"
|
||||
anthropic = "^0.26.1"
|
||||
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
|
||||
fastapi-users = {version = "*", extras = ["sqlalchemy"]}
|
||||
fastapi-users = {version = "14.0.0", extras = ["sqlalchemy"]}
|
||||
alembic = "^1.13.3"
|
||||
asyncpg = {version = "0.30.0", optional = true}
|
||||
pgvector = {version = "^0.3.5", optional = true}
|
||||
psycopg2 = {version = "^2.9.10", optional = true}
|
||||
llama-index-core = {version = "^0.12.10.post1", optional = true}
|
||||
llama-index-core = {version = "^0.12.11", optional = true}
|
||||
deepeval = {version = "^2.0.1", optional = true}
|
||||
transformers = "^4.46.3"
|
||||
pymilvus = {version = "^2.5.0", optional = true}
|
||||
unstructured = { extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], version = "^0.16.10", optional = true }
|
||||
unstructured = { extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], version = "^0.16.13", optional = true }
|
||||
pre-commit = "^4.0.1"
|
||||
httpx = "0.27.0"
|
||||
bokeh="^3.6.2"
|
||||
nltk = "3.9.1"
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue