Merge branch 'topoteretes:main' into main
This commit is contained in:
commit
879bcc1edf
94 changed files with 12620 additions and 9150 deletions
2
.github/actions/cognee_setup/action.yml
vendored
2
.github/actions/cognee_setup/action.yml
vendored
|
|
@ -24,4 +24,4 @@ runs:
|
|||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama
|
||||
run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama -E dev
|
||||
|
|
|
|||
19
.github/workflows/python_version_tests.yml
vendored
19
.github/workflows/python_version_tests.yml
vendored
|
|
@ -58,8 +58,10 @@ jobs:
|
|||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Run unit tests
|
||||
shell: bash
|
||||
run: poetry run pytest cognee/tests/unit/
|
||||
env:
|
||||
PYTHONUTF8: 1
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
|
|
@ -74,10 +76,26 @@ jobs:
|
|||
|
||||
- name: Run integration tests
|
||||
if: ${{ !contains(matrix.os, 'windows') }}
|
||||
shell: bash
|
||||
run: poetry run pytest cognee/tests/integration/
|
||||
env:
|
||||
PYTHONUTF8: 1
|
||||
LLM_PROVIDER: openai
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
|
||||
EMBEDDING_PROVIDER: openai
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
|
||||
- name: Run default basic pipeline
|
||||
shell: bash
|
||||
env:
|
||||
PYTHONUTF8: 1
|
||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||
|
||||
|
|
@ -95,6 +113,7 @@ jobs:
|
|||
run: poetry run python ./cognee/tests/test_library.py
|
||||
|
||||
- name: Build with Poetry
|
||||
shell: bash
|
||||
run: poetry build
|
||||
|
||||
- name: Install Package
|
||||
|
|
|
|||
57
.github/workflows/test_memgraph.yml
vendored
Normal file
57
.github/workflows/test_memgraph.yml
vendored
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
name: test | memgraph
|
||||
|
||||
# on:
|
||||
# workflow_dispatch:
|
||||
# pull_request:
|
||||
# types: [labeled, synchronize]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
RUNTIME__LOG_LEVEL: ERROR
|
||||
|
||||
jobs:
|
||||
run_memgraph_integration_test:
|
||||
name: test
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@master
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10.x'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
installer-parallel: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install -E memgraph --no-interaction
|
||||
|
||||
- name: Run default Memgraph
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_URL: ${{ secrets.MEMGRAPH_API_URL }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ secrets.MEMGRAPH_API_KEY }}
|
||||
GRAPH_DATABASE_USERNAME: " "
|
||||
run: poetry run python ./cognee/tests/test_memgraph.py
|
||||
|
|
@ -97,7 +97,7 @@ git push origin feature/your-feature-name
|
|||
|
||||
2. Create a Pull Request:
|
||||
- Go to the [**cognee** repository](https://github.com/topoteretes/cognee)
|
||||
- Click "Compare & Pull Request" and make sure to open PR against dev branch
|
||||
- Click "Compare & Pull Request"
|
||||
- Fill in the PR template with details about your changes
|
||||
|
||||
## 5. 📜 Developer Certificate of Origin (DCO)
|
||||
|
|
|
|||
82
Dockerfile
82
Dockerfile
|
|
@ -1,59 +1,61 @@
|
|||
FROM python:3.11-slim
|
||||
# Use a Python image with uv pre-installed
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim AS uv
|
||||
|
||||
# Define Poetry extras to install
|
||||
ARG POETRY_EXTRAS="\
|
||||
# API \
|
||||
api \
|
||||
# Storage & Databases \
|
||||
postgres weaviate qdrant neo4j falkordb milvus kuzu chromadb \
|
||||
# Notebooks & Interactive Environments \
|
||||
notebook \
|
||||
# LLM & AI Frameworks \
|
||||
langchain llama-index gemini huggingface ollama mistral groq anthropic \
|
||||
# Evaluation & Monitoring \
|
||||
deepeval evals posthog \
|
||||
# Graph Processing & Code Analysis \
|
||||
codegraph graphiti \
|
||||
# Document Processing \
|
||||
docs"
|
||||
# Install the project into `/app`
|
||||
WORKDIR /app
|
||||
|
||||
# Enable bytecode compilation
|
||||
# ENV UV_COMPILE_BYTECODE=1
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Set build argument
|
||||
ARG DEBUG
|
||||
|
||||
# Set environment variable based on the build argument
|
||||
ENV DEBUG=${DEBUG}
|
||||
ENV PIP_NO_CACHE_DIR=true
|
||||
ENV PATH="${PATH}:/root/.poetry/bin"
|
||||
|
||||
RUN apt-get update
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
git \
|
||||
curl \
|
||||
clang \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev
|
||||
# Copy pyproject.toml and lockfile first for better caching
|
||||
COPY README.md pyproject.toml uv.lock entrypoint.sh ./
|
||||
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml poetry.lock /app/
|
||||
|
||||
RUN pip install poetry
|
||||
|
||||
# Don't create virtualenv since Docker is already isolated
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
# Install the dependencies using the defined extras
|
||||
RUN poetry install --extras "${POETRY_EXTRAS}" --no-root --without dev
|
||||
|
||||
# Set the PYTHONPATH environment variable to include the /app directory
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
COPY cognee/ /app/cognee
|
||||
# Install the project's dependencies using the lockfile and settings
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra debug --extra api --extra postgres --extra weaviate --extra qdrant --extra neo4j --extra kuzu --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-install-project --no-dev --no-editable
|
||||
|
||||
# Copy Alembic configuration
|
||||
COPY alembic.ini /app/alembic.ini
|
||||
COPY alembic/ /app/alembic
|
||||
|
||||
COPY entrypoint.sh /app/entrypoint.sh
|
||||
# Then, add the rest of the project source code and install it
|
||||
# Installing separately from its dependencies allows optimal layer caching
|
||||
COPY ./cognee /app/cognee
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra debug --extra api --extra postgres --extra weaviate --extra qdrant --extra neo4j --extra kuzu --extra llama-index --extra gemini --extra ollama --extra mistral --extra groq --extra anthropic --frozen --no-dev --no-editable
|
||||
|
||||
FROM python:3.12-slim-bookworm
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=uv /app /app
|
||||
# COPY --from=uv /app/.venv /app/.venv
|
||||
# COPY --from=uv /root/.local /root/.local
|
||||
|
||||
RUN chmod +x /app/entrypoint.sh
|
||||
|
||||
RUN sed -i 's/\r$//' /app/entrypoint.sh
|
||||
# Place executables in the environment at the front of the path
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
ENTRYPOINT ["/app/entrypoint.sh"]
|
||||
|
|
|
|||
11
README.md
11
README.md
|
|
@ -64,7 +64,6 @@ More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github
|
|||
Get started quickly with a Google Colab <a href="https://colab.research.google.com/drive/1jHbWVypDgCLwjE71GSXhRL3YxYhCZzG1?usp=sharing">notebook</a> or <a href="https://github.com/topoteretes/cognee-starter">starter repo</a>
|
||||
|
||||
|
||||
|
||||
## Contributing
|
||||
Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly appreciated**. See [`CONTRIBUTING.md`](CONTRIBUTING.md) for more information.
|
||||
|
||||
|
|
@ -130,12 +129,14 @@ Example output:
|
|||
Natural Language Processing (NLP) is a cross-disciplinary and interdisciplinary field that involves computer science and information retrieval. It focuses on the interaction between computers and human language, enabling machines to understand and process natural language.
|
||||
|
||||
```
|
||||
Graph visualization:
|
||||
<a href="https://rawcdn.githack.com/topoteretes/cognee/refs/heads/main/assets/graph_visualization.html"><img src="assets/graph_visualization.png" width="100%" alt="Graph Visualization"></a>
|
||||
Open in [browser](https://rawcdn.githack.com/topoteretes/cognee/refs/heads/main/assets/graph_visualization.html).
|
||||
|
||||
For more advanced usage, have a look at our <a href="https://docs.cognee.ai"> documentation</a>.
|
||||
### cognee UI
|
||||
|
||||
You can also cognify your files and query using cognee UI.
|
||||
|
||||
<img src="assets/cognee-ui-2.webp" width="100%" alt="Cognee UI 2"></a>
|
||||
|
||||
Try cognee UI out locally [here](https://docs.cognee.ai/how-to-guides/cognee-ui).
|
||||
|
||||
## Understand our architecture
|
||||
|
||||
|
|
|
|||
BIN
assets/cognee-ui-1.webp
Normal file
BIN
assets/cognee-ui-1.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 878 KiB |
BIN
assets/cognee-ui-2.webp
Normal file
BIN
assets/cognee-ui-2.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 936 KiB |
|
|
@ -1,128 +0,0 @@
|
|||
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<script src="https://d3js.org/d3.v5.min.js"></script>
|
||||
<style>
|
||||
body, html { margin: 0; padding: 0; width: 100%; height: 100%; overflow: hidden; background: linear-gradient(90deg, #101010, #1a1a2e); color: white; font-family: 'Inter', sans-serif; }
|
||||
|
||||
svg { width: 100vw; height: 100vh; display: block; }
|
||||
.links line { stroke: rgba(255, 255, 255, 0.4); stroke-width: 2px; }
|
||||
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
|
||||
.node-label { font-size: 5px; font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
|
||||
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<svg></svg>
|
||||
<script>
|
||||
var nodes = [{"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["text"]}, "type": "DocumentChunk", "text": "Natural language processing (NLP) is an interdisciplinary subfield of computer science and information retrieval.", "chunk_size": 34, "chunk_index": 0, "cut_type": "sentence_end", "id": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "color": "#801212", "name": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "Entity", "name": "natural language processing", "description": "An interdisciplinary subfield of computer science and information retrieval.", "ontology_valid": false, "id": "bc338a39-64d6-549a-acec-da60846dd90d", "color": "#f47710"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "EntityType", "name": "concept", "description": "concept", "ontology_valid": false, "id": "dd9713b7-dc20-5101-aad0-1c4216811147", "color": "#6510f4"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "Entity", "name": "information retrieval", "description": "The activity of obtaining information system resources that are relevant to an information need.", "ontology_valid": false, "id": "02bdab9a-0981-518c-a0d4-1684e0329447", "color": "#f47710"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "EntityType", "name": "field", "description": "field", "ontology_valid": false, "id": "0198571b-3e94-50ea-8b9f-19e3a31080c0", "color": "#6510f4"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "Entity", "name": "computer science", "description": "The study of computers and computational systems.", "ontology_valid": false, "id": "6218dbab-eb6a-5759-a864-b3419755ffe0", "color": "#f47710"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["name"]}, "type": "TextDocument", "name": "text_46d2fce36f0f7b6ebc0575e353fdba5c", "raw_data_location": "/Users/handekafkas/Documents/local-code/new-cognee/cognee/cognee/.data_storage/data/text_46d2fce36f0f7b6ebc0575e353fdba5c.txt", "external_metadata": "{}", "mime_type": "text/plain", "id": "c07949fe-5a9f-53b9-ac90-5cb48a8a4303", "color": "#D3D3D3"}, {"version": 1, "topological_rank": 0, "metadata": {"index_fields": ["text"]}, "type": "TextSummary", "text": "Natural language processing (NLP) is a cross-disciplinary area of computer science and information extraction.", "id": "9da41e72-8150-5055-9217-eea49d1bc447", "color": "#1077f4", "name": "9da41e72-8150-5055-9217-eea49d1bc447"}];
|
||||
var links = [{"source": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "target": "bc338a39-64d6-549a-acec-da60846dd90d", "relation": "contains"}, {"source": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "target": "02bdab9a-0981-518c-a0d4-1684e0329447", "relation": "contains"}, {"source": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "target": "6218dbab-eb6a-5759-a864-b3419755ffe0", "relation": "contains"}, {"source": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "target": "c07949fe-5a9f-53b9-ac90-5cb48a8a4303", "relation": "is_part_of"}, {"source": "bc338a39-64d6-549a-acec-da60846dd90d", "target": "dd9713b7-dc20-5101-aad0-1c4216811147", "relation": "is_a"}, {"source": "bc338a39-64d6-549a-acec-da60846dd90d", "target": "6218dbab-eb6a-5759-a864-b3419755ffe0", "relation": "is_a_subfield_of"}, {"source": "bc338a39-64d6-549a-acec-da60846dd90d", "target": "02bdab9a-0981-518c-a0d4-1684e0329447", "relation": "is_a_subfield_of"}, {"source": "02bdab9a-0981-518c-a0d4-1684e0329447", "target": "0198571b-3e94-50ea-8b9f-19e3a31080c0", "relation": "is_a"}, {"source": "6218dbab-eb6a-5759-a864-b3419755ffe0", "target": "0198571b-3e94-50ea-8b9f-19e3a31080c0", "relation": "is_a"}, {"source": "9da41e72-8150-5055-9217-eea49d1bc447", "target": "b5b7b6b3-3bb7-5efd-a975-5a01e0d40220", "relation": "made_from"}];
|
||||
|
||||
var svg = d3.select("svg"),
|
||||
width = window.innerWidth,
|
||||
height = window.innerHeight;
|
||||
|
||||
var container = svg.append("g");
|
||||
|
||||
var simulation = d3.forceSimulation(nodes)
|
||||
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
|
||||
.force("charge", d3.forceManyBody().strength(-275))
|
||||
.force("center", d3.forceCenter(width / 2, height / 2))
|
||||
.force("x", d3.forceX().strength(0.1).x(width / 2))
|
||||
.force("y", d3.forceY().strength(0.1).y(height / 2));
|
||||
|
||||
var link = container.append("g")
|
||||
.attr("class", "links")
|
||||
.selectAll("line")
|
||||
.data(links)
|
||||
.enter().append("line")
|
||||
.attr("stroke-width", 2);
|
||||
|
||||
var edgeLabels = container.append("g")
|
||||
.attr("class", "edge-labels")
|
||||
.selectAll("text")
|
||||
.data(links)
|
||||
.enter().append("text")
|
||||
.attr("class", "edge-label")
|
||||
.text(d => d.relation);
|
||||
|
||||
var nodeGroup = container.append("g")
|
||||
.attr("class", "nodes")
|
||||
.selectAll("g")
|
||||
.data(nodes)
|
||||
.enter().append("g");
|
||||
|
||||
var node = nodeGroup.append("circle")
|
||||
.attr("r", 13)
|
||||
.attr("fill", d => d.color)
|
||||
.call(d3.drag()
|
||||
.on("start", dragstarted)
|
||||
.on("drag", dragged)
|
||||
.on("end", dragended));
|
||||
|
||||
nodeGroup.append("text")
|
||||
.attr("class", "node-label")
|
||||
.attr("dy", 4)
|
||||
.attr("text-anchor", "middle")
|
||||
.text(d => d.name);
|
||||
|
||||
node.append("title").text(d => JSON.stringify(d));
|
||||
|
||||
simulation.on("tick", function() {
|
||||
link.attr("x1", d => d.source.x)
|
||||
.attr("y1", d => d.source.y)
|
||||
.attr("x2", d => d.target.x)
|
||||
.attr("y2", d => d.target.y);
|
||||
|
||||
edgeLabels
|
||||
.attr("x", d => (d.source.x + d.target.x) / 2)
|
||||
.attr("y", d => (d.source.y + d.target.y) / 2 - 5);
|
||||
|
||||
node.attr("cx", d => d.x)
|
||||
.attr("cy", d => d.y);
|
||||
|
||||
nodeGroup.select("text")
|
||||
.attr("x", d => d.x)
|
||||
.attr("y", d => d.y)
|
||||
.attr("dy", 4)
|
||||
.attr("text-anchor", "middle");
|
||||
});
|
||||
|
||||
svg.call(d3.zoom().on("zoom", function() {
|
||||
container.attr("transform", d3.event.transform);
|
||||
}));
|
||||
|
||||
function dragstarted(d) {
|
||||
if (!d3.event.active) simulation.alphaTarget(0.3).restart();
|
||||
d.fx = d.x;
|
||||
d.fy = d.y;
|
||||
}
|
||||
|
||||
function dragged(d) {
|
||||
d.fx = d3.event.x;
|
||||
d.fy = d3.event.y;
|
||||
}
|
||||
|
||||
function dragended(d) {
|
||||
if (!d3.event.active) simulation.alphaTarget(0);
|
||||
d.fx = null;
|
||||
d.fy = null;
|
||||
}
|
||||
|
||||
window.addEventListener("resize", function() {
|
||||
width = window.innerWidth;
|
||||
height = window.innerHeight;
|
||||
svg.attr("width", width).attr("height", height);
|
||||
simulation.force("center", d3.forceCenter(width / 2, height / 2));
|
||||
simulation.alpha(1).restart();
|
||||
});
|
||||
</script>
|
||||
|
||||
<svg style="position: fixed; bottom: 10px; right: 10px; width: 150px; height: auto; z-index: 9999;" viewBox="0 0 158 44" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.7496 4.92654C7.83308 4.92654 4.8585 7.94279 4.8585 11.3612V14.9304C4.8585 18.3488 7.83308 21.3651 11.7496 21.3651C13.6831 21.3651 15.0217 20.8121 16.9551 19.3543C18.0458 18.5499 19.5331 18.8013 20.3263 19.9072C21.1195 21.0132 20.8717 22.5213 19.781 23.3257C17.3518 25.0851 15.0217 26.2414 11.7 26.2414C5.35425 26.2414 0 21.2646 0 14.9304V11.3612C0 4.97681 5.35425 0.0502739 11.7 0.0502739C15.0217 0.0502739 17.3518 1.2065 19.781 2.96598C20.8717 3.77032 21.1195 5.27843 20.3263 6.38439C19.5331 7.49035 18.0458 7.69144 16.9551 6.93737C15.0217 5.52979 13.6831 4.92654 11.7496 4.92654ZM35.5463 4.92654C31.7289 4.92654 28.6552 8.04333 28.6552 11.8639V14.478C28.6552 18.2986 31.7289 21.4154 35.5463 21.4154C39.3141 21.4154 42.3878 18.2986 42.3878 14.478V11.8639C42.3878 8.04333 39.3141 4.92654 35.5463 4.92654ZM23.7967 11.8639C23.7967 5.32871 29.0518 0 35.5463 0C42.0408 0 47.2463 5.32871 47.2463 11.8639V14.478C47.2463 21.0132 42.0408 26.3419 35.5463 26.3419C29.0518 26.3419 23.7967 21.0635 23.7967 14.478V11.8639ZM63.3091 5.07736C59.4917 5.07736 56.418 8.19415 56.418 12.0147C56.418 15.8353 59.4917 18.9521 63.3091 18.9521C67.1265 18.9521 70.1506 15.8856 70.1506 12.0147C70.1506 8.14388 67.0769 5.07736 63.3091 5.07736ZM51.5595 11.9645C51.5595 5.42925 56.8146 0.150814 63.3091 0.150814C66.0854 0.150814 68.5642 1.10596 70.5968 2.71463L72.4311 0.904876C73.3731 -0.0502693 74.9099 -0.0502693 75.8519 0.904876C76.7938 1.86002 76.7938 3.41841 75.8519 4.37356L73.7201 6.53521C74.5629 8.19414 75.0587 10.0542 75.0587 12.0147C75.0587 18.4997 69.8532 23.8284 63.3587 23.8284C63.3091 23.8284 63.2099 23.8284 63.1603 23.8284H58.0044C57.1616 23.8284 56.4675 24.5322 56.4675 25.3868C56.4675 26.2414 57.1616 26.9452 58.0044 26.9452H64.6476H66.7794C68.5146 26.9452 70.3489 27.4479 71.7866 28.6041C73.2739 29.8106 74.2159 31.5701 74.4142 33.7317C74.7116 37.6026 72.0345 40.2166 69.8532 41.0713L63.8048 43.7859C62.5654 44.3389 61.1277 43.7859 60.6319 42.5291C60.0866 41.2723 60.6319 39.8648 61.8714 39.3118L68.0188 36.5972C68.0684 36.5972 68.118 36.5469 68.1675 36.5469C68.4154 36.4463 68.8616 36.1447 69.2087 35.6923C69.5061 35.2398 69.7044 34.7371 69.6548 34.1339C69.6053 33.229 69.2582 32.7263 68.8616 32.4247C68.4154 32.0728 67.7214 31.8214 66.8786 31.8214H58.2027C58.1531 31.8214 58.1531 31.8214 58.1035 31.8214H58.054C54.534 31.8214 51.6586 28.956 51.6586 25.3868C51.6586 23.0743 52.8485 21.0635 54.6828 19.9072C52.6997 17.7959 51.5595 15.031 51.5595 11.9645ZM90.8736 5.07736C87.0562 5.07736 83.9824 8.19415 83.9824 12.0147V23.9289C83.9824 25.2862 82.8917 26.3922 81.5532 26.3922C80.2146 26.3922 79.1239 25.2862 79.1239 23.9289V11.9645C79.1239 5.42925 84.379 0.150814 90.824 0.150814C97.2689 0.150814 102.524 5.42925 102.524 11.9645V23.8786C102.524 25.2359 101.433 26.3419 100.095 26.3419C98.7562 26.3419 97.6655 25.2359 97.6655 23.8786V11.9645C97.7647 8.14387 94.6414 5.07736 90.8736 5.07736ZM119.43 5.07736C115.513 5.07736 112.39 8.24441 112.39 12.065V14.5785C112.39 18.4494 115.513 21.5662 119.43 21.5662C120.768 21.5662 122.057 21.164 123.098 20.5105C124.238 19.8067 125.726 20.1586 126.42 21.3148C127.114 22.4711 126.767 23.9792 125.627 24.683C123.842 25.7889 121.71 26.4425 119.43 26.4425C112.885 26.4425 107.581 21.1137 107.581 14.5785V12.065C107.581 5.47952 112.935 0.201088 119.43 0.201088C125.032 0.201088 129.692 4.07194 130.931 9.3001L131.427 11.3612L121.115 15.584C119.876 16.0867 118.488 15.4834 117.942 14.2266C117.447 12.9699 118.041 11.5623 119.281 11.0596L125.478 8.54604C124.238 6.43466 122.008 5.07736 119.43 5.07736ZM146.003 5.07736C142.086 5.07736 138.963 8.24441 138.963 12.065V14.5785C138.963 18.4494 142.086 21.5662 146.003 21.5662C147.341 21.5662 148.63 21.164 149.671 20.5105C150.217 20.1586 150.663 19.8067 151.109 19.304C152.001 18.2986 153.538 18.2483 154.53 19.2034C155.521 20.1083 155.571 21.6667 154.629 22.6721C153.935 23.4262 153.092 24.13 152.2 24.683C150.415 25.7889 148.283 26.4425 146.003 26.4425C139.458 26.4425 134.154 21.1137 134.154 14.5785V12.065C134.154 5.47952 139.508 0.201088 146.003 0.201088C151.605 0.201088 156.265 4.07194 157.504 9.3001L158 11.3612L147.688 15.584C146.449 16.0867 145.061 15.4834 144.515 14.2266C144.019 12.9699 144.614 11.5623 145.854 11.0596L152.051 8.54604C150.762 6.43466 148.58 5.07736 146.003 5.07736Z" fill="white"/>
|
||||
</svg>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 365 KiB |
|
|
@ -43,7 +43,7 @@ export default function Home() {
|
|||
const onDataAdd = useCallback((dataset: { id: string }, files: File[]) => {
|
||||
return addData(dataset, files)
|
||||
.then(() => {
|
||||
showNotification("Data added successfully.", 5000);
|
||||
showNotification("Data added successfully. Please run \"Cognify\" when ready.", 5000);
|
||||
openDatasetData(dataset);
|
||||
});
|
||||
}, [showNotification])
|
||||
|
|
@ -60,6 +60,14 @@ export default function Home() {
|
|||
});
|
||||
}, [showNotification]);
|
||||
|
||||
const onCognify = useCallback(() => {
|
||||
const dataset = datasets.find((dataset) => dataset.id === selectedDataset);
|
||||
return onDatasetCognify({
|
||||
id: dataset!.id,
|
||||
name: dataset!.name,
|
||||
});
|
||||
}, [datasets, onDatasetCognify, selectedDataset]);
|
||||
|
||||
const {
|
||||
value: isSettingsModalOpen,
|
||||
setTrue: openSettingsModal,
|
||||
|
|
@ -95,6 +103,7 @@ export default function Home() {
|
|||
datasetId={selectedDataset}
|
||||
onClose={closeDatasetData}
|
||||
onDataAdd={onDataAdd}
|
||||
onCognify={onCognify}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import cognifyDataset from '@/modules/datasets/cognifyDataset';
|
|||
|
||||
interface ConfigStepProps {
|
||||
onNext: () => void;
|
||||
dataset: { id: string }
|
||||
dataset: { name: string }
|
||||
}
|
||||
|
||||
export default function CognifyStep({ onNext, dataset }: ConfigStepProps) {
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ import { Explorer } from '@/ui/Partials';
|
|||
import { Spacer } from 'ohmy-ui';
|
||||
|
||||
interface ExploreStepProps {
|
||||
dataset: { id: string };
|
||||
dataset: { name: string };
|
||||
}
|
||||
|
||||
export default function ExploreStep({ dataset }: ExploreStepProps) {
|
||||
return (
|
||||
<Spacer horizontal="3">
|
||||
<Explorer dataset={dataset!} />
|
||||
<Explorer dataset={dataset} />
|
||||
</Spacer>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ export default function WizardPage({
|
|||
setFalse: closeSettingsModal,
|
||||
} = useBoolean(false);
|
||||
|
||||
const dataset = { id: 'main' };
|
||||
const dataset = { name: 'main' };
|
||||
|
||||
return (
|
||||
<main className={styles.main}>
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import { fetch } from '@/utils';
|
||||
|
||||
export default function cognifyDataset(dataset: { id: string, name: string }) {
|
||||
export default function cognifyDataset(dataset: { id?: string, name?: string }) {
|
||||
return fetch('/v1/cognify', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
datasets: [dataset.id],
|
||||
datasets: [dataset.id || dataset.name],
|
||||
}),
|
||||
}).then((response) => response.json());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { fetch } from '@/utils';
|
||||
|
||||
export default function getExplorationGraphUrl(dataset: { id: string }) {
|
||||
export default function getExplorationGraphUrl(dataset: { name: string }) {
|
||||
return fetch('/v1/visualize')
|
||||
.then(async (response) => {
|
||||
if (response.status !== 200) {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import {
|
|||
Text,
|
||||
UploadInput,
|
||||
CloseIcon,
|
||||
CTAButton,
|
||||
useBoolean,
|
||||
} from "ohmy-ui";
|
||||
import { fetch } from '@/utils';
|
||||
import RawDataPreview from './RawDataPreview';
|
||||
|
|
@ -28,9 +30,10 @@ interface DataViewProps {
|
|||
datasetId: string;
|
||||
onClose: () => void;
|
||||
onDataAdd: (dataset: DatasetLike, files: File[]) => void;
|
||||
onCognify: () => Promise<any>;
|
||||
}
|
||||
|
||||
export default function DataView({ datasetId, data, onClose, onDataAdd }: DataViewProps) {
|
||||
export default function DataView({ datasetId, data, onClose, onDataAdd, onCognify }: DataViewProps) {
|
||||
// const handleDataDelete = () => {};
|
||||
const [rawData, setRawData] = useState<ArrayBuffer | null>(null);
|
||||
const [selectedData, setSelectedData] = useState<Data | null>(null);
|
||||
|
|
@ -52,7 +55,19 @@ export default function DataView({ datasetId, data, onClose, onDataAdd }: DataVi
|
|||
|
||||
const handleDataAdd = (files: File[]) => {
|
||||
onDataAdd({ id: datasetId }, files);
|
||||
}
|
||||
};
|
||||
|
||||
const {
|
||||
value: isCognifyButtonDisabled,
|
||||
setTrue: disableCognifyButton,
|
||||
setFalse: enableCognifyButton,
|
||||
} = useBoolean(false);
|
||||
|
||||
const handleCognify = () => {
|
||||
disableCognifyButton();
|
||||
onCognify()
|
||||
.finally(() => enableCognifyButton());
|
||||
};
|
||||
|
||||
return (
|
||||
<Stack orientation="vertical" gap="4">
|
||||
|
|
@ -62,6 +77,11 @@ export default function DataView({ datasetId, data, onClose, onDataAdd }: DataVi
|
|||
<Text>Add data</Text>
|
||||
</UploadInput>
|
||||
</div>
|
||||
<div>
|
||||
<CTAButton disabled={isCognifyButtonDisabled} onClick={handleCognify}>
|
||||
<Text>Cognify</Text>
|
||||
</CTAButton>
|
||||
</div>
|
||||
<GhostButton hugContent onClick={onClose}>
|
||||
<CloseIcon />
|
||||
</GhostButton>
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import { getExplorationGraphUrl } from '@/modules/exploration';
|
|||
import styles from './Explorer.module.css';
|
||||
|
||||
interface ExplorerProps {
|
||||
dataset: { id: string };
|
||||
dataset: { name: string };
|
||||
className?: string;
|
||||
style?: React.CSSProperties;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,9 +28,6 @@ export default function SearchView() {
|
|||
}, []);
|
||||
|
||||
const searchOptions = [{
|
||||
value: 'INSIGHTS',
|
||||
label: 'Query insights from documents',
|
||||
}, {
|
||||
value: 'GRAPH_COMPLETION',
|
||||
label: 'Completion using Cognee\'s graph based memory',
|
||||
}, {
|
||||
|
|
@ -81,6 +78,8 @@ export default function SearchView() {
|
|||
|
||||
scrollToBottom();
|
||||
|
||||
setInputValue('');
|
||||
|
||||
const searchTypeValue = searchType.value;
|
||||
|
||||
fetch('/v1/search', {
|
||||
|
|
@ -103,10 +102,12 @@ export default function SearchView() {
|
|||
text: convertToSearchTypeOutput(systemMessage, searchTypeValue),
|
||||
},
|
||||
]);
|
||||
setInputValue('');
|
||||
|
||||
scrollToBottom();
|
||||
})
|
||||
.catch(() => {
|
||||
setInputValue(inputValue);
|
||||
});
|
||||
}, [inputValue, scrollToBottom, searchType.value]);
|
||||
|
||||
const {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import handleServerErrors from './handleServerErrors';
|
||||
|
||||
export default function fetch(url: string, options: RequestInit = {}): Promise<Response> {
|
||||
return global.fetch('http://127.0.0.1:8000/api' + url, {
|
||||
return global.fetch('http://localhost:8000/api' + url, {
|
||||
...options,
|
||||
headers: {
|
||||
...options.headers,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
[project]
|
||||
name = "cognee-mcp"
|
||||
version = "0.2.3"
|
||||
version = "0.3.0"
|
||||
description = "A MCP server project"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
dependencies = [
|
||||
"cognee[postgres,codegraph,gemini,huggingface]==0.1.39",
|
||||
# For local cognee repo usage remove comment bellow and add absolute path to cognee
|
||||
#"cognee[postgres,codegraph,gemini,huggingface] @ file:/Users/<username>/Desktop/cognee",
|
||||
"cognee[postgres,codegraph,gemini,huggingface]==0.1.40",
|
||||
"fastmcp>=1.0",
|
||||
"mcp==1.5.0",
|
||||
"uv>=0.6.3",
|
||||
]
|
||||
|
|
@ -27,5 +30,8 @@ dev = [
|
|||
"debugpy>=1.8.12",
|
||||
]
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[project.scripts]
|
||||
cognee = "src:main"
|
||||
|
|
|
|||
|
|
@ -1,253 +1,164 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import cognee
|
||||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger, get_log_file_location
|
||||
import importlib.util
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
# from PIL import Image as PILImage
|
||||
import mcp.types as types
|
||||
from mcp.server import Server, NotificationOptions
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server import FastMCP
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
mcp = Server("cognee")
|
||||
mcp = FastMCP("Cognee")
|
||||
|
||||
logger = get_logger()
|
||||
log_file = get_log_file_location()
|
||||
|
||||
|
||||
@mcp.list_tools()
|
||||
async def list_tools() -> list[types.Tool]:
|
||||
return [
|
||||
types.Tool(
|
||||
name="cognify",
|
||||
description="Cognifies text into knowledge graph",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to cognify",
|
||||
},
|
||||
"graph_model_file": {
|
||||
"type": "string",
|
||||
"description": "The path to the graph model file (Optional)",
|
||||
},
|
||||
"graph_model_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the graph model (Optional)",
|
||||
},
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="codify",
|
||||
description="Transforms codebase into knowledge graph",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"repo_path": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["repo_path"],
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="search",
|
||||
description="Searches for information in knowledge graph",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for",
|
||||
},
|
||||
"search_type": {
|
||||
"type": "string",
|
||||
"description": "The type of search to perform (e.g., INSIGHTS, CODE)",
|
||||
},
|
||||
},
|
||||
"required": ["search_query"],
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="prune",
|
||||
description="Prunes knowledge graph",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@mcp.call_tool()
|
||||
async def call_tools(name: str, arguments: dict) -> list[types.TextContent]:
|
||||
try:
|
||||
@mcp.tool()
|
||||
async def cognify(text: str, graph_model_file: str = None, graph_model_name: str = None) -> list:
|
||||
async def cognify_task(
|
||||
text: str, graph_model_file: str = None, graph_model_name: str = None
|
||||
) -> str:
|
||||
"""Build knowledge graph from the input text"""
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
log_file = get_log_file_location()
|
||||
logger.info("Cognify process starting.")
|
||||
if graph_model_file and graph_model_name:
|
||||
graph_model = load_class(graph_model_file, graph_model_name)
|
||||
else:
|
||||
graph_model = KnowledgeGraph
|
||||
|
||||
if name == "cognify":
|
||||
asyncio.create_task(
|
||||
cognify(
|
||||
text=arguments["text"],
|
||||
graph_model_file=arguments.get("graph_model_file"),
|
||||
graph_model_name=arguments.get("graph_model_name"),
|
||||
)
|
||||
)
|
||||
await cognee.add(text)
|
||||
|
||||
text = (
|
||||
f"Background process launched due to MCP timeout limitations.\n"
|
||||
f"Average completion time is around 4 minutes.\n"
|
||||
f"For current cognify status you can check the log file at: {log_file}"
|
||||
)
|
||||
try:
|
||||
await cognee.cognify(graph_model=graph_model)
|
||||
logger.info("Cognify process finished.")
|
||||
except Exception as e:
|
||||
logger.error("Cognify process failed.")
|
||||
raise ValueError(f"Failed to cognify: {str(e)}")
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=text,
|
||||
)
|
||||
]
|
||||
if name == "codify":
|
||||
asyncio.create_task(codify(arguments.get("repo_path")))
|
||||
|
||||
text = (
|
||||
f"Background process launched due to MCP timeout limitations.\n"
|
||||
f"Average completion time is around 4 minutes.\n"
|
||||
f"For current codify status you can check the log file at: {log_file}"
|
||||
)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=text,
|
||||
)
|
||||
]
|
||||
elif name == "search":
|
||||
search_results = await search(arguments["search_query"], arguments["search_type"])
|
||||
|
||||
return [types.TextContent(type="text", text=search_results)]
|
||||
elif name == "prune":
|
||||
await prune()
|
||||
|
||||
return [types.TextContent(type="text", text="Pruned")]
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling tool '{name}': {str(e)}")
|
||||
return [types.TextContent(type="text", text=f"Error calling tool '{name}': {str(e)}")]
|
||||
|
||||
|
||||
async def cognify(text: str, graph_model_file: str = None, graph_model_name: str = None) -> str:
|
||||
"""Build knowledge graph from the input text"""
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
# As cognify is an async background job the output had to be redirected again.
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Cognify process starting.")
|
||||
if graph_model_file and graph_model_name:
|
||||
graph_model = load_class(graph_model_file, graph_model_name)
|
||||
else:
|
||||
graph_model = KnowledgeGraph
|
||||
|
||||
await cognee.add(text)
|
||||
|
||||
try:
|
||||
await cognee.cognify(graph_model=graph_model)
|
||||
logger.info("Cognify process finished.")
|
||||
except Exception as e:
|
||||
logger.error("Cognify process failed.")
|
||||
raise ValueError(f"Failed to cognify: {str(e)}")
|
||||
|
||||
|
||||
async def codify(repo_path: str):
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
# As codify is an async background job the output had to be redirected again.
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Codify process starting.")
|
||||
results = []
|
||||
async for result in run_code_graph_pipeline(repo_path, False):
|
||||
results.append(result)
|
||||
logger.info(result)
|
||||
if all(results):
|
||||
logger.info("Codify process finished succesfully.")
|
||||
else:
|
||||
logger.info("Codify process failed.")
|
||||
|
||||
|
||||
async def search(search_query: str, search_type: str) -> str:
|
||||
"""Search the knowledge graph"""
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType[search_type.upper()], query_text=search_query
|
||||
asyncio.create_task(
|
||||
cognify_task(
|
||||
text=text,
|
||||
graph_model_file=graph_model_file,
|
||||
graph_model_name=graph_model_name,
|
||||
)
|
||||
)
|
||||
|
||||
if search_type.upper() == "CODE":
|
||||
return json.dumps(search_results, cls=JSONEncoder)
|
||||
elif search_type.upper() == "GRAPH_COMPLETION" or search_type.upper() == "RAG_COMPLETION":
|
||||
return search_results[0]
|
||||
elif search_type.upper() == "CHUNKS":
|
||||
return str(search_results)
|
||||
elif search_type.upper() == "INSIGHTS":
|
||||
results = retrieved_edges_to_string(search_results)
|
||||
return results
|
||||
else:
|
||||
return str(search_results)
|
||||
text = (
|
||||
f"Background process launched due to MCP timeout limitations.\n"
|
||||
f"To check current cognify status use the cognify_status tool\n"
|
||||
f"or check the log file at: {log_file}"
|
||||
)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=text,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def prune():
|
||||
"""Reset the knowledge graph"""
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
@mcp.tool()
|
||||
async def codify(repo_path: str) -> list:
|
||||
async def codify_task(repo_path: str):
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Codify process starting.")
|
||||
results = []
|
||||
async for result in run_code_graph_pipeline(repo_path, False):
|
||||
results.append(result)
|
||||
logger.info(result)
|
||||
if all(results):
|
||||
logger.info("Codify process finished succesfully.")
|
||||
else:
|
||||
logger.info("Codify process failed.")
|
||||
|
||||
asyncio.create_task(codify_task(repo_path))
|
||||
|
||||
text = (
|
||||
f"Background process launched due to MCP timeout limitations.\n"
|
||||
f"To check current codify status use the codify_status tool\n"
|
||||
f"or you can check the log file at: {log_file}"
|
||||
)
|
||||
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=text,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
from mcp.server.stdio import stdio_server
|
||||
|
||||
logger.info("Cognee MCP server started...")
|
||||
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await mcp.run(
|
||||
read_stream=read_stream,
|
||||
write_stream=write_stream,
|
||||
initialization_options=InitializationOptions(
|
||||
server_name="cognee",
|
||||
server_version="0.1.0",
|
||||
capabilities=mcp.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
),
|
||||
raise_exceptions=True,
|
||||
@mcp.tool()
|
||||
async def search(search_query: str, search_type: str) -> list:
|
||||
async def search_task(search_query: str, search_type: str) -> str:
|
||||
"""Search the knowledge graph"""
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType[search_type.upper()], query_text=search_query
|
||||
)
|
||||
|
||||
logger.info("Cognee MCP server closed.")
|
||||
if search_type.upper() == "CODE":
|
||||
return json.dumps(search_results, cls=JSONEncoder)
|
||||
elif (
|
||||
search_type.upper() == "GRAPH_COMPLETION" or search_type.upper() == "RAG_COMPLETION"
|
||||
):
|
||||
return search_results[0]
|
||||
elif search_type.upper() == "CHUNKS":
|
||||
return str(search_results)
|
||||
elif search_type.upper() == "INSIGHTS":
|
||||
results = retrieved_edges_to_string(search_results)
|
||||
return results
|
||||
else:
|
||||
return str(search_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Server failed to start: {str(e)}", exc_info=True)
|
||||
raise
|
||||
search_results = await search_task(search_query, search_type)
|
||||
return [types.TextContent(type="text", text=search_results)]
|
||||
|
||||
|
||||
# async def visualize() -> Image:
|
||||
# """Visualize the knowledge graph"""
|
||||
# try:
|
||||
# image_path = await cognee.visualize_graph()
|
||||
@mcp.tool()
|
||||
async def prune():
|
||||
"""Reset the knowledge graph"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
return [types.TextContent(type="text", text="Pruned")]
|
||||
|
||||
# img = PILImage.open(image_path)
|
||||
# return Image(data=img.tobytes(), format="png")
|
||||
# except (FileNotFoundError, IOError, ValueError) as e:
|
||||
# raise ValueError(f"Failed to create visualization: {str(e)}")
|
||||
|
||||
@mcp.tool()
|
||||
async def cognify_status():
|
||||
"""Get status of cognify pipeline"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
user = await get_default_user()
|
||||
status = await get_pipeline_status(
|
||||
[await get_unique_dataset_id("main_dataset", user)], "cognify_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def codify_status():
|
||||
"""Get status of codify pipeline"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
user = await get_default_user()
|
||||
status = await get_pipeline_status(
|
||||
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
|
||||
|
||||
def node_to_string(node):
|
||||
|
|
@ -265,6 +176,7 @@ def retrieved_edges_to_string(search_results):
|
|||
relationship_type = edge["relationship_name"]
|
||||
edge_str = f"{node_to_string(node1)} {relationship_type} {node_to_string(node2)}"
|
||||
edge_strings.append(edge_str)
|
||||
|
||||
return "\n".join(edge_strings)
|
||||
|
||||
|
||||
|
|
@ -279,32 +191,31 @@ def load_class(model_file, model_name):
|
|||
return model_class
|
||||
|
||||
|
||||
# def get_freshest_png(directory: str) -> Image:
|
||||
# if not os.path.exists(directory):
|
||||
# raise FileNotFoundError(f"Directory {directory} does not exist")
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# # List all files in 'directory' that end with .png
|
||||
# files = [f for f in os.listdir(directory) if f.endswith(".png")]
|
||||
# if not files:
|
||||
# raise FileNotFoundError("No PNG files found in the given directory.")
|
||||
parser.add_argument(
|
||||
"--transport",
|
||||
choices=["sse", "stdio"],
|
||||
default="stdio",
|
||||
help="Transport to use for communication with the client. (default: stdio)",
|
||||
)
|
||||
|
||||
# # Sort by integer value of the filename (minus the '.png')
|
||||
# # Example filename: 1673185134.png -> integer 1673185134
|
||||
# try:
|
||||
# files_sorted = sorted(files, key=lambda x: int(x.replace(".png", "")))
|
||||
# except ValueError as e:
|
||||
# raise ValueError("Invalid PNG filename format. Expected timestamp format.") from e
|
||||
args = parser.parse_args()
|
||||
|
||||
# # The "freshest" file has the largest timestamp
|
||||
# freshest_filename = files_sorted[-1]
|
||||
# freshest_path = os.path.join(directory, freshest_filename)
|
||||
logger.info(f"Starting MCP server with transport: {args.transport}")
|
||||
if args.transport == "stdio":
|
||||
await mcp.run_stdio_async()
|
||||
elif args.transport == "sse":
|
||||
logger.info(
|
||||
f"Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}"
|
||||
)
|
||||
await mcp.run_sse_async()
|
||||
|
||||
# # Open the image with PIL and return the PIL Image object
|
||||
# try:
|
||||
# return PILImage.open(freshest_path)
|
||||
# except (IOError, OSError) as e:
|
||||
# raise IOError(f"Failed to open PNG file {freshest_path}") from e
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize and run the server
|
||||
asyncio.run(main())
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Cognee MCP server: {str(e)}")
|
||||
raise
|
||||
|
|
|
|||
4493
cognee-mcp/uv.lock
generated
4493
cognee-mcp/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Union, BinaryIO, List, Optional
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
from cognee.modules.pipelines import Task
|
||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines import cognee_pipeline
|
||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||
|
||||
|
||||
async def add(
|
||||
|
|
|
|||
|
|
@ -1,19 +1,19 @@
|
|||
import os
|
||||
import pathlib
|
||||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
|
||||
from cognee.api.v1.search import SearchType, search
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool
|
||||
from cognee.shared.utils import render_graph
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.ingestion import ingest_data
|
||||
from cognee.tasks.repo_processor import get_non_py_files, get_repo_file_dependencies
|
||||
|
|
@ -22,11 +22,7 @@ from cognee.tasks.storage import add_data_points
|
|||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens
|
||||
|
||||
monitoring = get_base_config().monitoring_tool
|
||||
|
||||
if monitoring == MonitoringTool.LANGFUSE:
|
||||
from langfuse.decorators import observe
|
||||
|
||||
observe = get_observe()
|
||||
|
||||
logger = get_logger("code_graph_pipeline")
|
||||
|
||||
|
|
@ -69,7 +65,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
|
|||
),
|
||||
]
|
||||
|
||||
dataset_id = uuid5(NAMESPACE_OID, "codebase")
|
||||
dataset_id = await get_unique_dataset_id("codebase", user)
|
||||
|
||||
if include_docs:
|
||||
non_code_pipeline_run = run_tasks(
|
||||
|
|
|
|||
|
|
@ -34,7 +34,9 @@ async def cognify(
|
|||
):
|
||||
tasks = await get_default_tasks(user, graph_model, chunker, chunk_size, ontology_file_path)
|
||||
|
||||
return await cognee_pipeline(tasks=tasks, datasets=datasets, user=user)
|
||||
return await cognee_pipeline(
|
||||
tasks=tasks, datasets=datasets, user=user, pipeline_name="cognify_pipeline"
|
||||
)
|
||||
|
||||
|
||||
async def get_default_tasks( # TODO: Find out a better way to do this (Boris's comment)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class datasets:
|
|||
|
||||
@staticmethod
|
||||
async def get_status(dataset_ids: list[UUID]) -> dict:
|
||||
return await get_pipeline_status(dataset_ids)
|
||||
return await get_pipeline_status(dataset_ids, pipeline_name="cognify_pipeline")
|
||||
|
||||
@staticmethod
|
||||
async def delete_dataset(dataset_id: str):
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from cognee.root_dir import get_absolute_path
|
||||
from cognee.shared.data_models import MonitoringTool
|
||||
from cognee.modules.observability.observers import Observer
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class BaseConfig(BaseSettings):
|
||||
data_root_directory: str = get_absolute_path(".data_storage")
|
||||
monitoring_tool: object = MonitoringTool.LANGFUSE
|
||||
monitoring_tool: object = Observer.LANGFUSE
|
||||
graphistry_username: Optional[str] = os.getenv("GRAPHISTRY_USERNAME")
|
||||
graphistry_password: Optional[str] = os.getenv("GRAPHISTRY_PASSWORD")
|
||||
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||
|
|
|
|||
|
|
@ -12,13 +12,22 @@ class CogneeApiError(Exception):
|
|||
message: str = "Service is unavailable.",
|
||||
name: str = "Cognee",
|
||||
status_code=status.HTTP_418_IM_A_TEAPOT,
|
||||
log=True,
|
||||
log_level="ERROR",
|
||||
):
|
||||
self.message = message
|
||||
self.name = name
|
||||
self.status_code = status_code
|
||||
|
||||
# Automatically log the exception details
|
||||
logger.error(f"{self.name}: {self.message} (Status code: {self.status_code})")
|
||||
if log and (log_level == "ERROR"):
|
||||
logger.error(f"{self.name}: {self.message} (Status code: {self.status_code})")
|
||||
elif log and (log_level == "WARNING"):
|
||||
logger.warning(f"{self.name}: {self.message} (Status code: {self.status_code})")
|
||||
elif log and (log_level == "INFO"):
|
||||
logger.info(f"{self.name}: {self.message} (Status code: {self.status_code})")
|
||||
elif log and (log_level == "DEBUG"):
|
||||
logger.debug(f"{self.name}: {self.message} (Status code: {self.status_code})")
|
||||
|
||||
super().__init__(self.message, self.name)
|
||||
|
||||
|
|
|
|||
|
|
@ -67,6 +67,18 @@ def create_graph_engine(
|
|||
|
||||
return KuzuAdapter(db_path=graph_file_path)
|
||||
|
||||
elif graph_database_provider == "memgraph":
|
||||
if not (graph_database_url and graph_database_username and graph_database_password):
|
||||
raise EnvironmentError("Missing required Memgraph credentials.")
|
||||
|
||||
from .memgraph.memgraph_adapter import MemgraphAdapter
|
||||
|
||||
return MemgraphAdapter(
|
||||
graph_database_url=graph_database_url,
|
||||
graph_database_username=graph_database_username,
|
||||
graph_database_password=graph_database_password,
|
||||
)
|
||||
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
|
||||
graph_client = NetworkXAdapter(filename=graph_file_path)
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ def record_graph_changes(func):
|
|||
session.add(relationship)
|
||||
await session.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding relationship: {e}")
|
||||
logger.debug(f"Error adding relationship: {e}")
|
||||
await session.rollback()
|
||||
continue
|
||||
|
||||
|
|
@ -78,14 +78,14 @@ def record_graph_changes(func):
|
|||
session.add(relationship)
|
||||
await session.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding relationship: {e}")
|
||||
logger.debug(f"Error adding relationship: {e}")
|
||||
await session.rollback()
|
||||
continue
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error committing session: {e}")
|
||||
logger.debug(f"Error committing session: {e}")
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,690 @@
|
|||
"""Memgraph Adapter for Graph Database"""
|
||||
|
||||
import json
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Optional, Any, List, Dict
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import UUID
|
||||
from neo4j import AsyncSession
|
||||
from neo4j import AsyncGraphDatabase
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
|
||||
logger = get_logger("MemgraphAdapter", level=ERROR)
|
||||
|
||||
|
||||
class MemgraphAdapter(GraphDBInterface):
|
||||
def __init__(
|
||||
self,
|
||||
graph_database_url: str,
|
||||
graph_database_username: str,
|
||||
graph_database_password: str,
|
||||
driver: Optional[Any] = None,
|
||||
):
|
||||
self.driver = driver or AsyncGraphDatabase.driver(
|
||||
graph_database_url,
|
||||
auth=(graph_database_username, graph_database_password),
|
||||
max_connection_lifetime=120,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self) -> AsyncSession:
|
||||
async with self.driver.session() as session:
|
||||
yield session
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
async with self.get_session() as session:
|
||||
result = await session.run(query, params)
|
||||
data = await result.data()
|
||||
return data
|
||||
except Neo4jError as error:
|
||||
logger.error("Memgraph query error: %s", error, exc_info=True)
|
||||
raise error
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
results = await self.query(
|
||||
"""
|
||||
MATCH (n)
|
||||
WHERE n.id = $node_id
|
||||
RETURN COUNT(n) > 0 AS node_exists
|
||||
""",
|
||||
{"node_id": node_id},
|
||||
)
|
||||
return results[0]["node_exists"] if len(results) > 0 else False
|
||||
|
||||
async def add_node(self, node: DataPoint):
|
||||
serialized_properties = self.serialize_properties(node.model_dump())
|
||||
|
||||
query = """
|
||||
MERGE (node {id: $node_id})
|
||||
ON CREATE SET node:$node_label, node += $properties, node.updated_at = timestamp()
|
||||
ON MATCH SET node:$node_label, node += $properties, node.updated_at = timestamp()
|
||||
RETURN ID(node) AS internal_id,node.id AS nodeId
|
||||
"""
|
||||
|
||||
params = {
|
||||
"node_id": str(node.id),
|
||||
"node_label": type(node).__name__,
|
||||
"properties": serialized_properties,
|
||||
}
|
||||
return await self.query(query, params)
|
||||
|
||||
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
||||
query = """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n {id: node.node_id})
|
||||
ON CREATE SET n:node.label, n += node.properties, n.updated_at = timestamp()
|
||||
ON MATCH SET n:node.label, n += node.properties, n.updated_at = timestamp()
|
||||
RETURN ID(n) AS internal_id, n.id AS nodeId
|
||||
"""
|
||||
|
||||
nodes = [
|
||||
{
|
||||
"node_id": str(node.id),
|
||||
"label": type(node).__name__,
|
||||
"properties": self.serialize_properties(node.model_dump()),
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
results = await self.query(query, dict(nodes=nodes))
|
||||
return results
|
||||
|
||||
async def extract_node(self, node_id: str):
|
||||
results = await self.extract_nodes([node_id])
|
||||
|
||||
return results[0] if len(results) > 0 else None
|
||||
|
||||
async def extract_nodes(self, node_ids: List[str]):
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node {id: id})
|
||||
RETURN node"""
|
||||
|
||||
params = {"node_ids": node_ids}
|
||||
|
||||
results = await self.query(query, params)
|
||||
|
||||
return [result["node"] for result in results]
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
sanitized_id = node_id.replace(":", "_")
|
||||
|
||||
query = "MATCH (node: {{id: $node_id}}) DETACH DELETE node"
|
||||
params = {"node_id": sanitized_id}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def delete_nodes(self, node_ids: list[str]) -> None:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node {id: id})
|
||||
DETACH DELETE node"""
|
||||
|
||||
params = {"node_ids": node_ids}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
|
||||
query = """
|
||||
MATCH (from_node)-[relationship]->(to_node)
|
||||
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
|
||||
RETURN COUNT(relationship) > 0 AS edge_exists
|
||||
"""
|
||||
|
||||
params = {
|
||||
"from_node_id": str(from_node),
|
||||
"to_node_id": str(to_node),
|
||||
"edge_label": edge_label,
|
||||
}
|
||||
|
||||
records = await self.query(query, params)
|
||||
return records[0]["edge_exists"] if records else False
|
||||
|
||||
async def has_edges(self, edges):
|
||||
query = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (a)-[r]->(b)
|
||||
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
|
||||
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
|
||||
"""
|
||||
|
||||
try:
|
||||
params = {
|
||||
"edges": [
|
||||
{
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
}
|
||||
for edge in edges
|
||||
],
|
||||
}
|
||||
|
||||
results = await self.query(query, params)
|
||||
return [result["edge_exists"] for result in results]
|
||||
except Neo4jError as error:
|
||||
logger.error("Memgraph query error: %s", error, exc_info=True)
|
||||
raise error
|
||||
|
||||
async def add_edge(
|
||||
self,
|
||||
from_node: UUID,
|
||||
to_node: UUID,
|
||||
relationship_name: str,
|
||||
edge_properties: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
serialized_properties = self.serialize_properties(edge_properties or {})
|
||||
|
||||
query = dedent(
|
||||
f"""\
|
||||
MATCH (from_node {{id: $from_node}}),
|
||||
(to_node {{id: $to_node}})
|
||||
MERGE (from_node)-[r:{relationship_name}]->(to_node)
|
||||
ON CREATE SET r += $properties, r.updated_at = timestamp()
|
||||
ON MATCH SET r += $properties, r.updated_at = timestamp()
|
||||
RETURN r
|
||||
"""
|
||||
)
|
||||
|
||||
params = {
|
||||
"from_node": str(from_node),
|
||||
"to_node": str(to_node),
|
||||
"relationship_name": relationship_name,
|
||||
"properties": serialized_properties,
|
||||
}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
||||
query = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (from_node {id: edge.from_node})
|
||||
MATCH (to_node {id: edge.to_node})
|
||||
CALL merge.relationship(
|
||||
from_node,
|
||||
edge.relationship_name,
|
||||
{
|
||||
source_node_id: edge.from_node,
|
||||
target_node_id: edge.to_node
|
||||
},
|
||||
edge.properties,
|
||||
to_node,
|
||||
{}
|
||||
) YIELD rel
|
||||
RETURN rel"""
|
||||
|
||||
edges = [
|
||||
{
|
||||
"from_node": str(edge[0]),
|
||||
"to_node": str(edge[1]),
|
||||
"relationship_name": edge[2],
|
||||
"properties": {
|
||||
**(edge[3] if edge[3] else {}),
|
||||
"source_node_id": str(edge[0]),
|
||||
"target_node_id": str(edge[1]),
|
||||
},
|
||||
}
|
||||
for edge in edges
|
||||
]
|
||||
|
||||
try:
|
||||
results = await self.query(query, dict(edges=edges))
|
||||
return results
|
||||
except Neo4jError as error:
|
||||
logger.error("Memgraph query error: %s", error, exc_info=True)
|
||||
raise error
|
||||
|
||||
async def get_edges(self, node_id: str):
|
||||
query = """
|
||||
MATCH (n {id: $node_id})-[r]-(m)
|
||||
RETURN n, r, m
|
||||
"""
|
||||
|
||||
results = await self.query(query, dict(node_id=node_id))
|
||||
|
||||
return [
|
||||
(result["n"]["id"], result["m"]["id"], {"relationship_name": result["r"][1]})
|
||||
for result in results
|
||||
]
|
||||
|
||||
async def get_disconnected_nodes(self) -> list[str]:
|
||||
query = """
|
||||
// Step 1: Collect all nodes
|
||||
MATCH (n)
|
||||
WITH COLLECT(n) AS nodes
|
||||
|
||||
// Step 2: Find all connected components
|
||||
WITH nodes
|
||||
CALL {
|
||||
WITH nodes
|
||||
UNWIND nodes AS startNode
|
||||
MATCH path = (startNode)-[*]-(connectedNode)
|
||||
WITH startNode, COLLECT(DISTINCT connectedNode) AS component
|
||||
RETURN component
|
||||
}
|
||||
|
||||
// Step 3: Aggregate components
|
||||
WITH COLLECT(component) AS components
|
||||
|
||||
// Step 4: Identify the largest connected component
|
||||
UNWIND components AS component
|
||||
WITH component
|
||||
ORDER BY SIZE(component) DESC
|
||||
LIMIT 1
|
||||
WITH component AS largestComponent
|
||||
|
||||
// Step 5: Find nodes not in the largest connected component
|
||||
MATCH (n)
|
||||
WHERE NOT n IN largestComponent
|
||||
RETURN COLLECT(ID(n)) AS ids
|
||||
"""
|
||||
|
||||
results = await self.query(query)
|
||||
return results[0]["ids"] if len(results) > 0 else []
|
||||
|
||||
async def get_predecessors(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||
if edge_label is not None:
|
||||
query = """
|
||||
MATCH (node)<-[r]-(predecessor)
|
||||
WHERE node.id = $node_id AND type(r) = $edge_label
|
||||
RETURN predecessor
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id=node_id,
|
||||
edge_label=edge_label,
|
||||
),
|
||||
)
|
||||
|
||||
return [result["predecessor"] for result in results]
|
||||
else:
|
||||
query = """
|
||||
MATCH (node)<-[r]-(predecessor)
|
||||
WHERE node.id = $node_id
|
||||
RETURN predecessor
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id=node_id,
|
||||
),
|
||||
)
|
||||
|
||||
return [result["predecessor"] for result in results]
|
||||
|
||||
async def get_successors(self, node_id: str, edge_label: str = None) -> list[str]:
|
||||
if edge_label is not None:
|
||||
query = """
|
||||
MATCH (node)-[r]->(successor)
|
||||
WHERE node.id = $node_id AND type(r) = $edge_label
|
||||
RETURN successor
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id=node_id,
|
||||
edge_label=edge_label,
|
||||
),
|
||||
)
|
||||
|
||||
return [result["successor"] for result in results]
|
||||
else:
|
||||
query = """
|
||||
MATCH (node)-[r]->(successor)
|
||||
WHERE node.id = $node_id
|
||||
RETURN successor
|
||||
"""
|
||||
|
||||
results = await self.query(
|
||||
query,
|
||||
dict(
|
||||
node_id=node_id,
|
||||
),
|
||||
)
|
||||
|
||||
return [result["successor"] for result in results]
|
||||
|
||||
async def get_neighbours(self, node_id: str) -> List[Dict[str, Any]]:
|
||||
predecessors, successors = await asyncio.gather(
|
||||
self.get_predecessors(node_id), self.get_successors(node_id)
|
||||
)
|
||||
|
||||
return predecessors + successors
|
||||
|
||||
async def get_connections(self, node_id: UUID) -> list:
|
||||
predecessors_query = """
|
||||
MATCH (node)<-[relation]-(neighbour)
|
||||
WHERE node.id = $node_id
|
||||
RETURN neighbour, relation, node
|
||||
"""
|
||||
successors_query = """
|
||||
MATCH (node)-[relation]->(neighbour)
|
||||
WHERE node.id = $node_id
|
||||
RETURN node, relation, neighbour
|
||||
"""
|
||||
|
||||
predecessors, successors = await asyncio.gather(
|
||||
self.query(predecessors_query, dict(node_id=str(node_id))),
|
||||
self.query(successors_query, dict(node_id=str(node_id))),
|
||||
)
|
||||
|
||||
connections = []
|
||||
|
||||
for neighbour in predecessors:
|
||||
neighbour = neighbour["relation"]
|
||||
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
||||
|
||||
for neighbour in successors:
|
||||
neighbour = neighbour["relation"]
|
||||
connections.append((neighbour[0], {"relationship_name": neighbour[1]}, neighbour[2]))
|
||||
|
||||
return connections
|
||||
|
||||
async def remove_connection_to_predecessors_of(
|
||||
self, node_ids: list[str], edge_label: str
|
||||
) -> None:
|
||||
query = f"""
|
||||
UNWIND $node_ids AS nid
|
||||
MATCH (node {id: nid})-[r]->(predecessor)
|
||||
WHERE type(r) = $edge_label
|
||||
DELETE r;
|
||||
"""
|
||||
|
||||
params = {"node_ids": node_ids, "edge_label": edge_label}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def remove_connection_to_successors_of(
|
||||
self, node_ids: list[str], edge_label: str
|
||||
) -> None:
|
||||
query = f"""
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (node:`{id}`)<-[r:{edge_label}]-(successor)
|
||||
DELETE r;
|
||||
"""
|
||||
|
||||
params = {"node_ids": node_ids}
|
||||
|
||||
return await self.query(query, params)
|
||||
|
||||
async def delete_graph(self):
|
||||
query = """MATCH (node)
|
||||
DETACH DELETE node;"""
|
||||
|
||||
return await self.query(query)
|
||||
|
||||
def serialize_properties(self, properties=dict()):
|
||||
serialized_properties = {}
|
||||
|
||||
for property_key, property_value in properties.items():
|
||||
if isinstance(property_value, UUID):
|
||||
serialized_properties[property_key] = str(property_value)
|
||||
continue
|
||||
|
||||
if isinstance(property_value, dict):
|
||||
serialized_properties[property_key] = json.dumps(property_value, cls=JSONEncoder)
|
||||
continue
|
||||
|
||||
serialized_properties[property_key] = property_value
|
||||
|
||||
return serialized_properties
|
||||
|
||||
async def get_model_independent_graph_data(self):
|
||||
query_nodes = "MATCH (n) RETURN collect(n) AS nodes"
|
||||
nodes = await self.query(query_nodes)
|
||||
|
||||
query_edges = "MATCH (n)-[r]->(m) RETURN collect([n, r, m]) AS elements"
|
||||
edges = await self.query(query_edges)
|
||||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_graph_data(self):
|
||||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||||
|
||||
result = await self.query(query)
|
||||
|
||||
nodes = [
|
||||
(
|
||||
record["id"],
|
||||
record["properties"],
|
||||
)
|
||||
for record in result
|
||||
]
|
||||
|
||||
query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||
"""
|
||||
result = await self.query(query)
|
||||
edges = [
|
||||
(
|
||||
record["properties"]["source_node_id"],
|
||||
record["properties"]["target_node_id"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
)
|
||||
for record in result
|
||||
]
|
||||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_filtered_graph_data(self, attribute_filters):
|
||||
"""
|
||||
Fetches nodes and relationships filtered by specified attribute values.
|
||||
|
||||
Args:
|
||||
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
|
||||
Example: [{"community": ["1", "2"]}]
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing two lists: nodes and edges.
|
||||
"""
|
||||
where_clauses = []
|
||||
for attribute, values in attribute_filters[0].items():
|
||||
values_str = ", ".join(
|
||||
f"'{value}'" if isinstance(value, str) else str(value) for value in values
|
||||
)
|
||||
where_clauses.append(f"n.{attribute} IN [{values_str}]")
|
||||
|
||||
where_clause = " AND ".join(where_clauses)
|
||||
|
||||
query_nodes = f"""
|
||||
MATCH (n)
|
||||
WHERE {where_clause}
|
||||
RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties
|
||||
"""
|
||||
result_nodes = await self.query(query_nodes)
|
||||
|
||||
nodes = [
|
||||
(
|
||||
record["id"],
|
||||
record["properties"],
|
||||
)
|
||||
for record in result_nodes
|
||||
]
|
||||
|
||||
query_edges = f"""
|
||||
MATCH (n)-[r]->(m)
|
||||
WHERE {where_clause} AND {where_clause.replace("n.", "m.")}
|
||||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||
"""
|
||||
result_edges = await self.query(query_edges)
|
||||
|
||||
edges = [
|
||||
(
|
||||
record["source"],
|
||||
record["target"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
)
|
||||
for record in result_edges
|
||||
]
|
||||
|
||||
return (nodes, edges)
|
||||
|
||||
async def get_node_labels_string(self):
|
||||
node_labels_query = """
|
||||
MATCH (n)
|
||||
WITH DISTINCT labels(n) AS labelList
|
||||
UNWIND labelList AS label
|
||||
RETURN collect(DISTINCT label) AS labels;
|
||||
"""
|
||||
node_labels_result = await self.query(node_labels_query)
|
||||
node_labels = node_labels_result[0]["labels"] if node_labels_result else []
|
||||
|
||||
if not node_labels:
|
||||
raise ValueError("No node labels found in the database")
|
||||
|
||||
node_labels_str = "[" + ", ".join(f"'{label}'" for label in node_labels) + "]"
|
||||
return node_labels_str
|
||||
|
||||
async def get_relationship_labels_string(self):
|
||||
relationship_types_query = (
|
||||
"MATCH ()-[r]->() RETURN collect(DISTINCT type(r)) AS relationships;"
|
||||
)
|
||||
relationship_types_result = await self.query(relationship_types_query)
|
||||
relationship_types = (
|
||||
relationship_types_result[0]["relationships"] if relationship_types_result else []
|
||||
)
|
||||
|
||||
if not relationship_types:
|
||||
raise ValueError("No relationship types found in the database.")
|
||||
|
||||
relationship_types_undirected_str = (
|
||||
"{"
|
||||
+ ", ".join(f"{rel}" + ": {orientation: 'UNDIRECTED'}" for rel in relationship_types)
|
||||
+ "}"
|
||||
)
|
||||
return relationship_types_undirected_str
|
||||
|
||||
async def get_graph_metrics(self, include_optional=False):
|
||||
"""For the definition of these metrics, please refer to
|
||||
https://docs.cognee.ai/core_concepts/graph_generation/descriptive_metrics"""
|
||||
|
||||
try:
|
||||
# Basic metrics
|
||||
node_count = await self.query("MATCH (n) RETURN count(n)")
|
||||
edge_count = await self.query("MATCH ()-[r]->() RETURN count(r)")
|
||||
num_nodes = node_count[0][0] if node_count else 0
|
||||
num_edges = edge_count[0][0] if edge_count else 0
|
||||
|
||||
# Calculate mandatory metrics
|
||||
mandatory_metrics = {
|
||||
"num_nodes": num_nodes,
|
||||
"num_edges": num_edges,
|
||||
"mean_degree": (2 * num_edges) / num_nodes if num_nodes > 0 else 0,
|
||||
"edge_density": (num_edges) / (num_nodes * (num_nodes - 1)) if num_nodes > 1 else 0,
|
||||
}
|
||||
|
||||
# Calculate connected components
|
||||
components_query = """
|
||||
MATCH (n:Node)
|
||||
WITH n.id AS node_id
|
||||
MATCH path = (n)-[:EDGE*0..]-()
|
||||
WITH COLLECT(DISTINCT node_id) AS component
|
||||
RETURN COLLECT(component) AS components
|
||||
"""
|
||||
components_result = await self.query(components_query)
|
||||
component_sizes = (
|
||||
[len(comp) for comp in components_result[0][0]] if components_result else []
|
||||
)
|
||||
|
||||
mandatory_metrics.update(
|
||||
{
|
||||
"num_connected_components": len(component_sizes),
|
||||
"sizes_of_connected_components": component_sizes,
|
||||
}
|
||||
)
|
||||
|
||||
if include_optional:
|
||||
# Self-loops
|
||||
self_loops_query = """
|
||||
MATCH (n:Node)-[r:EDGE]->(n)
|
||||
RETURN COUNT(r)
|
||||
"""
|
||||
self_loops = await self.query(self_loops_query)
|
||||
num_selfloops = self_loops[0][0] if self_loops else 0
|
||||
|
||||
# Shortest paths (simplified for Kuzu)
|
||||
paths_query = """
|
||||
MATCH (n:Node), (m:Node)
|
||||
WHERE n.id < m.id
|
||||
MATCH path = (n)-[:EDGE*]-(m)
|
||||
RETURN MIN(LENGTH(path)) AS length
|
||||
"""
|
||||
paths = await self.query(paths_query)
|
||||
path_lengths = [p[0] for p in paths if p[0] is not None]
|
||||
|
||||
# Local clustering coefficient
|
||||
clustering_query = """
|
||||
/// Step 1: Get each node with its neighbors and degree
|
||||
MATCH (n:Node)-[:EDGE]-(neighbor)
|
||||
WITH n, COLLECT(DISTINCT neighbor) AS neighbors, COUNT(DISTINCT neighbor) AS degree
|
||||
|
||||
// Step 2: Pair up neighbors and check if they are connected
|
||||
UNWIND neighbors AS n1
|
||||
UNWIND neighbors AS n2
|
||||
WITH n, degree, n1, n2
|
||||
WHERE id(n1) < id(n2) // avoid duplicate pairs
|
||||
|
||||
// Step 3: Use OPTIONAL MATCH to see if n1 and n2 are connected
|
||||
OPTIONAL MATCH (n1)-[:EDGE]-(n2)
|
||||
WITH n, degree, COUNT(n2) AS triangle_count
|
||||
|
||||
// Step 4: Compute local clustering coefficient
|
||||
WITH n, degree,
|
||||
CASE WHEN degree <= 1 THEN 0.0
|
||||
ELSE (1.0 * triangle_count) / (degree * (degree - 1) / 2.0)
|
||||
END AS local_cc
|
||||
|
||||
// Step 5: Compute average
|
||||
RETURN AVG(local_cc) AS avg_clustering_coefficient
|
||||
"""
|
||||
clustering = await self.query(clustering_query)
|
||||
|
||||
optional_metrics = {
|
||||
"num_selfloops": num_selfloops,
|
||||
"diameter": max(path_lengths) if path_lengths else -1,
|
||||
"avg_shortest_path_length": sum(path_lengths) / len(path_lengths)
|
||||
if path_lengths
|
||||
else -1,
|
||||
"avg_clustering": clustering[0][0] if clustering and clustering[0][0] else -1,
|
||||
}
|
||||
else:
|
||||
optional_metrics = {
|
||||
"num_selfloops": -1,
|
||||
"diameter": -1,
|
||||
"avg_shortest_path_length": -1,
|
||||
"avg_clustering": -1,
|
||||
}
|
||||
|
||||
return {**mandatory_metrics, **optional_metrics}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get graph metrics: {e}")
|
||||
return {
|
||||
"num_nodes": 0,
|
||||
"num_edges": 0,
|
||||
"mean_degree": 0,
|
||||
"edge_density": 0,
|
||||
"num_connected_components": 0,
|
||||
"sizes_of_connected_components": [],
|
||||
"num_selfloops": -1,
|
||||
"diameter": -1,
|
||||
"avg_shortest_path_length": -1,
|
||||
"avg_clustering": -1,
|
||||
}
|
||||
|
|
@ -42,7 +42,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
async def query(self, query: str, params: dict):
|
||||
pass
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
async def has_node(self, node_id: UUID) -> bool:
|
||||
return self.graph.has_node(node_id)
|
||||
|
||||
async def add_node(self, node: DataPoint) -> None:
|
||||
|
|
@ -136,7 +136,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to add edges: {e}")
|
||||
raise
|
||||
|
||||
async def get_edges(self, node_id: str):
|
||||
async def get_edges(self, node_id: UUID):
|
||||
return list(self.graph.in_edges(node_id, data=True)) + list(
|
||||
self.graph.out_edges(node_id, data=True)
|
||||
)
|
||||
|
|
@ -174,13 +174,13 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
return disconnected_nodes
|
||||
|
||||
async def extract_node(self, node_id: str) -> dict:
|
||||
async def extract_node(self, node_id: UUID) -> dict:
|
||||
if self.graph.has_node(node_id):
|
||||
return self.graph.nodes[node_id]
|
||||
|
||||
return None
|
||||
|
||||
async def extract_nodes(self, node_ids: List[str]) -> List[dict]:
|
||||
async def extract_nodes(self, node_ids: List[UUID]) -> List[dict]:
|
||||
return [self.graph.nodes[node_id] for node_id in node_ids if self.graph.has_node(node_id)]
|
||||
|
||||
async def get_predecessors(self, node_id: UUID, edge_label: str = None) -> list:
|
||||
|
|
@ -215,7 +215,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
return nodes
|
||||
|
||||
async def get_neighbors(self, node_id: str) -> list:
|
||||
async def get_neighbors(self, node_id: UUID) -> list:
|
||||
if not self.graph.has_node(node_id):
|
||||
return []
|
||||
|
||||
|
|
@ -264,7 +264,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
return connections
|
||||
|
||||
async def remove_connection_to_predecessors_of(
|
||||
self, node_ids: list[str], edge_label: str
|
||||
self, node_ids: list[UUID], edge_label: str
|
||||
) -> None:
|
||||
for node_id in node_ids:
|
||||
if self.graph.has_node(node_id):
|
||||
|
|
@ -275,7 +275,7 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
await self.save_graph_to_file(self.filename)
|
||||
|
||||
async def remove_connection_to_successors_of(
|
||||
self, node_ids: list[str], edge_label: str
|
||||
self, node_ids: list[UUID], edge_label: str
|
||||
) -> None:
|
||||
for node_id in node_ids:
|
||||
if self.graph.has_node(node_id):
|
||||
|
|
@ -621,12 +621,12 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
nodes.append(node_data)
|
||||
return nodes
|
||||
|
||||
async def get_node(self, node_id: str) -> dict:
|
||||
async def get_node(self, node_id: UUID) -> dict:
|
||||
if self.graph.has_node(node_id):
|
||||
return self.graph.nodes[node_id]
|
||||
return None
|
||||
|
||||
async def get_nodes(self, node_ids: List[str] = None) -> List[dict]:
|
||||
async def get_nodes(self, node_ids: List[UUID] = None) -> List[dict]:
|
||||
if node_ids is None:
|
||||
return [{"id": node_id, **data} for node_id, data in self.graph.nodes(data=True)]
|
||||
return [
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class SQLAlchemyAdapter:
|
|||
async def delete_table(self, table_name: str, schema_name: Optional[str] = "public"):
|
||||
async with self.engine.begin() as connection:
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
# SQLite doesn’t support schema namespaces and the CASCADE keyword.
|
||||
# SQLite doesn't support schema namespaces and the CASCADE keyword.
|
||||
# However, foreign key constraint can be defined with ON DELETE CASCADE during table creation.
|
||||
await connection.execute(text(f'DROP TABLE IF EXISTS "{table_name}";'))
|
||||
else:
|
||||
|
|
@ -327,10 +327,10 @@ class SQLAlchemyAdapter:
|
|||
file.write("")
|
||||
else:
|
||||
async with self.engine.begin() as connection:
|
||||
schema_list = await self.get_schema_list()
|
||||
# Create a MetaData instance to load table information
|
||||
metadata = MetaData()
|
||||
# Drop all tables from all schemas
|
||||
# Drop all tables from the public schema
|
||||
schema_list = ["public", "public_staging"]
|
||||
for schema_name in schema_list:
|
||||
# Load the schema information into the MetaData object
|
||||
await connection.run_sync(metadata.reflect, schema=schema_name)
|
||||
|
|
|
|||
|
|
@ -6,8 +6,9 @@ from chromadb import AsyncHttpClient, Settings
|
|||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.storage.utils import get_own_properties
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
|
@ -108,9 +109,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
client = await self.get_connection()
|
||||
collections = await client.list_collections()
|
||||
# In ChromaDB v0.6.0, list_collections returns collection names directly
|
||||
collections = await self.get_collection_names()
|
||||
return collection_name in collections
|
||||
|
||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
||||
|
|
@ -119,13 +118,17 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
if not await self.has_collection(collection_name):
|
||||
await client.create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||
client = await self.get_connection()
|
||||
|
||||
async def get_collection(self, collection_name: str) -> AsyncHttpClient:
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(collection_name)
|
||||
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||
|
||||
collection = await client.get_collection(collection_name)
|
||||
client = await self.get_connection()
|
||||
return await client.get_collection(collection_name)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||
await self.create_collection(collection_name)
|
||||
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
texts = [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
|
||||
embeddings = await self.embed_data(texts)
|
||||
|
|
@ -161,8 +164,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
"""Retrieve data points by their IDs from a collection."""
|
||||
client = await self.get_connection()
|
||||
collection = await client.get_collection(collection_name)
|
||||
collection = await self.get_collection(collection_name)
|
||||
results = await collection.get(ids=data_point_ids, include=["metadatas"])
|
||||
|
||||
return [
|
||||
|
|
@ -174,62 +176,12 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
for id, metadata in zip(results["ids"], results["metadatas"])
|
||||
]
|
||||
|
||||
async def get_distance_from_collection_elements(
|
||||
self, collection_name: str, query_text: str = None, query_vector: List[float] = None
|
||||
):
|
||||
"""Calculate distance between query and all elements in a collection."""
|
||||
if query_text is None and query_vector is None:
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
client = await self.get_connection()
|
||||
try:
|
||||
collection = await client.get_collection(collection_name)
|
||||
|
||||
collection_count = await collection.count()
|
||||
|
||||
results = await collection.query(
|
||||
query_embeddings=[query_vector],
|
||||
include=["metadatas", "distances"],
|
||||
n_results=collection_count,
|
||||
)
|
||||
|
||||
result_values = []
|
||||
for i, (id, metadata, distance) in enumerate(
|
||||
zip(results["ids"][0], results["metadatas"][0], results["distances"][0])
|
||||
):
|
||||
result_values.append(
|
||||
{
|
||||
"id": parse_id(id),
|
||||
"payload": restore_data_from_chroma(metadata),
|
||||
"_distance": distance,
|
||||
}
|
||||
)
|
||||
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
scored_results = []
|
||||
for i, result in enumerate(result_values):
|
||||
scored_results.append(
|
||||
ScoredResult(
|
||||
id=result["id"],
|
||||
payload=result["payload"],
|
||||
score=normalized_values[i],
|
||||
)
|
||||
)
|
||||
|
||||
return scored_results
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
limit: int = 5,
|
||||
limit: int = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
):
|
||||
|
|
@ -241,8 +193,10 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
try:
|
||||
client = await self.get_connection()
|
||||
collection = await client.get_collection(collection_name)
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
if limit == 0:
|
||||
limit = await collection.count()
|
||||
|
||||
results = await collection.query(
|
||||
query_embeddings=[query_vector],
|
||||
|
|
@ -296,8 +250,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
"""Perform multiple searches in a single request for efficiency."""
|
||||
query_vectors = await self.embed_data(query_texts)
|
||||
|
||||
client = await self.get_connection()
|
||||
collection = await client.get_collection(collection_name)
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
results = await collection.query(
|
||||
query_embeddings=query_vectors,
|
||||
|
|
@ -346,15 +299,14 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
"""Remove data points from a collection by their IDs."""
|
||||
client = await self.get_connection()
|
||||
collection = await client.get_collection(collection_name)
|
||||
collection = await self.get_collection(collection_name)
|
||||
await collection.delete(ids=data_point_ids)
|
||||
return True
|
||||
|
||||
async def prune(self):
|
||||
"""Delete all collections in the ChromaDB database."""
|
||||
client = await self.get_connection()
|
||||
collections = await client.list_collections()
|
||||
collections = await self.list_collections()
|
||||
for collection_name in collections:
|
||||
await client.delete_collection(collection_name)
|
||||
return True
|
||||
|
|
@ -362,4 +314,8 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
async def get_collection_names(self):
|
||||
"""Get a list of all collection names in the database."""
|
||||
client = await self.get_connection()
|
||||
return await client.list_collections()
|
||||
collections = await client.list_collections()
|
||||
return [
|
||||
collection.name if hasattr(collection, "name") else collection["name"]
|
||||
for collection in collections
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ class CollectionNotFoundError(CriticalError):
|
|||
def __init__(
|
||||
self,
|
||||
message,
|
||||
name: str = "DatabaseNotCreatedError",
|
||||
name: str = "CollectionNotFoundError",
|
||||
status_code: int = status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
log=True,
|
||||
log_level="DEBUG",
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
super().__init__(message, name, status_code, log, log_level)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
import asyncio
|
||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from pydantic import BaseModel
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -17,8 +16,6 @@ from ..models.ScoredResult import ScoredResult
|
|||
from ..utils import normalize_distances
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
id: str
|
||||
|
|
@ -76,9 +73,14 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
exist_ok=True,
|
||||
)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||
connection = await self.get_connection()
|
||||
async def get_collection(self, collection_name: str):
|
||||
if not await self.has_collection(collection_name):
|
||||
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||
|
||||
connection = await self.get_connection()
|
||||
return await connection.open_table(collection_name)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||
payload_schema = type(data_points[0])
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
|
|
@ -87,7 +89,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
payload_schema,
|
||||
)
|
||||
|
||||
collection = await connection.open_table(collection_name)
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
|
||||
|
|
@ -125,8 +127,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
connection = await self.get_connection()
|
||||
collection = await connection.open_table(collection_name)
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
if len(data_point_ids) == 1:
|
||||
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
|
||||
|
|
@ -142,48 +143,12 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
for result in results.to_dict("index").values()
|
||||
]
|
||||
|
||||
async def get_distance_from_collection_elements(
|
||||
self, collection_name: str, query_text: str = None, query_vector: List[float] = None
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
connection = await self.get_connection()
|
||||
|
||||
try:
|
||||
collection = await connection.open_table(collection_name)
|
||||
|
||||
collection_size = await collection.count_rows()
|
||||
|
||||
results = (
|
||||
await collection.vector_search(query_vector).limit(collection_size).to_pandas()
|
||||
)
|
||||
|
||||
result_values = list(results.to_dict("index").values())
|
||||
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
payload=result["payload"],
|
||||
score=normalized_values[value_index],
|
||||
)
|
||||
for value_index, result in enumerate(result_values)
|
||||
]
|
||||
except ValueError:
|
||||
# Ignore if collection doesn't exist
|
||||
return []
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
limit: int = 5,
|
||||
limit: int = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
):
|
||||
|
|
@ -193,12 +158,10 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
connection = await self.get_connection()
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
try:
|
||||
collection = await connection.open_table(collection_name)
|
||||
except ValueError:
|
||||
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||
if limit == 0:
|
||||
limit = await collection.count_rows()
|
||||
|
||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
||||
|
||||
|
|
@ -239,30 +202,12 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
]
|
||||
)
|
||||
|
||||
def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def _delete_data_points():
|
||||
connection = await self.get_connection()
|
||||
collection = await connection.open_table(collection_name)
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
# Delete one at a time to avoid commit conflicts
|
||||
for data_point_id in data_point_ids:
|
||||
await collection.delete(f"id = '{data_point_id}'")
|
||||
|
||||
return True
|
||||
|
||||
# Check if we're in an event loop
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# If we're in a running event loop, create a new task
|
||||
return loop.create_task(_delete_data_points())
|
||||
else:
|
||||
# If we're not in an event loop, run it synchronously
|
||||
return asyncio.run(_delete_data_points())
|
||||
# Delete one at a time to avoid commit conflicts
|
||||
for data_point_id in data_point_ids:
|
||||
await collection.delete(f"id = '{data_point_id}'")
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
await self.create_collection(
|
||||
|
|
@ -288,7 +233,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
collection_names = await connection.table_names()
|
||||
|
||||
for collection_name in collection_names:
|
||||
collection = await connection.open_table(collection_name)
|
||||
collection = await self.get_collection(collection_name)
|
||||
await collection.delete("id IS NOT NULL")
|
||||
await connection.drop_table(collection_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from uuid import UUID
|
||||
from typing import List, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
|
|
@ -96,7 +97,7 @@ class MilvusAdapter(VectorDBInterface):
|
|||
raise e
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
from pymilvus import MilvusException
|
||||
from pymilvus import MilvusException, exceptions
|
||||
|
||||
client = self.get_milvus_client()
|
||||
data_vectors = await self.embed_data(
|
||||
|
|
@ -118,6 +119,10 @@ class MilvusAdapter(VectorDBInterface):
|
|||
f"Inserted {result.get('insert_count', 0)} data points into collection '{collection_name}'."
|
||||
)
|
||||
return result
|
||||
except exceptions.CollectionNotExistException as error:
|
||||
raise CollectionNotFoundError(
|
||||
f"Collection '{collection_name}' does not exist!"
|
||||
) from error
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
f"Error inserting data points into collection '{collection_name}': {str(e)}"
|
||||
|
|
@ -140,8 +145,8 @@ class MilvusAdapter(VectorDBInterface):
|
|||
collection_name = f"{index_name}_{index_property_name}"
|
||||
await self.create_data_points(collection_name, formatted_data_points)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
from pymilvus import MilvusException
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[UUID]):
|
||||
from pymilvus import MilvusException, exceptions
|
||||
|
||||
client = self.get_milvus_client()
|
||||
try:
|
||||
|
|
@ -153,6 +158,10 @@ class MilvusAdapter(VectorDBInterface):
|
|||
output_fields=["*"],
|
||||
)
|
||||
return results
|
||||
except exceptions.CollectionNotExistException as error:
|
||||
raise CollectionNotFoundError(
|
||||
f"Collection '{collection_name}' does not exist!"
|
||||
) from error
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
f"Error retrieving data points from collection '{collection_name}': {str(e)}"
|
||||
|
|
@ -164,10 +173,10 @@ class MilvusAdapter(VectorDBInterface):
|
|||
collection_name: str,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = 5,
|
||||
limit: int = 15,
|
||||
with_vector: bool = False,
|
||||
):
|
||||
from pymilvus import MilvusException
|
||||
from pymilvus import MilvusException, exceptions
|
||||
|
||||
client = self.get_milvus_client()
|
||||
if query_text is None and query_vector is None:
|
||||
|
|
@ -184,7 +193,7 @@ class MilvusAdapter(VectorDBInterface):
|
|||
collection_name=collection_name,
|
||||
data=[query_vector],
|
||||
anns_field="vector",
|
||||
limit=limit,
|
||||
limit=limit if limit > 0 else None,
|
||||
output_fields=output_fields,
|
||||
search_params={
|
||||
"metric_type": "COSINE",
|
||||
|
|
@ -199,6 +208,10 @@ class MilvusAdapter(VectorDBInterface):
|
|||
)
|
||||
for result in results[0]
|
||||
]
|
||||
except exceptions.CollectionNotExistException as error:
|
||||
raise CollectionNotFoundError(
|
||||
f"Collection '{collection_name}' does not exist!"
|
||||
) from error
|
||||
except MilvusException as e:
|
||||
logger.error(f"Error during search in collection '{collection_name}': {str(e)}")
|
||||
raise e
|
||||
|
|
@ -220,7 +233,7 @@ class MilvusAdapter(VectorDBInterface):
|
|||
]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
|
||||
from pymilvus import MilvusException
|
||||
|
||||
client = self.get_milvus_client()
|
||||
|
|
|
|||
|
|
@ -7,19 +7,18 @@ from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
|||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
from ...relational.ModelBase import Base
|
||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from .serialize_data import serialize_data
|
||||
from ..utils import normalize_distances
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..exceptions import CollectionNotFoundError
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from .serialize_data import serialize_data
|
||||
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
|
|
@ -184,7 +183,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
if collection_name in metadata.tables:
|
||||
return metadata.tables[collection_name]
|
||||
else:
|
||||
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||
raise CollectionNotFoundError(
|
||||
f"Collection '{collection_name}' not found!", log_level="DEBUG"
|
||||
)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||
# Get PGVectorDataPoint Table from database
|
||||
|
|
@ -201,60 +202,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
for result in results
|
||||
]
|
||||
|
||||
async def get_distance_from_collection_elements(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
with_vector: bool = False,
|
||||
) -> List[ScoredResult]:
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
try:
|
||||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(
|
||||
select(
|
||||
PGVectorDataPoint,
|
||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
|
||||
"similarity"
|
||||
),
|
||||
).order_by("similarity")
|
||||
)
|
||||
|
||||
vector_list = []
|
||||
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items:
|
||||
# TODO: Add normalization of similarity score
|
||||
vector_list.append(vector)
|
||||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(id=parse_id(str(row.id)), payload=row.payload, score=row.similarity)
|
||||
for row in vector_list
|
||||
]
|
||||
except EntityNotFoundError:
|
||||
# Ignore if collection does not exist
|
||||
return []
|
||||
except CollectionNotFoundError:
|
||||
# Ignore if collection does not exist
|
||||
return []
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = 5,
|
||||
limit: int = 15,
|
||||
with_vector: bool = False,
|
||||
) -> List[ScoredResult]:
|
||||
if query_text is None and query_vector is None:
|
||||
|
|
@ -266,24 +219,26 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||
closest_items = []
|
||||
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
query = select(
|
||||
PGVectorDataPoint,
|
||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||
).order_by("similarity")
|
||||
|
||||
if limit > 0:
|
||||
query = query.limit(limit)
|
||||
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(
|
||||
select(
|
||||
PGVectorDataPoint,
|
||||
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
||||
)
|
||||
.order_by("similarity")
|
||||
.limit(limit)
|
||||
)
|
||||
closest_items = await session.execute(query)
|
||||
|
||||
vector_list = []
|
||||
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items:
|
||||
for vector in closest_items.all():
|
||||
vector_list.append(
|
||||
{
|
||||
"id": parse_id(str(vector.id)),
|
||||
|
|
@ -292,6 +247,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
}
|
||||
)
|
||||
|
||||
if len(vector_list) == 0:
|
||||
return []
|
||||
|
||||
# Normalize vector distance and add this as score information to vector_list
|
||||
normalized_values = normalize_distances(vector_list)
|
||||
for i in range(0, len(normalized_values)):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
|
|
@ -97,6 +97,8 @@ class QDrantAdapter(VectorDBInterface):
|
|||
await client.close()
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
|
|
@ -114,6 +116,13 @@ class QDrantAdapter(VectorDBInterface):
|
|||
|
||||
try:
|
||||
client.upload_points(collection_name=collection_name, points=points)
|
||||
except UnexpectedResponse as error:
|
||||
if "Collection not found" in str(error):
|
||||
raise CollectionNotFoundError(
|
||||
message=f"Collection {collection_name} not found!"
|
||||
) from error
|
||||
else:
|
||||
raise error
|
||||
except Exception as error:
|
||||
logger.error("Error uploading data points to Qdrant: %s", str(error))
|
||||
raise error
|
||||
|
|
@ -143,19 +152,22 @@ class QDrantAdapter(VectorDBInterface):
|
|||
await client.close()
|
||||
return results
|
||||
|
||||
async def get_distance_from_collection_elements(
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = 15,
|
||||
with_vector: bool = False,
|
||||
) -> List[ScoredResult]:
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
):
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
client = self.get_qdrant_client()
|
||||
if query_text is None and query_vector is None:
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
try:
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
results = await client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=models.NamedVector(
|
||||
|
|
@ -164,9 +176,12 @@ class QDrantAdapter(VectorDBInterface):
|
|||
if query_vector is not None
|
||||
else (await self.embed_data([query_text]))[0],
|
||||
),
|
||||
limit=limit if limit > 0 else None,
|
||||
with_vectors=with_vector,
|
||||
)
|
||||
|
||||
await client.close()
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result.id),
|
||||
|
|
@ -178,51 +193,16 @@ class QDrantAdapter(VectorDBInterface):
|
|||
)
|
||||
for result in results
|
||||
]
|
||||
except ValueError:
|
||||
# Ignore if the collection doesn't exist
|
||||
return []
|
||||
except UnexpectedResponse as error:
|
||||
if "Collection not found" in str(error):
|
||||
raise CollectionNotFoundError(
|
||||
message=f"Collection {collection_name} not found!"
|
||||
) from error
|
||||
else:
|
||||
raise error
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = 5,
|
||||
with_vector: bool = False,
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
results = await client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=models.NamedVector(
|
||||
name="text",
|
||||
vector=query_vector
|
||||
if query_vector is not None
|
||||
else (await self.embed_data([query_text]))[0],
|
||||
),
|
||||
limit=limit,
|
||||
with_vectors=with_vector,
|
||||
)
|
||||
|
||||
await client.close()
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result.id),
|
||||
payload={
|
||||
**result.payload,
|
||||
"id": parse_id(result.id),
|
||||
},
|
||||
score=1 - result.score,
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
collection_name: str,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import List, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
|
|
@ -34,21 +34,23 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
|
||||
self.embedding_engine = embedding_engine
|
||||
|
||||
self.client = weaviate.connect_to_wcs(
|
||||
self.client = weaviate.use_async_with_weaviate_cloud(
|
||||
cluster_url=url,
|
||||
auth_credentials=weaviate.auth.AuthApiKey(api_key),
|
||||
additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30)),
|
||||
)
|
||||
|
||||
async def get_client(self):
|
||||
await self.client.connect()
|
||||
|
||||
return self.client
|
||||
|
||||
async def embed_data(self, data: List[str]) -> List[float]:
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
future = asyncio.Future()
|
||||
|
||||
future.set_result(self.client.collections.exists(collection_name))
|
||||
|
||||
return await future
|
||||
client = await self.get_client()
|
||||
return await client.collections.exists(collection_name)
|
||||
|
||||
async def create_collection(
|
||||
self,
|
||||
|
|
@ -57,26 +59,25 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
):
|
||||
import weaviate.classes.config as wvcc
|
||||
|
||||
future = asyncio.Future()
|
||||
|
||||
if not self.client.collections.exists(collection_name):
|
||||
future.set_result(
|
||||
self.client.collections.create(
|
||||
name=collection_name,
|
||||
properties=[
|
||||
wvcc.Property(
|
||||
name="text", data_type=wvcc.DataType.TEXT, skip_vectorization=True
|
||||
)
|
||||
],
|
||||
)
|
||||
if not await self.has_collection(collection_name):
|
||||
client = await self.get_client()
|
||||
return await client.collections.create(
|
||||
name=collection_name,
|
||||
properties=[
|
||||
wvcc.Property(
|
||||
name="text", data_type=wvcc.DataType.TEXT, skip_vectorization=True
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
future.set_result(self.get_collection(collection_name))
|
||||
return await self.get_collection(collection_name)
|
||||
|
||||
return await future
|
||||
async def get_collection(self, collection_name: str):
|
||||
if not await self.has_collection(collection_name):
|
||||
raise CollectionNotFoundError(f"Collection '{collection_name}' not found.")
|
||||
|
||||
def get_collection(self, collection_name: str):
|
||||
return self.client.collections.get(collection_name)
|
||||
client = await self.get_client()
|
||||
return client.collections.get(collection_name)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
from weaviate.classes.data import DataObject
|
||||
|
|
@ -97,29 +98,30 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
|
||||
data_points = [convert_to_weaviate_data_points(data_point) for data_point in data_points]
|
||||
|
||||
collection = self.get_collection(collection_name)
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
try:
|
||||
if len(data_points) > 1:
|
||||
with collection.batch.dynamic() as batch:
|
||||
for data_point in data_points:
|
||||
batch.add_object(
|
||||
uuid=data_point.uuid,
|
||||
vector=data_point.vector,
|
||||
properties=data_point.properties,
|
||||
references=data_point.references,
|
||||
)
|
||||
return await collection.data.insert_many(data_points)
|
||||
# with collection.batch.dynamic() as batch:
|
||||
# for data_point in data_points:
|
||||
# batch.add_object(
|
||||
# uuid=data_point.uuid,
|
||||
# vector=data_point.vector,
|
||||
# properties=data_point.properties,
|
||||
# references=data_point.references,
|
||||
# )
|
||||
else:
|
||||
data_point: DataObject = data_points[0]
|
||||
if collection.data.exists(data_point.uuid):
|
||||
return collection.data.update(
|
||||
return await collection.data.update(
|
||||
uuid=data_point.uuid,
|
||||
vector=data_point.vector,
|
||||
properties=data_point.properties,
|
||||
references=data_point.references,
|
||||
)
|
||||
else:
|
||||
return collection.data.insert(
|
||||
return await collection.data.insert(
|
||||
uuid=data_point.uuid,
|
||||
vector=data_point.vector,
|
||||
properties=data_point.properties,
|
||||
|
|
@ -130,12 +132,12 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
raise error
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||
return await self.create_collection(f"{index_name}_{index_property_name}")
|
||||
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
):
|
||||
await self.create_data_points(
|
||||
return await self.create_data_points(
|
||||
f"{index_name}_{index_property_name}",
|
||||
[
|
||||
IndexSchema(
|
||||
|
|
@ -149,9 +151,8 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
from weaviate.classes.query import Filter
|
||||
|
||||
future = asyncio.Future()
|
||||
|
||||
data_points = self.get_collection(collection_name).query.fetch_objects(
|
||||
collection = await self.get_collection(collection_name)
|
||||
data_points = await collection.query.fetch_objects(
|
||||
filters=Filter.by_id().contains_any(data_point_ids)
|
||||
)
|
||||
|
||||
|
|
@ -160,30 +161,32 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
data_point.id = data_point.uuid
|
||||
del data_point.properties
|
||||
|
||||
future.set_result(data_points.objects)
|
||||
return data_points.objects
|
||||
|
||||
return await future
|
||||
|
||||
async def get_distance_from_collection_elements(
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = 15,
|
||||
with_vector: bool = False,
|
||||
) -> List[ScoredResult]:
|
||||
):
|
||||
import weaviate.classes as wvc
|
||||
import weaviate.exceptions
|
||||
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_vector is None:
|
||||
query_vector = (await self.embed_data([query_text]))[0]
|
||||
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
try:
|
||||
search_result = self.get_collection(collection_name).query.hybrid(
|
||||
search_result = await collection.query.hybrid(
|
||||
query=None,
|
||||
vector=query_vector,
|
||||
limit=limit if limit > 0 else None,
|
||||
include_vector=with_vector,
|
||||
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||
)
|
||||
|
|
@ -196,43 +199,10 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
)
|
||||
for result in search_result.objects
|
||||
]
|
||||
except weaviate.exceptions.UnexpectedStatusCodeError:
|
||||
except weaviate.exceptions.WeaviateInvalidInputError:
|
||||
# Ignore if the collection doesn't exist
|
||||
return []
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = None,
|
||||
with_vector: bool = False,
|
||||
):
|
||||
import weaviate.classes as wvc
|
||||
|
||||
if query_text is None and query_vector is None:
|
||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||
|
||||
if query_vector is None:
|
||||
query_vector = (await self.embed_data([query_text]))[0]
|
||||
|
||||
search_result = self.get_collection(collection_name).query.hybrid(
|
||||
query=None,
|
||||
vector=query_vector,
|
||||
limit=limit,
|
||||
include_vector=with_vector,
|
||||
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||
)
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(str(result.uuid)),
|
||||
payload=result.properties,
|
||||
score=1 - float(result.metadata.score),
|
||||
)
|
||||
for result in search_result.objects
|
||||
]
|
||||
|
||||
async def batch_search(
|
||||
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
|
||||
):
|
||||
|
|
@ -248,14 +218,13 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
from weaviate.classes.query import Filter
|
||||
|
||||
future = asyncio.Future()
|
||||
|
||||
result = self.get_collection(collection_name).data.delete_many(
|
||||
collection = await self.get_collection(collection_name)
|
||||
result = await collection.data.delete_many(
|
||||
filters=Filter.by_id().contains_any(data_point_ids)
|
||||
)
|
||||
future.set_result(result)
|
||||
|
||||
return await future
|
||||
return result
|
||||
|
||||
async def prune(self):
|
||||
self.client.collections.delete_all()
|
||||
client = await self.get_client()
|
||||
await client.collections.delete_all()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Type, Optional
|
||||
from pydantic import BaseModel
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
import litellm
|
||||
from pydantic import BaseModel
|
||||
from typing import Type, Optional
|
||||
from litellm import acompletion, JSONSchemaValidationError
|
||||
from cognee.shared.data_models import MonitoringTool
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
|
|
@ -11,14 +12,9 @@ from cognee.infrastructure.llm.rate_limiter import (
|
|||
rate_limit_async,
|
||||
sleep_and_retry_async,
|
||||
)
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
monitoring = get_base_config().monitoring_tool
|
||||
|
||||
if monitoring == MonitoringTool.LANGFUSE:
|
||||
from langfuse.decorators import observe
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class GeminiAdapter(LLMInterface):
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
import os
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
|
||||
import litellm
|
||||
import instructor
|
||||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
|
||||
from cognee.shared.data_models import MonitoringTool
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
|
|
@ -18,12 +15,9 @@ from cognee.infrastructure.llm.rate_limiter import (
|
|||
sleep_and_retry_async,
|
||||
sleep_and_retry_sync,
|
||||
)
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
|
||||
monitoring = get_base_config().monitoring_tool
|
||||
|
||||
if monitoring == MonitoringTool.LANGFUSE:
|
||||
from langfuse.decorators import observe
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class OpenAIAdapter(LLMInterface):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from .get_datasets import get_datasets
|
|||
from .get_datasets_by_name import get_datasets_by_name
|
||||
from .get_dataset_data import get_dataset_data
|
||||
from .get_data import get_data
|
||||
from .get_unique_dataset_id import get_unique_dataset_id
|
||||
|
||||
# Delete
|
||||
from .delete_dataset import delete_dataset
|
||||
|
|
|
|||
|
|
@ -4,8 +4,13 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import joinedload
|
||||
from cognee.modules.data.models import Dataset
|
||||
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -> Dataset:
|
||||
owner_id = user.id
|
||||
|
||||
async def create_dataset(dataset_name: str, owner_id: UUID, session: AsyncSession) -> Dataset:
|
||||
dataset = (
|
||||
await session.scalars(
|
||||
select(Dataset)
|
||||
|
|
@ -16,10 +21,9 @@ async def create_dataset(dataset_name: str, owner_id: UUID, session: AsyncSessio
|
|||
).first()
|
||||
|
||||
if dataset is None:
|
||||
# Dataset id should be generated based on dataset_name and owner_id so multiple users can use the same dataset_name
|
||||
dataset = Dataset(
|
||||
id=uuid5(NAMESPACE_OID, f"{dataset_name}{str(owner_id)}"), name=dataset_name, data=[]
|
||||
)
|
||||
# Dataset id should be generated based on dataset_name and owner_id/user so multiple users can use the same dataset_name
|
||||
dataset_id = await get_unique_dataset_id(dataset_name=dataset_name, user=user)
|
||||
dataset = Dataset(id=dataset_id, name=dataset_name, data=[])
|
||||
dataset.owner_id = owner_id
|
||||
|
||||
session.add(dataset)
|
||||
|
|
|
|||
6
cognee/modules/data/methods/get_unique_dataset_id.py
Normal file
6
cognee/modules/data/methods/get_unique_dataset_id.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
async def get_unique_dataset_id(dataset_name: str, user: User) -> UUID:
|
||||
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}")
|
||||
|
|
@ -128,8 +128,10 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
if query_vector is None or len(query_vector) == 0:
|
||||
raise ValueError("Failed to generate query embedding.")
|
||||
|
||||
edge_distances = await vector_engine.get_distance_from_collection_elements(
|
||||
"EdgeType_relationship_name", query_text=query
|
||||
edge_distances = await vector_engine.search(
|
||||
collection_name="EdgeType_relationship_name",
|
||||
query_text=query,
|
||||
limit=0,
|
||||
)
|
||||
|
||||
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
||||
|
|
|
|||
11
cognee/modules/observability/get_observe.py
Normal file
11
cognee/modules/observability/get_observe.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from cognee.base_config import get_base_config
|
||||
from .observers import Observer
|
||||
|
||||
|
||||
def get_observe():
|
||||
monitoring = get_base_config().monitoring_tool
|
||||
|
||||
if monitoring == Observer.LANGFUSE:
|
||||
from langfuse.decorators import observe
|
||||
|
||||
return observe
|
||||
9
cognee/modules/observability/observers.py
Normal file
9
cognee/modules/observability/observers.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class Observer(str, Enum):
|
||||
"""Monitoring tools"""
|
||||
|
||||
LANGFUSE = "langfuse"
|
||||
LLMLITE = "llmlite"
|
||||
LANGSMITH = "langsmith"
|
||||
|
|
@ -6,6 +6,7 @@ from cognee.infrastructure.databases.relational import Base
|
|||
|
||||
|
||||
class PipelineRunStatus(enum.Enum):
|
||||
DATASET_PROCESSING_INITIATED = "DATASET_PROCESSING_INITIATED"
|
||||
DATASET_PROCESSING_STARTED = "DATASET_PROCESSING_STARTED"
|
||||
DATASET_PROCESSING_COMPLETED = "DATASET_PROCESSING_COMPLETED"
|
||||
DATASET_PROCESSING_ERRORED = "DATASET_PROCESSING_ERRORED"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .log_pipeline_run_initiated import log_pipeline_run_initiated
|
||||
from .log_pipeline_run_start import log_pipeline_run_start
|
||||
from .log_pipeline_run_complete import log_pipeline_run_complete
|
||||
from .log_pipeline_run_error import log_pipeline_run_error
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from ..models import PipelineRun
|
|||
from sqlalchemy.orm import aliased
|
||||
|
||||
|
||||
async def get_pipeline_status(dataset_ids: list[UUID]):
|
||||
async def get_pipeline_status(dataset_ids: list[UUID], pipeline_name: str):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
|
|
@ -20,6 +20,7 @@ async def get_pipeline_status(dataset_ids: list[UUID]):
|
|||
.label("rn"),
|
||||
)
|
||||
.filter(PipelineRun.dataset_id.in_(dataset_ids))
|
||||
.filter(PipelineRun.pipeline_name == pipeline_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,22 @@
|
|||
from uuid import UUID, uuid4
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.pipelines.models import PipelineRun, PipelineRunStatus
|
||||
|
||||
|
||||
async def log_pipeline_run_initiated(pipeline_id: str, pipeline_name: str, dataset_id: UUID):
|
||||
pipeline_run = PipelineRun(
|
||||
pipeline_run_id=uuid4(),
|
||||
pipeline_name=pipeline_name,
|
||||
pipeline_id=pipeline_id,
|
||||
status=PipelineRunStatus.DATASET_PROCESSING_INITIATED,
|
||||
dataset_id=dataset_id,
|
||||
run_info={},
|
||||
)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
session.add(pipeline_run)
|
||||
await session.commit()
|
||||
|
||||
return pipeline_run
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import Union
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.data.methods import get_datasets
|
||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
|
|
@ -12,6 +13,7 @@ from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline
|
|||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines.operations import log_pipeline_run_initiated
|
||||
|
||||
from cognee.infrastructure.databases.relational import (
|
||||
create_db_and_tables as create_relational_db_and_tables,
|
||||
|
|
@ -58,15 +60,36 @@ async def cognee_pipeline(
|
|||
|
||||
# If no datasets are provided, work with all existing datasets.
|
||||
existing_datasets = await get_datasets(user.id)
|
||||
if datasets is None or len(datasets) == 0:
|
||||
|
||||
if not datasets:
|
||||
# Get datasets from database if none sent.
|
||||
datasets = existing_datasets
|
||||
if isinstance(datasets[0], str):
|
||||
datasets = await get_datasets_by_name(datasets, user.id)
|
||||
else:
|
||||
# Try to get datasets objects from database, if they don't exist use dataset name
|
||||
datasets_names = await get_datasets_by_name(datasets, user.id)
|
||||
if datasets_names:
|
||||
datasets = datasets_names
|
||||
# If dataset is already in database, use it, otherwise create a new instance.
|
||||
dataset_instances = []
|
||||
|
||||
for dataset_name in datasets:
|
||||
is_dataset_found = False
|
||||
|
||||
for existing_dataset in existing_datasets:
|
||||
if (
|
||||
existing_dataset.name == dataset_name
|
||||
or str(existing_dataset.id) == dataset_name
|
||||
):
|
||||
dataset_instances.append(existing_dataset)
|
||||
is_dataset_found = True
|
||||
break
|
||||
|
||||
if not is_dataset_found:
|
||||
dataset_instances.append(
|
||||
Dataset(
|
||||
id=await get_unique_dataset_id(dataset_name=dataset_name, user=user),
|
||||
name=dataset_name,
|
||||
owner_id=user.id,
|
||||
)
|
||||
)
|
||||
|
||||
datasets = dataset_instances
|
||||
|
||||
awaitables = []
|
||||
|
||||
|
|
@ -87,31 +110,48 @@ async def run_pipeline(
|
|||
data=None,
|
||||
pipeline_name: str = "custom_pipeline",
|
||||
):
|
||||
if isinstance(dataset, Dataset):
|
||||
check_dataset_name(dataset.name)
|
||||
dataset_id = dataset.id
|
||||
elif isinstance(dataset, str):
|
||||
check_dataset_name(dataset)
|
||||
# Generate id based on unique dataset_id formula
|
||||
dataset_id = uuid5(NAMESPACE_OID, f"{dataset}{str(user.id)}")
|
||||
check_dataset_name(dataset.name)
|
||||
|
||||
# Ugly hack, but no easier way to do this.
|
||||
if pipeline_name == "add_pipeline":
|
||||
# Refresh the add pipeline status so data is added to a dataset.
|
||||
# Without this the app_pipeline status will be DATASET_PROCESSING_COMPLETED and will skip the execution.
|
||||
dataset_id = uuid5(NAMESPACE_OID, f"{dataset.name}{str(user.id)}")
|
||||
|
||||
await log_pipeline_run_initiated(
|
||||
pipeline_id=uuid5(NAMESPACE_OID, "add_pipeline"),
|
||||
pipeline_name="add_pipeline",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
# Refresh the cognify pipeline status after we add new files.
|
||||
# Without this the cognify_pipeline status will be DATASET_PROCESSING_COMPLETED and will skip the execution.
|
||||
await log_pipeline_run_initiated(
|
||||
pipeline_id=uuid5(NAMESPACE_OID, "cognify_pipeline"),
|
||||
pipeline_name="cognify_pipeline",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
dataset_id = dataset.id
|
||||
|
||||
if not data:
|
||||
data: list[Data] = await get_dataset_data(dataset_id=dataset_id)
|
||||
|
||||
# async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests
|
||||
if isinstance(dataset, Dataset):
|
||||
task_status = await get_pipeline_status([dataset_id])
|
||||
task_status = await get_pipeline_status([dataset_id], pipeline_name)
|
||||
else:
|
||||
task_status = [
|
||||
PipelineRunStatus.DATASET_PROCESSING_COMPLETED
|
||||
] # TODO: this is a random assignment, find permanent solution
|
||||
|
||||
if (
|
||||
str(dataset_id) in task_status
|
||||
and task_status[str(dataset_id)] == PipelineRunStatus.DATASET_PROCESSING_STARTED
|
||||
):
|
||||
logger.info("Dataset %s is already being processed.", dataset_id)
|
||||
return
|
||||
if str(dataset_id) in task_status:
|
||||
if task_status[str(dataset_id)] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
|
||||
logger.info("Dataset %s is already being processed.", dataset_id)
|
||||
return
|
||||
if task_status[str(dataset_id)] == PipelineRunStatus.DATASET_PROCESSING_COMPLETED:
|
||||
logger.info("Dataset %s is already processed.", dataset_id)
|
||||
return
|
||||
|
||||
if not isinstance(tasks, list):
|
||||
raise ValueError("Tasks must be a list")
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ from ..tasks.task import Task
|
|||
logger = get_logger("run_tasks(tasks: [Task], data)")
|
||||
|
||||
|
||||
async def run_tasks_with_telemetry(tasks: list[Task], data, user: User, pipeline_name: str):
|
||||
async def run_tasks_with_telemetry(
|
||||
tasks: list[Task], data, user: User, pipeline_name: str, context: dict = None
|
||||
):
|
||||
config = get_current_settings()
|
||||
|
||||
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent=1))
|
||||
|
|
@ -36,7 +38,7 @@ async def run_tasks_with_telemetry(tasks: list[Task], data, user: User, pipeline
|
|||
| config,
|
||||
)
|
||||
|
||||
async for result in run_tasks_base(tasks, data, user):
|
||||
async for result in run_tasks_base(tasks, data, user, context):
|
||||
yield result
|
||||
|
||||
logger.info("Pipeline run completed: `%s`", pipeline_name)
|
||||
|
|
@ -72,6 +74,7 @@ async def run_tasks(
|
|||
data: Any = None,
|
||||
user: User = None,
|
||||
pipeline_name: str = "unknown_pipeline",
|
||||
context: dict = None,
|
||||
):
|
||||
pipeline_id = uuid5(NAMESPACE_OID, pipeline_name)
|
||||
|
||||
|
|
@ -82,7 +85,11 @@ async def run_tasks(
|
|||
|
||||
try:
|
||||
async for _ in run_tasks_with_telemetry(
|
||||
tasks=tasks, data=data, user=user, pipeline_name=pipeline_id
|
||||
tasks=tasks,
|
||||
data=data,
|
||||
user=user,
|
||||
pipeline_name=pipeline_id,
|
||||
context=context,
|
||||
):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ async def handle_task(
|
|||
leftover_tasks: list[Task],
|
||||
next_task_batch_size: int,
|
||||
user: User,
|
||||
context: dict = None,
|
||||
):
|
||||
"""Handle common task workflow with logging, telemetry, and error handling around the core execution logic."""
|
||||
task_type = running_task.task_type
|
||||
|
|
@ -27,9 +28,16 @@ async def handle_task(
|
|||
},
|
||||
)
|
||||
|
||||
has_context = any(
|
||||
[key == "context" for key in inspect.signature(running_task.executable).parameters.keys()]
|
||||
)
|
||||
|
||||
if has_context:
|
||||
args.append(context)
|
||||
|
||||
try:
|
||||
async for result_data in running_task.execute(args, next_task_batch_size):
|
||||
async for result in run_tasks_base(leftover_tasks, result_data, user):
|
||||
async for result in run_tasks_base(leftover_tasks, result_data, user, context):
|
||||
yield result
|
||||
|
||||
logger.info(f"{task_type} task completed: `{running_task.executable.__name__}`")
|
||||
|
|
@ -55,7 +63,7 @@ async def handle_task(
|
|||
raise error
|
||||
|
||||
|
||||
async def run_tasks_base(tasks: list[Task], data=None, user: User = None):
|
||||
async def run_tasks_base(tasks: list[Task], data=None, user: User = None, context: dict = None):
|
||||
"""Base function to execute tasks in a pipeline, handling task type detection and execution."""
|
||||
if len(tasks) == 0:
|
||||
yield data
|
||||
|
|
@ -68,5 +76,7 @@ async def run_tasks_base(tasks: list[Task], data=None, user: User = None):
|
|||
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
|
||||
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
||||
|
||||
async for result in handle_task(running_task, args, leftover_tasks, next_task_batch_size, user):
|
||||
async for result in handle_task(
|
||||
running_task, args, leftover_tasks, next_task_batch_size, user, context
|
||||
):
|
||||
yield result
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@ Custom exceptions for the Cognee API.
|
|||
This module defines a set of exceptions for handling various data errors
|
||||
"""
|
||||
|
||||
from .exceptions import SearchTypeNotSupported, CypherSearchError, CollectionDistancesNotFoundError
|
||||
from .exceptions import SearchTypeNotSupported, CypherSearchError
|
||||
|
|
|
|||
|
|
@ -2,16 +2,6 @@ from fastapi import status
|
|||
from cognee.exceptions import CogneeApiError, CriticalError
|
||||
|
||||
|
||||
class CollectionDistancesNotFoundError(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "No distances found between the query and collections. It is possible that the given collection names don't exist.",
|
||||
name: str = "CollectionDistancesNotFoundError",
|
||||
status_code: int = status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class SearchTypeNotSupported(CogneeApiError):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from collections import Counter
|
|||
import string
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
|
|
@ -76,10 +75,7 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
|
||||
async def get_context(self, query: str) -> str:
|
||||
"""Retrieves and resolves graph triplets into context."""
|
||||
try:
|
||||
triplets = await self.get_triplets(query)
|
||||
except EntityNotFoundError:
|
||||
return ""
|
||||
triplets = await self.get_triplets(query)
|
||||
|
||||
if len(triplets) == 0:
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from typing import List, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, ERROR
|
||||
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
|
||||
|
||||
logger = get_logger(level=ERROR)
|
||||
|
||||
|
|
@ -62,11 +63,14 @@ async def get_memory_fragment(
|
|||
if properties_to_project is None:
|
||||
properties_to_project = ["id", "description", "name", "type", "text"]
|
||||
|
||||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
)
|
||||
try:
|
||||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
edge_properties_to_project=["relationship_name"],
|
||||
)
|
||||
except EntityNotFoundError:
|
||||
pass
|
||||
|
||||
return memory_fragment
|
||||
|
||||
|
|
@ -139,16 +143,21 @@ async def brute_force_search(
|
|||
|
||||
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
||||
|
||||
async def search_in_collection(collection_name: str):
|
||||
try:
|
||||
return await vector_engine.search(
|
||||
collection_name=collection_name, query_text=query, limit=0
|
||||
)
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
vector_engine.get_distance_from_collection_elements(collection, query_text=query)
|
||||
for collection in collections
|
||||
]
|
||||
*[search_in_collection(collection_name) for collection_name in collections]
|
||||
)
|
||||
|
||||
if all(not item for item in results):
|
||||
raise CollectionDistancesNotFoundError()
|
||||
return []
|
||||
|
||||
node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
|
||||
|
|
@ -161,6 +170,8 @@ async def brute_force_search(
|
|||
|
||||
return results
|
||||
|
||||
except CollectionNotFoundError:
|
||||
return []
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error during brute force search for user: %s, query: %s. Error: %s",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.llm import get_llm_config
|
||||
|
|
@ -20,8 +21,8 @@ class LLMConfig(BaseModel):
|
|||
api_key: str
|
||||
model: str
|
||||
provider: str
|
||||
endpoint: str
|
||||
api_version: str
|
||||
endpoint: Optional[str]
|
||||
api_version: Optional[str]
|
||||
models: dict[str, list[ConfigChoice]]
|
||||
providers: list[ConfigChoice]
|
||||
|
||||
|
|
|
|||
|
|
@ -350,11 +350,3 @@ class ChunkSummaries(BaseModel):
|
|||
"""Relevant summary and chunk id"""
|
||||
|
||||
summaries: List[ChunkSummary]
|
||||
|
||||
|
||||
class MonitoringTool(str, Enum):
|
||||
"""Monitoring tools"""
|
||||
|
||||
LANGFUSE = "langfuse"
|
||||
LLMLITE = "llmlite"
|
||||
LANGSMITH = "langsmith"
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ def setup_logging(log_level=None, name=None):
|
|||
root_logger.addHandler(file_handler)
|
||||
root_logger.setLevel(log_level)
|
||||
|
||||
if log_level > logging.WARNING:
|
||||
if log_level > logging.DEBUG:
|
||||
import warnings
|
||||
from sqlalchemy.exc import SAWarning
|
||||
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ async def ingest_data(
|
|||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
dataset = await create_dataset(dataset_name, user.id, session)
|
||||
dataset = await create_dataset(dataset_name, user, session)
|
||||
|
||||
# Check to see if data should be updated
|
||||
data_point = (
|
||||
|
|
|
|||
|
|
@ -20,20 +20,23 @@ async def index_and_transform_graphiti_nodes_and_edges():
|
|||
logger.error("Failed to initialize engines: %s", e)
|
||||
raise RuntimeError("Initialization error") from e
|
||||
|
||||
await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""")
|
||||
await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""", params={})
|
||||
await graph_engine.query(
|
||||
"""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
|
||||
r.target_node_id = target.id,
|
||||
r.relationship_name = type(r) RETURN r"""
|
||||
r.relationship_name = type(r) RETURN r""",
|
||||
params={},
|
||||
)
|
||||
await graph_engine.query(
|
||||
"""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""", params={}
|
||||
)
|
||||
await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""")
|
||||
|
||||
nodes_data, edges_data = await graph_engine.get_model_independent_graph_data()
|
||||
nodes_data, edges_data = await graph_engine.get_graph_data()
|
||||
|
||||
for node_data in nodes_data[0]["nodes"]:
|
||||
for node_id, node_data in nodes_data:
|
||||
graphiti_node = GraphitiNode(
|
||||
**{key: node_data[key] for key in ("content", "name", "summary") if key in node_data},
|
||||
id=node_data.get("uuid"),
|
||||
id=node_id,
|
||||
)
|
||||
|
||||
data_point_type = type(graphiti_node)
|
||||
|
|
@ -58,9 +61,8 @@ async def index_and_transform_graphiti_nodes_and_edges():
|
|||
await vector_engine.index_data_points(index_name, field_name, indexable_points)
|
||||
|
||||
edge_types = Counter(
|
||||
edge[1][1]
|
||||
for edge in edges_data[0]["elements"]
|
||||
if isinstance(edge, list) and len(edge) == 3
|
||||
edge[2] # The edge key (relationship name) is at index 2
|
||||
for edge in edges_data
|
||||
)
|
||||
|
||||
for text, count in edge_types.items():
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="session")
|
||||
def copy_cognee_db_to_target_location():
|
||||
os.makedirs("cognee/.cognee_system/databases/", exist_ok=True)
|
||||
os.system(
|
||||
"cp cognee/tests/integration/run_toy_tasks/data/cognee_db cognee/.cognee_system/databases/cognee_db"
|
||||
)
|
||||
Binary file not shown.
107
cognee/tests/test_memgraph.py
Normal file
107
cognee/tests/test_memgraph.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import os
|
||||
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
cognee.config.set_graph_database_provider("memgraph")
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_memgraph")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_memgraph")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
dataset_name = "cs_explanations"
|
||||
|
||||
explanation_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
|
||||
)
|
||||
await cognee.add([explanation_file_path], dataset_name)
|
||||
|
||||
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
|
||||
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
|
||||
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
|
||||
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
|
||||
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
|
||||
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
|
||||
"""
|
||||
|
||||
await cognee.add([text], dataset_name)
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.INSIGHTS, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=random_node_name)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted chunks are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.SUMMARIES, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||
print("\nExtracted results are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.NATURAL_LANGUAGE,
|
||||
query_text=f"Find nodes connected to node with name {random_node_name}",
|
||||
)
|
||||
assert len(search_results) != 0, "Query related natural language don't exist."
|
||||
print("\nExtracted results are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
user = await get_default_user()
|
||||
history = await get_history(user.id)
|
||||
|
||||
assert len(history) == 8, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 0 and len(edges) == 0, "Memgraph graph database is not empty"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -74,19 +74,20 @@ async def main():
|
|||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.NATURAL_LANGUAGE,
|
||||
query_text=f"Find nodes connected to node with name {random_node_name}",
|
||||
)
|
||||
assert len(search_results) != 0, "Query related natural language don't exist."
|
||||
print("\nExtracted results are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
# NOTE: Due to the test failing often on weak LLM models we've removed this test for now
|
||||
# search_results = await cognee.search(
|
||||
# query_type=SearchType.NATURAL_LANGUAGE,
|
||||
# query_text=f"Find nodes connected to node with name {random_node_name}",
|
||||
# )
|
||||
# assert len(search_results) != 0, "Query related natural language don't exist."
|
||||
# print("\nExtracted results are:\n")
|
||||
# for result in search_results:
|
||||
# print(f"{result}\n")
|
||||
|
||||
user = await get_default_user()
|
||||
history = await get_history(user.id)
|
||||
|
||||
assert len(history) == 8, "Search history is not correct."
|
||||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ async def main():
|
|||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
collections = get_vector_engine().client.collections.list_all()
|
||||
collections = await get_vector_engine().client.collections.list_all()
|
||||
assert len(collections) == 0, "Weaviate vector database is not empty"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -48,3 +48,7 @@ async def run_and_check_tasks():
|
|||
|
||||
def test_run_tasks():
|
||||
asyncio.run(run_and_check_tasks())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_run_tasks()
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
import asyncio
|
||||
|
||||
import cognee
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
|
||||
async def run_and_check_tasks():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
def task_1(num, context):
|
||||
return num + context
|
||||
|
||||
def task_2(num):
|
||||
return num * 2
|
||||
|
||||
def task_3(num, context):
|
||||
return num**context
|
||||
|
||||
await create_db_and_tables()
|
||||
user = await get_default_user()
|
||||
|
||||
pipeline = run_tasks_base(
|
||||
[
|
||||
Task(task_1),
|
||||
Task(task_2),
|
||||
Task(task_3),
|
||||
],
|
||||
data=5,
|
||||
user=user,
|
||||
context=7,
|
||||
)
|
||||
|
||||
final_result = 4586471424
|
||||
async for result in pipeline:
|
||||
assert result == final_result
|
||||
|
||||
|
||||
def test_run_tasks():
|
||||
asyncio.run(run_and_check_tasks())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_run_tasks()
|
||||
|
|
@ -16,11 +16,11 @@ class TestChunksRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunks_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_rag_context"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunks_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -73,11 +73,11 @@ class TestChunksRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -162,11 +162,11 @@ class TestChunksRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_empty"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_empty"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -190,6 +190,9 @@ if __name__ == "__main__":
|
|||
|
||||
test = TestChunksRetriever()
|
||||
|
||||
run(test.test_chunk_context_simple())
|
||||
run(test.test_chunk_context_complex())
|
||||
run(test.test_chunk_context_on_empty_graph())
|
||||
async def main():
|
||||
await test.test_chunk_context_simple()
|
||||
await test.test_chunk_context_complex()
|
||||
await test.test_chunk_context_on_empty_graph()
|
||||
|
||||
run(main())
|
||||
|
|
|
|||
|
|
@ -154,6 +154,9 @@ if __name__ == "__main__":
|
|||
|
||||
test = TestGraphCompletionRetriever()
|
||||
|
||||
run(test.test_graph_completion_context_simple())
|
||||
run(test.test_graph_completion_context_complex())
|
||||
run(test.test_get_graph_completion_context_on_empty_graph())
|
||||
async def main():
|
||||
await test.test_graph_completion_context_simple()
|
||||
await test.test_graph_completion_context_complex()
|
||||
await test.test_get_graph_completion_context_on_empty_graph()
|
||||
|
||||
run(main())
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ class TextSummariesRetriever:
|
|||
|
||||
await add_data_points(entities)
|
||||
|
||||
retriever = SummariesRetriever(limit=20)
|
||||
retriever = SummariesRetriever(top_k=20)
|
||||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,44 +0,0 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.retrieval.exceptions import CollectionDistancesNotFoundError
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import (
|
||||
brute_force_search,
|
||||
brute_force_triplet_search,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
|
||||
async def test_brute_force_search_collection_not_found(mock_get_vector_engine):
|
||||
user = User(id="test_user")
|
||||
query = "test query"
|
||||
collections = ["nonexistent_collection"]
|
||||
top_k = 5
|
||||
mock_memory_fragment = AsyncMock()
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
with pytest.raises(CollectionDistancesNotFoundError):
|
||||
await brute_force_search(
|
||||
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine")
|
||||
async def test_brute_force_triplet_search_collection_not_found(mock_get_vector_engine):
|
||||
user = User(id="test_user")
|
||||
query = "test query"
|
||||
collections = ["nonexistent_collection"]
|
||||
top_k = 5
|
||||
mock_memory_fragment = AsyncMock()
|
||||
mock_vector_engine = AsyncMock()
|
||||
mock_vector_engine.get_distance_from_collection_elements.return_value = []
|
||||
mock_get_vector_engine.return_value = mock_vector_engine
|
||||
|
||||
with pytest.raises(CollectionDistancesNotFoundError):
|
||||
await brute_force_triplet_search(
|
||||
query, user, top_k, collections=collections, memory_fragment=mock_memory_fragment
|
||||
)
|
||||
|
|
@ -13,7 +13,7 @@ echo "Environment: $ENVIRONMENT"
|
|||
# inconsistencies and should cause the startup to fail. This check allows for
|
||||
# smooth redeployments and container restarts while maintaining data integrity.
|
||||
echo "Running database migrations..."
|
||||
MIGRATION_OUTPUT=$(poetry run alembic upgrade head 2>&1) || {
|
||||
MIGRATION_OUTPUT=$(alembic upgrade head 2>&1) || {
|
||||
if [[ $MIGRATION_OUTPUT == *"UserAlreadyExists"* ]] || [[ $MIGRATION_OUTPUT == *"User default_user@example.com already exists"* ]]; then
|
||||
echo "Warning: Default user already exists, continuing startup..."
|
||||
else
|
||||
|
|
@ -22,8 +22,9 @@ MIGRATION_OUTPUT=$(poetry run alembic upgrade head 2>&1) || {
|
|||
exit 1
|
||||
fi
|
||||
}
|
||||
echo "Database migrations done."
|
||||
|
||||
echo "Starting Gunicorn"
|
||||
echo "Starting server..."
|
||||
|
||||
# Add startup delay to ensure DB is ready
|
||||
sleep 2
|
||||
|
|
@ -32,10 +33,10 @@ sleep 2
|
|||
if [ "$ENVIRONMENT" = "dev" ] || [ "$ENVIRONMENT" = "local" ]; then
|
||||
if [ "$DEBUG" = "true" ]; then
|
||||
echo "Waiting for the debugger to attach..."
|
||||
exec python -m debugpy --wait-for-client --listen 0.0.0.0:5678 -m gunicorn -w 3 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:8000 --log-level debug --reload cognee.api.client:app
|
||||
debugpy --wait-for-client --listen 0.0.0.0:5678 -m gunicorn -w 3 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:8000 --log-level debug --reload cognee.api.client:app
|
||||
else
|
||||
exec gunicorn -w 3 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:8000 --log-level debug --reload cognee.api.client:app
|
||||
gunicorn -w 3 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:8000 --log-level debug --reload cognee.api.client:app
|
||||
fi
|
||||
else
|
||||
exec gunicorn -w 3 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:8000 --log-level error cognee.api.client:app
|
||||
gunicorn -w 3 -k uvicorn.workers.UvicornWorker -t 30000 --bind=0.0.0.0:8000 --log-level error cognee.api.client:app
|
||||
fi
|
||||
|
|
|
|||
37
examples/data/car_and_tech_companies.txt
Normal file
37
examples/data/car_and_tech_companies.txt
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
text_1 = """
|
||||
1. Audi
|
||||
Audi is known for its modern designs and advanced technology. Founded in the early 1900s, the brand has earned a reputation for precision engineering and innovation. With features like the Quattro all-wheel-drive system, Audi offers a range of vehicles from stylish sedans to high-performance sports cars.
|
||||
|
||||
2. BMW
|
||||
BMW, short for Bayerische Motoren Werke, is celebrated for its focus on performance and driving pleasure. The company's vehicles are designed to provide a dynamic and engaging driving experience, and their slogan, "The Ultimate Driving Machine," reflects that commitment. BMW produces a variety of cars that combine luxury with sporty performance.
|
||||
|
||||
3. Mercedes-Benz
|
||||
Mercedes-Benz is synonymous with luxury and quality. With a history dating back to the early 20th century, the brand is known for its elegant designs, innovative safety features, and high-quality engineering. Mercedes-Benz manufactures not only luxury sedans but also SUVs, sports cars, and commercial vehicles, catering to a wide range of needs.
|
||||
|
||||
4. Porsche
|
||||
Porsche is a name that stands for high-performance sports cars. Founded in 1931, the brand has become famous for models like the iconic Porsche 911. Porsche cars are celebrated for their speed, precision, and distinctive design, appealing to car enthusiasts who value both performance and style.
|
||||
|
||||
5. Volkswagen
|
||||
Volkswagen, which means "people's car" in German, was established with the idea of making affordable and reliable vehicles accessible to everyone. Over the years, Volkswagen has produced several iconic models, such as the Beetle and the Golf. Today, it remains one of the largest car manufacturers in the world, offering a wide range of vehicles that balance practicality with quality.
|
||||
|
||||
Each of these car manufacturer contributes to Germany's reputation as a leader in the global automotive industry, showcasing a blend of innovation, performance, and design excellence.
|
||||
"""
|
||||
|
||||
text_2 = """
|
||||
1. Apple
|
||||
Apple is renowned for its innovative consumer electronics and software. Its product lineup includes the iPhone, iPad, Mac computers, and wearables like the Apple Watch. Known for its emphasis on sleek design and user-friendly interfaces, Apple has built a loyal customer base and created a seamless ecosystem that integrates hardware, software, and services.
|
||||
|
||||
2. Google
|
||||
Founded in 1998, Google started as a search engine and quickly became the go-to resource for finding information online. Over the years, the company has diversified its offerings to include digital advertising, cloud computing, mobile operating systems (Android), and various web services like Gmail and Google Maps. Google's innovations have played a major role in shaping the internet landscape.
|
||||
|
||||
3. Microsoft
|
||||
Microsoft Corporation has been a dominant force in software for decades. Its Windows operating system and Microsoft Office suite are staples in both business and personal computing. In recent years, Microsoft has expanded into cloud computing with Azure, gaming with the Xbox platform, and even hardware through products like the Surface line. This evolution has helped the company maintain its relevance in a rapidly changing tech world.
|
||||
|
||||
4. Amazon
|
||||
What began as an online bookstore has grown into one of the largest e-commerce platforms globally. Amazon is known for its vast online marketplace, but its influence extends far beyond retail. With Amazon Web Services (AWS), the company has become a leader in cloud computing, offering robust solutions that power websites, applications, and businesses around the world. Amazon's constant drive for innovation continues to reshape both retail and technology sectors.
|
||||
|
||||
5. Meta
|
||||
Meta, originally known as Facebook, revolutionized social media by connecting billions of people worldwide. Beyond its core social networking service, Meta is investing in the next generation of digital experiences through virtual and augmented reality technologies, with projects like Oculus. The company's efforts signal a commitment to evolving digital interaction and building the metaverse—a shared virtual space where users can connect and collaborate.
|
||||
|
||||
Each of these companies has significantly impacted the technology landscape, driving innovation and transforming everyday life through their groundbreaking products and services.
|
||||
"""
|
||||
|
|
@ -14,6 +14,7 @@ from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_forc
|
|||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
text_list = [
|
||||
"Kamala Harris is the Attorney General of California. She was previously "
|
||||
|
|
@ -27,6 +28,9 @@ async def main():
|
|||
await cognee.prune.prune_system(metadata=True)
|
||||
await create_relational_db_and_tables()
|
||||
|
||||
# Initialize default user
|
||||
user = await get_default_user()
|
||||
|
||||
for text in text_list:
|
||||
await cognee.add(text)
|
||||
|
||||
|
|
@ -34,7 +38,7 @@ async def main():
|
|||
Task(build_graph_with_temporal_awareness, text_list=text_list),
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks)
|
||||
pipeline = run_tasks(tasks, user=user)
|
||||
|
||||
async for result in pipeline:
|
||||
print(result)
|
||||
|
|
|
|||
|
|
@ -467,7 +467,7 @@
|
|||
"from cognee.modules.data.models import Dataset, Data\n",
|
||||
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
|
||||
"from cognee.modules.cognify.config import get_cognify_config\n",
|
||||
"from cognee.modules.pipelines.tasks.Task import Task\n",
|
||||
"from cognee.modules.pipelines.tasks.task import Task\n",
|
||||
"from cognee.modules.pipelines import run_tasks\n",
|
||||
"from cognee.modules.users.models import User\n",
|
||||
"from cognee.tasks.documents import (\n",
|
||||
|
|
@ -505,7 +505,7 @@
|
|||
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" pipeline_run = run_tasks(tasks, dataset.id, data_documents, \"cognify_pipeline\")\n",
|
||||
" pipeline_run = run_tasks(tasks, dataset.id, data_documents, user, \"cognify_pipeline\")\n",
|
||||
" pipeline_run_status = None\n",
|
||||
"\n",
|
||||
" async for run_status in pipeline_run:\n",
|
||||
|
|
@ -529,8 +529,11 @@
|
|||
"source": [
|
||||
"from cognee.modules.users.methods import get_default_user\n",
|
||||
"from cognee.modules.data.methods import get_datasets_by_name\n",
|
||||
"from cognee.modules.users.methods import get_user\n",
|
||||
"\n",
|
||||
"user = await get_default_user()\n",
|
||||
"default_user = await get_default_user()\n",
|
||||
"\n",
|
||||
"user = await get_user(default_user.id)\n",
|
||||
"\n",
|
||||
"datasets = await get_datasets_by_name([\"example\"], user.id)\n",
|
||||
"\n",
|
||||
|
|
@ -604,39 +607,6 @@
|
|||
"visualization_server(port=8002)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "765bc42a143e98af",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-02-09T21:46:07.783693Z",
|
||||
"start_time": "2025-02-09T21:46:07.780709Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e1382358-433c-4cd0-8535-9e103f821034",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6332d5bc-882f-49d5-8496-582e3954567a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from IPython.display import IFrame, display, HTML\n",
|
||||
"\n",
|
||||
"IFrame(\"http://127.0.0.1:8002/.artifacts/graph_visualization.html\", width=800, height=600)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
|
@ -837,14 +807,6 @@
|
|||
"### Give us a star if you like it!\n",
|
||||
"https://github.com/topoteretes/cognee"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3c081f2d53512199",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
}
|
||||
},
|
||||
"source": [
|
||||
"First we import the necessary libaries"
|
||||
"First we import the necessary libraries"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -24,9 +24,10 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"\n",
|
||||
"import cognee\n",
|
||||
"from cognee.shared.logging_utils import get_logger, ERROR\n",
|
||||
"import warnings\n",
|
||||
"from cognee.modules.pipelines import Task, run_tasks\n",
|
||||
"from cognee.tasks.temporal_awareness import build_graph_with_temporal_awareness\n",
|
||||
"from cognee.infrastructure.databases.relational import (\n",
|
||||
|
|
@ -38,7 +39,8 @@
|
|||
"from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search\n",
|
||||
"from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever\n",
|
||||
"from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt\n",
|
||||
"from cognee.infrastructure.llm.get_llm_client import get_llm_client"
|
||||
"from cognee.infrastructure.llm.get_llm_client import get_llm_client\n",
|
||||
"from cognee.modules.users.methods import get_default_user"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -126,33 +128,25 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 🔧 Setting Up Logging to Suppress Errors\n",
|
||||
"logger = get_logger(level=ERROR) # Keeping logs clean and focused\n",
|
||||
"\n",
|
||||
"# 🧹 Pruning Old Data and Metadata\n",
|
||||
"await cognee.prune.prune_data() # Removing outdated data\n",
|
||||
"await cognee.prune.prune_data()\n",
|
||||
"await cognee.prune.prune_system(metadata=True)\n",
|
||||
"\n",
|
||||
"# 🏗️ Creating Relational Database and Tables\n",
|
||||
"await create_relational_db_and_tables()\n",
|
||||
"\n",
|
||||
"# 📚 Adding Text Data to Cognee\n",
|
||||
"# Initialize default user\n",
|
||||
"user = await get_default_user()\n",
|
||||
"\n",
|
||||
"for text in text_list:\n",
|
||||
" await cognee.add(text)\n",
|
||||
"\n",
|
||||
"# 🕰️ Building Temporal-Aware Graphs\n",
|
||||
"tasks = [\n",
|
||||
" Task(build_graph_with_temporal_awareness, text_list=text_list),\n",
|
||||
"]\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"# 🚀 Running the Task Pipeline\n",
|
||||
"pipeline = run_tasks(tasks)\n",
|
||||
"pipeline = run_tasks(tasks, user=user)\n",
|
||||
"\n",
|
||||
"# 🌟 Processing Pipeline Results\n",
|
||||
"async for result in pipeline:\n",
|
||||
" print(f\"✅ Result Processed: {result}\")\n",
|
||||
" print(result)\n",
|
||||
"\n",
|
||||
"# 🔄 Indexing and Transforming Graph Data\n",
|
||||
"await index_and_transform_graphiti_nodes_and_edges()"
|
||||
]
|
||||
},
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@
|
|||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": "%pip install llama-index-core\n"
|
||||
"source": [
|
||||
"%pip install llama-index-core\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -116,10 +118,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Union, BinaryIO\n",
|
||||
"\n",
|
||||
|
|
|
|||
|
|
@ -10,11 +10,13 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "982b897a29a26f7d",
|
||||
"metadata": {},
|
||||
"source": "!pip install cognee==0.1.36",
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
"source": [
|
||||
"!pip install cognee==0.1.39"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -28,15 +30,15 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "initial_id",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"LLM_API_KEY\"] = \"\""
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -48,14 +50,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5805c346f03d8070",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"current_directory = os.getcwd()\n",
|
||||
"file_path = os.path.join(current_directory, \"data\", \"alice_in_wonderland.txt\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -67,15 +69,15 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "875763366723ee48",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cognee\n",
|
||||
"await cognee.add(file_path)\n",
|
||||
"await cognee.cognify()"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -87,33 +89,33 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "29b3a1e3279100d2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await cognee.search(\"List me all the influential characters in Alice in Wonderland.\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "883ce50d2d9dc584",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await cognee.search(\"How did Alice end up in Wonderland?\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "677e1bc52aa078b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await cognee.search(\"Tell me about Alice's personality.\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -125,8 +127,10 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6effdae590b795d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import webbrowser\n",
|
||||
"import os\n",
|
||||
|
|
@ -136,9 +140,7 @@
|
|||
"html_file = os.path.join(home_dir, \"graph_visualization.html\")\n",
|
||||
"display(html_file)\n",
|
||||
"webbrowser.open(f\"file://{html_file}\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
|
|||
|
|
@ -1,128 +0,0 @@
|
|||
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<script src="https://d3js.org/d3.v5.min.js"></script>
|
||||
<style>
|
||||
body, html { margin: 0; padding: 0; width: 100%; height: 100%; overflow: hidden; background: linear-gradient(90deg, #101010, #1a1a2e); color: white; font-family: 'Inter', sans-serif; }
|
||||
|
||||
svg { width: 100vw; height: 100vh; display: block; }
|
||||
.links line { stroke: rgba(255, 255, 255, 0.4); stroke-width: 2px; }
|
||||
.nodes circle { stroke: white; stroke-width: 0.5px; filter: drop-shadow(0 0 5px rgba(255,255,255,0.3)); }
|
||||
.node-label { font-size: 5px; font-weight: bold; fill: white; text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
|
||||
.edge-label { font-size: 3px; fill: rgba(255, 255, 255, 0.7); text-anchor: middle; dominant-baseline: middle; font-family: 'Inter', sans-serif; pointer-events: none; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<svg></svg>
|
||||
<script>
|
||||
var nodes = [];
|
||||
var links = [];
|
||||
|
||||
var svg = d3.select("svg"),
|
||||
width = window.innerWidth,
|
||||
height = window.innerHeight;
|
||||
|
||||
var container = svg.append("g");
|
||||
|
||||
var simulation = d3.forceSimulation(nodes)
|
||||
.force("link", d3.forceLink(links).id(d => d.id).strength(0.1))
|
||||
.force("charge", d3.forceManyBody().strength(-275))
|
||||
.force("center", d3.forceCenter(width / 2, height / 2))
|
||||
.force("x", d3.forceX().strength(0.1).x(width / 2))
|
||||
.force("y", d3.forceY().strength(0.1).y(height / 2));
|
||||
|
||||
var link = container.append("g")
|
||||
.attr("class", "links")
|
||||
.selectAll("line")
|
||||
.data(links)
|
||||
.enter().append("line")
|
||||
.attr("stroke-width", 2);
|
||||
|
||||
var edgeLabels = container.append("g")
|
||||
.attr("class", "edge-labels")
|
||||
.selectAll("text")
|
||||
.data(links)
|
||||
.enter().append("text")
|
||||
.attr("class", "edge-label")
|
||||
.text(d => d.relation);
|
||||
|
||||
var nodeGroup = container.append("g")
|
||||
.attr("class", "nodes")
|
||||
.selectAll("g")
|
||||
.data(nodes)
|
||||
.enter().append("g");
|
||||
|
||||
var node = nodeGroup.append("circle")
|
||||
.attr("r", 13)
|
||||
.attr("fill", d => d.color)
|
||||
.call(d3.drag()
|
||||
.on("start", dragstarted)
|
||||
.on("drag", dragged)
|
||||
.on("end", dragended));
|
||||
|
||||
nodeGroup.append("text")
|
||||
.attr("class", "node-label")
|
||||
.attr("dy", 4)
|
||||
.attr("text-anchor", "middle")
|
||||
.text(d => d.name);
|
||||
|
||||
node.append("title").text(d => JSON.stringify(d));
|
||||
|
||||
simulation.on("tick", function() {
|
||||
link.attr("x1", d => d.source.x)
|
||||
.attr("y1", d => d.source.y)
|
||||
.attr("x2", d => d.target.x)
|
||||
.attr("y2", d => d.target.y);
|
||||
|
||||
edgeLabels
|
||||
.attr("x", d => (d.source.x + d.target.x) / 2)
|
||||
.attr("y", d => (d.source.y + d.target.y) / 2 - 5);
|
||||
|
||||
node.attr("cx", d => d.x)
|
||||
.attr("cy", d => d.y);
|
||||
|
||||
nodeGroup.select("text")
|
||||
.attr("x", d => d.x)
|
||||
.attr("y", d => d.y)
|
||||
.attr("dy", 4)
|
||||
.attr("text-anchor", "middle");
|
||||
});
|
||||
|
||||
svg.call(d3.zoom().on("zoom", function() {
|
||||
container.attr("transform", d3.event.transform);
|
||||
}));
|
||||
|
||||
function dragstarted(d) {
|
||||
if (!d3.event.active) simulation.alphaTarget(0.3).restart();
|
||||
d.fx = d.x;
|
||||
d.fy = d.y;
|
||||
}
|
||||
|
||||
function dragged(d) {
|
||||
d.fx = d3.event.x;
|
||||
d.fy = d3.event.y;
|
||||
}
|
||||
|
||||
function dragended(d) {
|
||||
if (!d3.event.active) simulation.alphaTarget(0);
|
||||
d.fx = null;
|
||||
d.fy = null;
|
||||
}
|
||||
|
||||
window.addEventListener("resize", function() {
|
||||
width = window.innerWidth;
|
||||
height = window.innerHeight;
|
||||
svg.attr("width", width).attr("height", height);
|
||||
simulation.force("center", d3.forceCenter(width / 2, height / 2));
|
||||
simulation.alpha(1).restart();
|
||||
});
|
||||
</script>
|
||||
|
||||
<svg style="position: fixed; bottom: 10px; right: 10px; width: 150px; height: auto; z-index: 9999;" viewBox="0 0 158 44" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.7496 4.92654C7.83308 4.92654 4.8585 7.94279 4.8585 11.3612V14.9304C4.8585 18.3488 7.83308 21.3651 11.7496 21.3651C13.6831 21.3651 15.0217 20.8121 16.9551 19.3543C18.0458 18.5499 19.5331 18.8013 20.3263 19.9072C21.1195 21.0132 20.8717 22.5213 19.781 23.3257C17.3518 25.0851 15.0217 26.2414 11.7 26.2414C5.35425 26.2414 0 21.2646 0 14.9304V11.3612C0 4.97681 5.35425 0.0502739 11.7 0.0502739C15.0217 0.0502739 17.3518 1.2065 19.781 2.96598C20.8717 3.77032 21.1195 5.27843 20.3263 6.38439C19.5331 7.49035 18.0458 7.69144 16.9551 6.93737C15.0217 5.52979 13.6831 4.92654 11.7496 4.92654ZM35.5463 4.92654C31.7289 4.92654 28.6552 8.04333 28.6552 11.8639V14.478C28.6552 18.2986 31.7289 21.4154 35.5463 21.4154C39.3141 21.4154 42.3878 18.2986 42.3878 14.478V11.8639C42.3878 8.04333 39.3141 4.92654 35.5463 4.92654ZM23.7967 11.8639C23.7967 5.32871 29.0518 0 35.5463 0C42.0408 0 47.2463 5.32871 47.2463 11.8639V14.478C47.2463 21.0132 42.0408 26.3419 35.5463 26.3419C29.0518 26.3419 23.7967 21.0635 23.7967 14.478V11.8639ZM63.3091 5.07736C59.4917 5.07736 56.418 8.19415 56.418 12.0147C56.418 15.8353 59.4917 18.9521 63.3091 18.9521C67.1265 18.9521 70.1506 15.8856 70.1506 12.0147C70.1506 8.14388 67.0769 5.07736 63.3091 5.07736ZM51.5595 11.9645C51.5595 5.42925 56.8146 0.150814 63.3091 0.150814C66.0854 0.150814 68.5642 1.10596 70.5968 2.71463L72.4311 0.904876C73.3731 -0.0502693 74.9099 -0.0502693 75.8519 0.904876C76.7938 1.86002 76.7938 3.41841 75.8519 4.37356L73.7201 6.53521C74.5629 8.19414 75.0587 10.0542 75.0587 12.0147C75.0587 18.4997 69.8532 23.8284 63.3587 23.8284C63.3091 23.8284 63.2099 23.8284 63.1603 23.8284H58.0044C57.1616 23.8284 56.4675 24.5322 56.4675 25.3868C56.4675 26.2414 57.1616 26.9452 58.0044 26.9452H64.6476H66.7794C68.5146 26.9452 70.3489 27.4479 71.7866 28.6041C73.2739 29.8106 74.2159 31.5701 74.4142 33.7317C74.7116 37.6026 72.0345 40.2166 69.8532 41.0713L63.8048 43.7859C62.5654 44.3389 61.1277 43.7859 60.6319 42.5291C60.0866 41.2723 60.6319 39.8648 61.8714 39.3118L68.0188 36.5972C68.0684 36.5972 68.118 36.5469 68.1675 36.5469C68.4154 36.4463 68.8616 36.1447 69.2087 35.6923C69.5061 35.2398 69.7044 34.7371 69.6548 34.1339C69.6053 33.229 69.2582 32.7263 68.8616 32.4247C68.4154 32.0728 67.7214 31.8214 66.8786 31.8214H58.2027C58.1531 31.8214 58.1531 31.8214 58.1035 31.8214H58.054C54.534 31.8214 51.6586 28.956 51.6586 25.3868C51.6586 23.0743 52.8485 21.0635 54.6828 19.9072C52.6997 17.7959 51.5595 15.031 51.5595 11.9645ZM90.8736 5.07736C87.0562 5.07736 83.9824 8.19415 83.9824 12.0147V23.9289C83.9824 25.2862 82.8917 26.3922 81.5532 26.3922C80.2146 26.3922 79.1239 25.2862 79.1239 23.9289V11.9645C79.1239 5.42925 84.379 0.150814 90.824 0.150814C97.2689 0.150814 102.524 5.42925 102.524 11.9645V23.8786C102.524 25.2359 101.433 26.3419 100.095 26.3419C98.7562 26.3419 97.6655 25.2359 97.6655 23.8786V11.9645C97.7647 8.14387 94.6414 5.07736 90.8736 5.07736ZM119.43 5.07736C115.513 5.07736 112.39 8.24441 112.39 12.065V14.5785C112.39 18.4494 115.513 21.5662 119.43 21.5662C120.768 21.5662 122.057 21.164 123.098 20.5105C124.238 19.8067 125.726 20.1586 126.42 21.3148C127.114 22.4711 126.767 23.9792 125.627 24.683C123.842 25.7889 121.71 26.4425 119.43 26.4425C112.885 26.4425 107.581 21.1137 107.581 14.5785V12.065C107.581 5.47952 112.935 0.201088 119.43 0.201088C125.032 0.201088 129.692 4.07194 130.931 9.3001L131.427 11.3612L121.115 15.584C119.876 16.0867 118.488 15.4834 117.942 14.2266C117.447 12.9699 118.041 11.5623 119.281 11.0596L125.478 8.54604C124.238 6.43466 122.008 5.07736 119.43 5.07736ZM146.003 5.07736C142.086 5.07736 138.963 8.24441 138.963 12.065V14.5785C138.963 18.4494 142.086 21.5662 146.003 21.5662C147.341 21.5662 148.63 21.164 149.671 20.5105C150.217 20.1586 150.663 19.8067 151.109 19.304C152.001 18.2986 153.538 18.2483 154.53 19.2034C155.521 20.1083 155.571 21.6667 154.629 22.6721C153.935 23.4262 153.092 24.13 152.2 24.683C150.415 25.7889 148.283 26.4425 146.003 26.4425C139.458 26.4425 134.154 21.1137 134.154 14.5785V12.065C134.154 5.47952 139.508 0.201088 146.003 0.201088C151.605 0.201088 156.265 4.07194 157.504 9.3001L158 11.3612L147.688 15.584C146.449 16.0867 145.061 15.4834 144.515 14.2266C144.019 12.9699 144.614 11.5623 145.854 11.0596L152.051 8.54604C150.762 6.43466 148.58 5.07736 146.003 5.07736Z" fill="white"/>
|
||||
</svg>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Cognee GraphRAG\n",
|
||||
"\n",
|
||||
|
|
@ -48,15 +48,19 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"source": "!pip install cognee==0.1.24",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
"source": [
|
||||
"!pip install cognee==0.1.39"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import cognee\n",
|
||||
|
|
@ -66,13 +70,11 @@
|
|||
"\n",
|
||||
"if \"OPENAI_API_KEY\" not in os.environ:\n",
|
||||
" os.environ[\"OPENAI_API_KEY\"] = \"\""
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Ensure you’ve set up your API keys and installed necessary dependencies.\n",
|
||||
"\n",
|
||||
|
|
@ -82,19 +84,19 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"documents = [\"Jessica Miller, Experienced Sales Manager with a strong track record in building high-performing teams.\",\n",
|
||||
" \"David Thompson, Creative Graphic Designer with over 8 years of experience in visual design and branding.\"\n",
|
||||
" ]"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 3. Adding Data to Cognee\n",
|
||||
"\n",
|
||||
|
|
@ -102,15 +104,17 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"source": "await cognee.add(documents)",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
"source": [
|
||||
"await cognee.add(documents)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This step prepares the data for graph-based processing.\n",
|
||||
"\n",
|
||||
|
|
@ -120,15 +124,17 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"source": "await cognee.cognify()",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
"source": [
|
||||
"await cognee.cognify()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The graph now contains nodes and relationships derived from the dataset, creating a powerful structure for exploration.\n",
|
||||
"\n",
|
||||
|
|
@ -138,45 +144,49 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cognee.modules.search.types import SearchType\n",
|
||||
"search_results = await cognee.search(SearchType.GRAPH_COMPLETION, \"Tell me who are the people mentioned?\")\n",
|
||||
"from cognee.api.v1.search import SearchType\n",
|
||||
"search_results = await cognee.search(query_type=SearchType.GRAPH_COMPLETION, query_text=\"Tell me who are the people mentioned?\")\n",
|
||||
"\n",
|
||||
"print(\"\\n\\nAnswer based on knowledge graph:\\n\")\n",
|
||||
"for result in search_results:\n",
|
||||
" print(f\"{result}\\n\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"source": "### Answer prompt based on RAG approach:"
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Answer prompt based on RAG approach:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search_results = await cognee.search(SearchType.COMPLETION, \"Tell me who are the people mentioned?\")\n",
|
||||
"search_results = await cognee.search(query_type=SearchType.RAG_COMPLETION, query_text=\"Tell me who are the people mentioned?\")\n",
|
||||
"\n",
|
||||
"print(\"\\n\\nAnswer based on RAG:\\n\")\n",
|
||||
"for result in search_results:\n",
|
||||
" print(f\"{result}\\n\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"source": "In conclusion, the results demonstrate a significant advantage of the knowledge graph-based approach (Graphrag) over the RAG approach. Graphrag successfully identified all the mentioned individuals across multiple documents, showcasing its ability to aggregate and infer information from a global context. In contrast, the RAG approach was limited to identifying individuals within a single document due to its chunking-based processing constraints. This highlights Graphrag's superior capability in comprehensively resolving queries that span across a broader corpus of interconnected data."
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In conclusion, the results demonstrate a significant advantage of the knowledge graph-based approach (Graphrag) over the RAG approach. Graphrag successfully identified all the mentioned individuals across multiple documents, showcasing its ability to aggregate and infer information from a global context. In contrast, the RAG approach was limited to identifying individuals within a single document due to its chunking-based processing constraints. This highlights Graphrag's superior capability in comprehensively resolving queries that span across a broader corpus of interconnected data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 7. Finding Related Nodes\n",
|
||||
"\n",
|
||||
|
|
@ -184,21 +194,21 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"related_nodes = await cognee.search(SearchType.INSIGHTS, \"person\")\n",
|
||||
"related_nodes = await cognee.search(query_type=SearchType.INSIGHTS, query_text=\"person\")\n",
|
||||
"\n",
|
||||
"print(\"\\n\\nRelated nodes are:\\n\")\n",
|
||||
"for node in related_nodes:\n",
|
||||
" print(f\"{node}\\n\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Why Choose Cognee?\n",
|
||||
"\n",
|
||||
|
|
@ -233,9 +243,9 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"display_name": "Python 3 (ipykernel)"
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
|
|
@ -1,978 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d35ac8ce-0f92-46f5-9ba4-a46970f0ce19",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Cognee - Get Started"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "074f0ea8-c659-4736-be26-be4b0e5ac665",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Demo time"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0587d91d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### First let's define some data that we will cognify and perform a search on"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "df16431d0f48b006",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:48.519686Z",
|
||||
"start_time": "2024-09-20T14:02:48.515589Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_position = \"\"\"Senior Data Scientist (Machine Learning)\n",
|
||||
"\n",
|
||||
"Company: TechNova Solutions\n",
|
||||
"Location: San Francisco, CA\n",
|
||||
"\n",
|
||||
"Job Description:\n",
|
||||
"\n",
|
||||
"TechNova Solutions is seeking a Senior Data Scientist specializing in Machine Learning to join our dynamic analytics team. The ideal candidate will have a strong background in developing and deploying machine learning models, working with large datasets, and translating complex data into actionable insights.\n",
|
||||
"\n",
|
||||
"Responsibilities:\n",
|
||||
"\n",
|
||||
"Develop and implement advanced machine learning algorithms and models.\n",
|
||||
"Analyze large, complex datasets to extract meaningful patterns and insights.\n",
|
||||
"Collaborate with cross-functional teams to integrate predictive models into products.\n",
|
||||
"Stay updated with the latest advancements in machine learning and data science.\n",
|
||||
"Mentor junior data scientists and provide technical guidance.\n",
|
||||
"Qualifications:\n",
|
||||
"\n",
|
||||
"Master’s or Ph.D. in Data Science, Computer Science, Statistics, or a related field.\n",
|
||||
"5+ years of experience in data science and machine learning.\n",
|
||||
"Proficient in Python, R, and SQL.\n",
|
||||
"Experience with deep learning frameworks (e.g., TensorFlow, PyTorch).\n",
|
||||
"Strong problem-solving skills and attention to detail.\n",
|
||||
"Candidate CVs\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9086abf3af077ab4",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:49.120838Z",
|
||||
"start_time": "2024-09-20T14:02:49.118294Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_1 = \"\"\"\n",
|
||||
"CV 1: Relevant\n",
|
||||
"Name: Dr. Emily Carter\n",
|
||||
"Contact Information:\n",
|
||||
"\n",
|
||||
"Email: emily.carter@example.com\n",
|
||||
"Phone: (555) 123-4567\n",
|
||||
"Summary:\n",
|
||||
"\n",
|
||||
"Senior Data Scientist with over 8 years of experience in machine learning and predictive analytics. Expertise in developing advanced algorithms and deploying scalable models in production environments.\n",
|
||||
"\n",
|
||||
"Education:\n",
|
||||
"\n",
|
||||
"Ph.D. in Computer Science, Stanford University (2014)\n",
|
||||
"B.S. in Mathematics, University of California, Berkeley (2010)\n",
|
||||
"Experience:\n",
|
||||
"\n",
|
||||
"Senior Data Scientist, InnovateAI Labs (2016 – Present)\n",
|
||||
"Led a team in developing machine learning models for natural language processing applications.\n",
|
||||
"Implemented deep learning algorithms that improved prediction accuracy by 25%.\n",
|
||||
"Collaborated with cross-functional teams to integrate models into cloud-based platforms.\n",
|
||||
"Data Scientist, DataWave Analytics (2014 – 2016)\n",
|
||||
"Developed predictive models for customer segmentation and churn analysis.\n",
|
||||
"Analyzed large datasets using Hadoop and Spark frameworks.\n",
|
||||
"Skills:\n",
|
||||
"\n",
|
||||
"Programming Languages: Python, R, SQL\n",
|
||||
"Machine Learning: TensorFlow, Keras, Scikit-Learn\n",
|
||||
"Big Data Technologies: Hadoop, Spark\n",
|
||||
"Data Visualization: Tableau, Matplotlib\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a9de0cc07f798b7f",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:49.675003Z",
|
||||
"start_time": "2024-09-20T14:02:49.671615Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_2 = \"\"\"\n",
|
||||
"CV 2: Relevant\n",
|
||||
"Name: Michael Rodriguez\n",
|
||||
"Contact Information:\n",
|
||||
"\n",
|
||||
"Email: michael.rodriguez@example.com\n",
|
||||
"Phone: (555) 234-5678\n",
|
||||
"Summary:\n",
|
||||
"\n",
|
||||
"Data Scientist with a strong background in machine learning and statistical modeling. Skilled in handling large datasets and translating data into actionable business insights.\n",
|
||||
"\n",
|
||||
"Education:\n",
|
||||
"\n",
|
||||
"M.S. in Data Science, Carnegie Mellon University (2013)\n",
|
||||
"B.S. in Computer Science, University of Michigan (2011)\n",
|
||||
"Experience:\n",
|
||||
"\n",
|
||||
"Senior Data Scientist, Alpha Analytics (2017 – Present)\n",
|
||||
"Developed machine learning models to optimize marketing strategies.\n",
|
||||
"Reduced customer acquisition cost by 15% through predictive modeling.\n",
|
||||
"Data Scientist, TechInsights (2013 – 2017)\n",
|
||||
"Analyzed user behavior data to improve product features.\n",
|
||||
"Implemented A/B testing frameworks to evaluate product changes.\n",
|
||||
"Skills:\n",
|
||||
"\n",
|
||||
"Programming Languages: Python, Java, SQL\n",
|
||||
"Machine Learning: Scikit-Learn, XGBoost\n",
|
||||
"Data Visualization: Seaborn, Plotly\n",
|
||||
"Databases: MySQL, MongoDB\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "185ff1c102d06111",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:50.286828Z",
|
||||
"start_time": "2024-09-20T14:02:50.284369Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_3 = \"\"\"\n",
|
||||
"CV 3: Relevant\n",
|
||||
"Name: Sarah Nguyen\n",
|
||||
"Contact Information:\n",
|
||||
"\n",
|
||||
"Email: sarah.nguyen@example.com\n",
|
||||
"Phone: (555) 345-6789\n",
|
||||
"Summary:\n",
|
||||
"\n",
|
||||
"Data Scientist specializing in machine learning with 6 years of experience. Passionate about leveraging data to drive business solutions and improve product performance.\n",
|
||||
"\n",
|
||||
"Education:\n",
|
||||
"\n",
|
||||
"M.S. in Statistics, University of Washington (2014)\n",
|
||||
"B.S. in Applied Mathematics, University of Texas at Austin (2012)\n",
|
||||
"Experience:\n",
|
||||
"\n",
|
||||
"Data Scientist, QuantumTech (2016 – Present)\n",
|
||||
"Designed and implemented machine learning algorithms for financial forecasting.\n",
|
||||
"Improved model efficiency by 20% through algorithm optimization.\n",
|
||||
"Junior Data Scientist, DataCore Solutions (2014 – 2016)\n",
|
||||
"Assisted in developing predictive models for supply chain optimization.\n",
|
||||
"Conducted data cleaning and preprocessing on large datasets.\n",
|
||||
"Skills:\n",
|
||||
"\n",
|
||||
"Programming Languages: Python, R\n",
|
||||
"Machine Learning Frameworks: PyTorch, Scikit-Learn\n",
|
||||
"Statistical Analysis: SAS, SPSS\n",
|
||||
"Cloud Platforms: AWS, Azure\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d55ce4c58f8efb67",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:50.950343Z",
|
||||
"start_time": "2024-09-20T14:02:50.946378Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_4 = \"\"\"\n",
|
||||
"CV 4: Not Relevant\n",
|
||||
"Name: David Thompson\n",
|
||||
"Contact Information:\n",
|
||||
"\n",
|
||||
"Email: david.thompson@example.com\n",
|
||||
"Phone: (555) 456-7890\n",
|
||||
"Summary:\n",
|
||||
"\n",
|
||||
"Creative Graphic Designer with over 8 years of experience in visual design and branding. Proficient in Adobe Creative Suite and passionate about creating compelling visuals.\n",
|
||||
"\n",
|
||||
"Education:\n",
|
||||
"\n",
|
||||
"B.F.A. in Graphic Design, Rhode Island School of Design (2012)\n",
|
||||
"Experience:\n",
|
||||
"\n",
|
||||
"Senior Graphic Designer, CreativeWorks Agency (2015 – Present)\n",
|
||||
"Led design projects for clients in various industries.\n",
|
||||
"Created branding materials that increased client engagement by 30%.\n",
|
||||
"Graphic Designer, Visual Innovations (2012 – 2015)\n",
|
||||
"Designed marketing collateral, including brochures, logos, and websites.\n",
|
||||
"Collaborated with the marketing team to develop cohesive brand strategies.\n",
|
||||
"Skills:\n",
|
||||
"\n",
|
||||
"Design Software: Adobe Photoshop, Illustrator, InDesign\n",
|
||||
"Web Design: HTML, CSS\n",
|
||||
"Specialties: Branding and Identity, Typography\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ca4ecc32721ad332",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:51.548191Z",
|
||||
"start_time": "2024-09-20T14:02:51.545520Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_5 = \"\"\"\n",
|
||||
"CV 5: Not Relevant\n",
|
||||
"Name: Jessica Miller\n",
|
||||
"Contact Information:\n",
|
||||
"\n",
|
||||
"Email: jessica.miller@example.com\n",
|
||||
"Phone: (555) 567-8901\n",
|
||||
"Summary:\n",
|
||||
"\n",
|
||||
"Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams. Excellent communication and leadership skills.\n",
|
||||
"\n",
|
||||
"Education:\n",
|
||||
"\n",
|
||||
"B.A. in Business Administration, University of Southern California (2010)\n",
|
||||
"Experience:\n",
|
||||
"\n",
|
||||
"Sales Manager, Global Enterprises (2015 – Present)\n",
|
||||
"Managed a sales team of 15 members, achieving a 20% increase in annual revenue.\n",
|
||||
"Developed sales strategies that expanded customer base by 25%.\n",
|
||||
"Sales Representative, Market Leaders Inc. (2010 – 2015)\n",
|
||||
"Consistently exceeded sales targets and received the 'Top Salesperson' award in 2013.\n",
|
||||
"Skills:\n",
|
||||
"\n",
|
||||
"Sales Strategy and Planning\n",
|
||||
"Team Leadership and Development\n",
|
||||
"CRM Software: Salesforce, Zoho\n",
|
||||
"Negotiation and Relationship Building\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4415446a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Please add the necessary environment information bellow:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bce39dc6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# Setting environment variables\n",
|
||||
"if \"GRAPHISTRY_USERNAME\" not in os.environ:\n",
|
||||
" os.environ[\"GRAPHISTRY_USERNAME\"] = \"\"\n",
|
||||
"\n",
|
||||
"if \"GRAPHISTRY_PASSWORD\" not in os.environ:\n",
|
||||
" os.environ[\"GRAPHISTRY_PASSWORD\"] = \"\"\n",
|
||||
"\n",
|
||||
"if \"LLM_API_KEY\" not in os.environ:\n",
|
||||
" os.environ[\"LLM_API_KEY\"] = \"\"\n",
|
||||
"\n",
|
||||
"# \"neo4j\" or \"networkx\"\n",
|
||||
"os.environ[\"GRAPH_DATABASE_PROVIDER\"] = \"networkx\"\n",
|
||||
"# Not needed if using networkx\n",
|
||||
"# os.environ[\"GRAPH_DATABASE_URL\"]=\"\"\n",
|
||||
"# os.environ[\"GRAPH_DATABASE_USERNAME\"]=\"\"\n",
|
||||
"# os.environ[\"GRAPH_DATABASE_PASSWORD\"]=\"\"\n",
|
||||
"\n",
|
||||
"# \"pgvector\", \"qdrant\", \"weaviate\" or \"lancedb\"\n",
|
||||
"os.environ[\"VECTOR_DB_PROVIDER\"] = \"lancedb\"\n",
|
||||
"# Not needed if using \"lancedb\" or \"pgvector\"\n",
|
||||
"# os.environ[\"VECTOR_DB_URL\"]=\"\"\n",
|
||||
"# os.environ[\"VECTOR_DB_KEY\"]=\"\"\n",
|
||||
"\n",
|
||||
"# Relational Database provider \"sqlite\" or \"postgres\"\n",
|
||||
"os.environ[\"DB_PROVIDER\"] = \"sqlite\"\n",
|
||||
"\n",
|
||||
"# Database name\n",
|
||||
"os.environ[\"DB_NAME\"] = \"cognee_db\"\n",
|
||||
"\n",
|
||||
"# Postgres specific parameters (Only if Postgres or PGVector is used)\n",
|
||||
"# os.environ[\"DB_HOST\"]=\"127.0.0.1\"\n",
|
||||
"# os.environ[\"DB_PORT\"]=\"5432\"\n",
|
||||
"# os.environ[\"DB_USERNAME\"]=\"cognee\"\n",
|
||||
"# os.environ[\"DB_PASSWORD\"]=\"cognee\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9f1a1dbd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reset the cognee system with the following command:\n",
|
||||
"\n",
|
||||
"import cognee\n",
|
||||
"\n",
|
||||
"await cognee.prune.prune_data()\n",
|
||||
"await cognee.prune.prune_system(metadata=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "383d6971",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### After we have defined and gathered our data let's add it to cognee "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "904df61ba484a8e5",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:54.243987Z",
|
||||
"start_time": "2024-09-20T14:02:52.498195Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cognee\n",
|
||||
"\n",
|
||||
"await cognee.add([job_1, job_2, job_3, job_4, job_5, job_position], \"example\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0f15c5b1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### All good, let's cognify it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7c431fdef4921ae0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:57.925667Z",
|
||||
"start_time": "2024-09-20T14:02:57.922353Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cognee.shared.data_models import KnowledgeGraph\n",
|
||||
"from cognee.modules.data.models import Dataset, Data\n",
|
||||
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
|
||||
"from cognee.modules.cognify.config import get_cognify_config\n",
|
||||
"from cognee.modules.pipelines.tasks.Task import Task\n",
|
||||
"from cognee.modules.pipelines import run_tasks\n",
|
||||
"from cognee.modules.users.models import User\n",
|
||||
"from cognee.tasks.documents import (\n",
|
||||
" check_permissions_on_documents,\n",
|
||||
" classify_documents,\n",
|
||||
" extract_chunks_from_documents,\n",
|
||||
")\n",
|
||||
"from cognee.tasks.graph import extract_graph_from_data\n",
|
||||
"from cognee.tasks.storage import add_data_points\n",
|
||||
"from cognee.tasks.summarization import summarize_text\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def run_cognify_pipeline(dataset: Dataset, user: User = None):\n",
|
||||
" data_documents: list[Data] = await get_dataset_data(dataset_id=dataset.id)\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" cognee_config = get_cognify_config()\n",
|
||||
"\n",
|
||||
" tasks = [\n",
|
||||
" Task(classify_documents),\n",
|
||||
" Task(check_permissions_on_documents, user=user, permissions=[\"write\"]),\n",
|
||||
" Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n",
|
||||
" Task(\n",
|
||||
" extract_graph_from_data, graph_model=KnowledgeGraph,\n",
|
||||
" task_config={\"batch_size\": 10}\n",
|
||||
" ), # Generate knowledge graphs from the document chunks.\n",
|
||||
" Task(\n",
|
||||
" summarize_text,\n",
|
||||
" summarization_model=cognee_config.summarization_model,\n",
|
||||
" task_config={\"batch_size\": 10},\n",
|
||||
" ),\n",
|
||||
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" pipeline = run_tasks(tasks, data_documents)\n",
|
||||
"\n",
|
||||
" async for result in pipeline:\n",
|
||||
" print(result)\n",
|
||||
" except Exception as error:\n",
|
||||
" raise error"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f0a91b99c6215e09",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-09-20T14:02:58.905774Z",
|
||||
"start_time": "2024-09-20T14:02:58.625915Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cognee.modules.users.methods import get_default_user\n",
|
||||
"from cognee.modules.data.methods import get_datasets_by_name\n",
|
||||
"\n",
|
||||
"user = await get_default_user()\n",
|
||||
"\n",
|
||||
"datasets = await get_datasets_by_name([\"example\"], user.id)\n",
|
||||
"\n",
|
||||
"await run_cognify_pipeline(datasets[0], user)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "219a6d41",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### We get the url to the graph on graphistry in the notebook cell bellow, showing nodes and connections made by the cognify process."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "080389e5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from cognee.shared.utils import render_graph\n",
|
||||
"from cognee.infrastructure.databases.graph import get_graph_engine\n",
|
||||
"import graphistry\n",
|
||||
"\n",
|
||||
"graphistry.login(\n",
|
||||
" username=os.getenv(\"GRAPHISTRY_USERNAME\"), password=os.getenv(\"GRAPHISTRY_PASSWORD\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"graph_engine = await get_graph_engine()\n",
|
||||
"\n",
|
||||
"graph_url = await render_graph(graph_engine.graph)\n",
|
||||
"print(graph_url)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "59e6c3c3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### We can also do a search on the data to explore the knowledge."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e5e7dfc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"async def search(\n",
|
||||
" vector_engine,\n",
|
||||
" collection_name: str,\n",
|
||||
" query_text: str = None,\n",
|
||||
"):\n",
|
||||
" query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0]\n",
|
||||
"\n",
|
||||
" connection = await vector_engine.get_connection()\n",
|
||||
" collection = await connection.open_table(collection_name)\n",
|
||||
"\n",
|
||||
" results = await collection.vector_search(query_vector).limit(10).to_pandas()\n",
|
||||
"\n",
|
||||
" result_values = list(results.to_dict(\"index\").values())\n",
|
||||
"\n",
|
||||
" return [\n",
|
||||
" dict(\n",
|
||||
" id=str(result[\"id\"]),\n",
|
||||
" payload=result[\"payload\"],\n",
|
||||
" score=result[\"_distance\"],\n",
|
||||
" )\n",
|
||||
" for result in result_values\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from cognee.infrastructure.databases.vector import get_vector_engine\n",
|
||||
"\n",
|
||||
"vector_engine = get_vector_engine()\n",
|
||||
"results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n",
|
||||
"for result in results:\n",
|
||||
" print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "81fa2b00",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### We normalize search output scores so the lower the score of the search result is the higher the chance that it's what you're looking for. In the example above we have searched for node entities in the knowledge graph related to \"sarah.nguyen@example.com\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1b94ff96",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### In the example bellow we'll use cognee search to summarize information regarding the node most related to \"sarah.nguyen@example.com\" in the knowledge graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "21a3e9a6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cognee.api.v1.search import SearchType\n",
|
||||
"\n",
|
||||
"node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
|
||||
"node_name = node.payload[\"text\"]\n",
|
||||
"\n",
|
||||
"search_results = await cognee.search(query_type=SearchType.SUMMARIES, query_text=node_name)\n",
|
||||
"print(\"\\n\\Extracted summaries are:\\n\")\n",
|
||||
"for result in search_results:\n",
|
||||
" print(f\"{result}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fd6e5fe2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### In this example we'll use cognee search to find chunks in which the node most related to \"sarah.nguyen@example.com\" is a part of"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c7a8abff",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=node_name)\n",
|
||||
"print(\"\\n\\nExtracted chunks are:\\n\")\n",
|
||||
"for result in search_results:\n",
|
||||
" print(f\"{result}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "47f0112f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### In this example we'll use cognee search to give us insights from the knowledge graph related to the node most related to \"sarah.nguyen@example.com\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "706a3954",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search_results = await cognee.search(query_type=SearchType.INSIGHTS, query_text=node_name)\n",
|
||||
"print(\"\\n\\nExtracted sentences are:\\n\")\n",
|
||||
"for result in search_results:\n",
|
||||
" print(f\"{result}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e519e30c0423c2a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Let's add evals"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3845443e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install \"cognee[deepeval]\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7a2c3c70",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from evals.eval_on_hotpot import deepeval_answers, answer_qa_instance\n",
|
||||
"from evals.qa_dataset_utils import load_qa_dataset\n",
|
||||
"from evals.qa_metrics_utils import get_metrics\n",
|
||||
"from evals.qa_context_provider_utils import qa_context_providers\n",
|
||||
"from pathlib import Path\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import statistics\n",
|
||||
"import random"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "53a609d8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"num_samples = 10 # With cognee, it takes ~1m10s per sample\n",
|
||||
"dataset_name_or_filename = \"hotpotqa\"\n",
|
||||
"dataset = load_qa_dataset(dataset_name_or_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7351ab8f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"context_provider_name = \"cognee\"\n",
|
||||
"context_provider = qa_context_providers[context_provider_name]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9346115b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"random.seed(42)\n",
|
||||
"instances = dataset if not num_samples else random.sample(dataset, num_samples)\n",
|
||||
"\n",
|
||||
"out_path = \"out\"\n",
|
||||
"if not Path(out_path).exists():\n",
|
||||
" Path(out_path).mkdir()\n",
|
||||
"contexts_filename = out_path / Path(\n",
|
||||
" f\"contexts_{dataset_name_or_filename.split('.')[0]}_{context_provider_name}.json\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"answers = []\n",
|
||||
"for instance in tqdm(instances, desc=\"Getting answers\"):\n",
|
||||
" answer = await answer_qa_instance(instance, context_provider, contexts_filename)\n",
|
||||
" answers.append(answer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1e7d872d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define Metrics for Evaluation and Calculate Score\n",
|
||||
"**Options**: \n",
|
||||
"- **Correctness**: Is the actual output factually correct based on the expected output?\n",
|
||||
"- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?\n",
|
||||
"- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?\n",
|
||||
"- **Empowerment**: How well does the answer help the reader understand and make informed judgements about the topic?\n",
|
||||
"- **Directness**: How specifically and clearly does the answer address the question?\n",
|
||||
"- **F1 Score**: the harmonic mean of the precision and recall, using word-level Exact Match\n",
|
||||
"- **EM Score**: the rate at which the predicted strings exactly match their references, ignoring white spaces and capitalization."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c81e2b46",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculate `\"Correctness\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ae728344",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"Correctness\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "764aac6d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Correctness = statistics.mean(\n",
|
||||
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
|
||||
")\n",
|
||||
"print(Correctness)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6d3bbdc5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating `\"Comprehensiveness\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9793ef78",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"Comprehensiveness\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9add448a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Comprehensiveness = statistics.mean(\n",
|
||||
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
|
||||
")\n",
|
||||
"print(Comprehensiveness)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bce2fa25",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating `\"Diversity\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f60a179e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"Diversity\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7ccbd0ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Diversity = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
|
||||
"print(Diversity)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "191cab63",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating`\"Empowerment\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "66bec0bf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"Empowerment\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b043a8f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Empowerment = statistics.mean(\n",
|
||||
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
|
||||
")\n",
|
||||
"print(Empowerment)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2cac3be9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating `\"Directness\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "adaa17c0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"Directness\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3a8f97c9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Directness = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
|
||||
"print(Directness)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1ad6feb8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating `\"F1\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bdc48259",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"F1\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c43c17c8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8bfcc46d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"F1_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
|
||||
"print(F1_score)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2583f948",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Calculating `\"EM\"`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "90a8f630",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"metric_name_list = [\"EM\"]\n",
|
||||
"eval_metrics = get_metrics(metric_name_list)\n",
|
||||
"eval_results = await deepeval_answers(instances, answers, eval_metrics[\"deepeval_metrics\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8d1b1ea1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"EM = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])\n",
|
||||
"print(EM)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "288ab570",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Give us a star if you like it!\n",
|
||||
"https://github.com/topoteretes/cognee"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "cognee-c83GrcRT-py3.11",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
|
@ -3,7 +3,9 @@
|
|||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "[](https://colab.research.google.com/drive/1EpokQ8Y_5jIJ7HdixZms81Oqgh2sp7-E?usp=sharing)"
|
||||
"source": [
|
||||
"[](https://colab.research.google.com/drive/1EpokQ8Y_5jIJ7HdixZms81Oqgh2sp7-E?usp=sharing)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -57,7 +59,9 @@
|
|||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": "!pip install llama-index-graph-rag-cognee==0.1.3"
|
||||
"source": [
|
||||
"!pip install llama-index-graph-rag-cognee==0.1.3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
|
@ -192,7 +196,9 @@
|
|||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "### Answer prompt based on RAG approach:"
|
||||
"source": [
|
||||
"### Answer prompt based on RAG approach:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
|
@ -210,7 +216,9 @@
|
|||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "In conclusion, the results demonstrate a significant advantage of the knowledge graph-based approach (Graphrag) over the RAG approach. Graphrag successfully identified all the mentioned individuals across multiple documents, showcasing its ability to aggregate and infer information from a global context. In contrast, the RAG approach was limited to identifying individuals within a single document due to its chunking-based processing constraints. This highlights Graphrag's superior capability in comprehensively resolving queries that span across a broader corpus of interconnected data."
|
||||
"source": [
|
||||
"In conclusion, the results demonstrate a significant advantage of the knowledge graph-based approach (Graphrag) over the RAG approach. Graphrag successfully identified all the mentioned individuals across multiple documents, showcasing its ability to aggregate and infer information from a global context. In contrast, the RAG approach was limited to identifying individuals within a single document due to its chunking-based processing constraints. This highlights Graphrag's superior capability in comprehensively resolving queries that span across a broader corpus of interconnected data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -271,7 +279,8 @@
|
|||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
"name": "python",
|
||||
"version": "3.12.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
634
notebooks/node_scores.ipynb
Normal file
634
notebooks/node_scores.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2783
poetry.lock
generated
2783
poetry.lock
generated
File diff suppressed because it is too large
Load diff
250
pyproject.toml
250
pyproject.toml
|
|
@ -1,12 +1,14 @@
|
|||
[tool.poetry]
|
||||
[project]
|
||||
name = "cognee"
|
||||
version = "0.1.39"
|
||||
version = "0.1.40"
|
||||
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
||||
authors = ["Vasilije Markovic", "Boris Arzentar"]
|
||||
authors = [
|
||||
{ name = "Vasilije Markovic" },
|
||||
{ name = "Boris Arzentar" },
|
||||
]
|
||||
requires-python = ">=3.10,<=3.13"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
homepage = "https://www.cognee.ai"
|
||||
repository = "https://github.com/topoteretes/cognee"
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
|
|
@ -14,129 +16,127 @@ classifiers = [
|
|||
"Topic :: Software Development :: Libraries",
|
||||
"Operating System :: MacOS :: MacOS X",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: Microsoft :: Windows"
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
]
|
||||
dependencies = [
|
||||
"openai>=1.59.4,<2",
|
||||
"python-dotenv>=1.0.1",
|
||||
"pydantic==2.10.5",
|
||||
"pydantic-settings>=2.2.1,<3",
|
||||
"typing_extensions==4.12.2",
|
||||
"nltk==3.9.1",
|
||||
"numpy>=1.26.4, <=2.1",
|
||||
"pandas>=2.2.2",
|
||||
# Note: New s3fs and boto3 versions don't work well together
|
||||
# Always use comaptible fixed versions of these two dependencies
|
||||
"s3fs[boto3]==2025.3.2",
|
||||
"sqlalchemy==2.0.39",
|
||||
"aiosqlite>=0.20.0,<0.21",
|
||||
"tiktoken<=0.9.0",
|
||||
"litellm>=1.57.4",
|
||||
"instructor==1.7.2",
|
||||
"langfuse>=2.32.0",
|
||||
"filetype>=1.2.0",
|
||||
"aiohttp>=3.11.14",
|
||||
"aiofiles>=23.2.1",
|
||||
"owlready2>=0.47,<0.48",
|
||||
"graphistry>=0.33.5,<0.34",
|
||||
"pypdf>=4.1.0,<6.0.0",
|
||||
"jinja2>=3.1.3,<4",
|
||||
"matplotlib>=3.8.3,<4",
|
||||
"networkx>=3.4.2,<4",
|
||||
"lancedb==0.21.0",
|
||||
"alembic>=1.13.3,<2",
|
||||
"pre-commit>=4.0.1,<5",
|
||||
"scikit-learn>=1.6.1,<2",
|
||||
"limits>=4.4.1,<5",
|
||||
"fastapi==0.115.7",
|
||||
"python-multipart==0.0.20",
|
||||
"fastapi-users[sqlalchemy]==14.0.1",
|
||||
"dlt[sqlalchemy]>=1.9.0,<2",
|
||||
"sentry-sdk[fastapi]>=2.9.0,<3",
|
||||
"structlog>=25.2.0,<26",
|
||||
"onnxruntime<=1.21.1",
|
||||
"pylance==0.22.0",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<=3.13"
|
||||
openai = "^1.59.4"
|
||||
python-dotenv = "1.0.1"
|
||||
pydantic = "2.10.5"
|
||||
pydantic-settings = "^2.2.1"
|
||||
typing_extensions = "4.12.2"
|
||||
nltk = "3.9.1"
|
||||
numpy = ">=1.26.4, <=2.1"
|
||||
pandas = "2.2.3"
|
||||
boto3 = "^1.26.125"
|
||||
botocore="^1.35.54"
|
||||
sqlalchemy = "2.0.39"
|
||||
aiosqlite = "^0.20.0"
|
||||
tiktoken = "<=0.9.0"
|
||||
litellm = ">=1.57.4"
|
||||
instructor = "1.7.2"
|
||||
langfuse = "^2.32.0"
|
||||
filetype = "^1.2.0"
|
||||
aiohttp = "^3.11.14"
|
||||
aiofiles = "^23.2.1"
|
||||
owlready2 = "^0.47"
|
||||
graphistry = "^0.33.5"
|
||||
pypdf = ">=4.1.0,<6.0.0"
|
||||
jinja2 = "^3.1.3"
|
||||
matplotlib = "^3.8.3"
|
||||
networkx = "^3.2.1"
|
||||
lancedb = "0.16.0"
|
||||
alembic = "^1.13.3"
|
||||
pre-commit = "^4.0.1"
|
||||
scikit-learn = "^1.6.1"
|
||||
limits = "^4.4.1"
|
||||
fastapi = {version = "0.115.7"}
|
||||
python-multipart = "0.0.20"
|
||||
fastapi-users = {version = "14.0.1", extras = ["sqlalchemy"]}
|
||||
uvicorn = {version = "0.34.0", optional = true}
|
||||
gunicorn = {version = "^20.1.0", optional = true}
|
||||
dlt = {extras = ["sqlalchemy"], version = "^1.9.0"}
|
||||
qdrant-client = {version = "^1.9.0", optional = true}
|
||||
weaviate-client = {version = "4.9.6", optional = true}
|
||||
neo4j = {version = "^5.20.0", optional = true}
|
||||
falkordb = {version = "1.0.9", optional = true}
|
||||
kuzu = {version = "0.8.2", optional = true}
|
||||
chromadb = {version = "^0.6.0", optional = true}
|
||||
langchain_text_splitters = {version = "0.3.2", optional = true}
|
||||
langsmith = {version = "0.2.3", optional = true}
|
||||
posthog = {version = "^3.5.0", optional = true}
|
||||
groq = {version = "0.8.0", optional = true}
|
||||
anthropic = {version = "^0.26.1", optional = true}
|
||||
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
|
||||
asyncpg = {version = "0.30.0", optional = true}
|
||||
pgvector = {version = "^0.3.5", optional = true}
|
||||
psycopg2 = {version = "^2.9.10", optional = true}
|
||||
llama-index-core = {version = "^0.12.11", optional = true}
|
||||
deepeval = {version = "^2.0.1", optional = true}
|
||||
transformers = {version = "^4.46.3", optional = true}
|
||||
pymilvus = {version = "^2.5.0", optional = true}
|
||||
unstructured = { extras = ["csv", "doc", "docx", "epub", "md", "odt", "org", "ppt", "pptx", "rst", "rtf", "tsv", "xlsx"], version = "^0.16.13", optional = true }
|
||||
mistral-common = {version = "^1.5.2", optional = true}
|
||||
fastembed = {version = "<=0.6.0", optional = true, markers = "python_version < '3.13'"}
|
||||
tree-sitter = {version = "^0.24.0", optional = true}
|
||||
tree-sitter-python = {version = "^0.23.6", optional = true}
|
||||
plotly = {version = "^6.0.0", optional = true}
|
||||
gdown = {version = "^5.2.0", optional = true}
|
||||
qasync = {version = "^0.27.1", optional = true}
|
||||
graphiti-core = {version = "^0.7.0", optional = true}
|
||||
structlog = "^25.2.0"
|
||||
pyside6 = {version = "^6.8.3", optional = true}
|
||||
google-generativeai = {version = "^0.8.4", optional = true}
|
||||
notebook = {version = "^7.1.0", optional = true}
|
||||
s3fs = "^2025.3.2"
|
||||
[project.optional-dependencies]
|
||||
api = [
|
||||
"uvicorn==0.34.0",
|
||||
"gunicorn>=20.1.0,<21",
|
||||
]
|
||||
weaviate = ["weaviate-client==4.9.6"]
|
||||
qdrant = ["qdrant-client>=1.9.0,<2"]
|
||||
neo4j = ["neo4j>=5.20.0,<6"]
|
||||
postgres = [
|
||||
"psycopg2>=2.9.10,<3",
|
||||
"pgvector>=0.3.5,<0.4",
|
||||
"asyncpg==0.30.0",
|
||||
]
|
||||
notebook = ["notebook>=7.1.0,<8"]
|
||||
langchain = [
|
||||
"langsmith==0.2.3",
|
||||
"langchain_text_splitters==0.3.2",
|
||||
]
|
||||
llama-index = ["llama-index-core>=0.12.11,<0.13"]
|
||||
gemini = ["google-generativeai>=0.8.4,<0.9"]
|
||||
huggingface = ["transformers>=4.46.3,<5"]
|
||||
ollama = ["transformers>=4.46.3,<5"]
|
||||
mistral = ["mistral-common>=1.5.2,<2"]
|
||||
anthropic = ["anthropic>=0.26.1,<0.27"]
|
||||
deepeval = ["deepeval>=2.0.1,<3"]
|
||||
posthog = ["posthog>=3.5.0,<4"]
|
||||
falkordb = ["falkordb==1.0.9"]
|
||||
kuzu = ["kuzu==0.8.2"]
|
||||
groq = ["groq==0.8.0"]
|
||||
milvus = ["pymilvus>=2.5.0,<3"]
|
||||
chromadb = [
|
||||
"chromadb>=0.3.0,<0.7",
|
||||
"pypika==0.48.8",
|
||||
]
|
||||
docs = ["unstructured[csv, doc, docx, epub, md, odt, org, ppt, pptx, rst, rtf, tsv, xlsx]>=0.16.13,<0.17"]
|
||||
codegraph = [
|
||||
"fastembed<=0.6.0 ; python_version < '3.13'",
|
||||
"transformers>=4.46.3,<5",
|
||||
"tree-sitter>=0.24.0,<0.25",
|
||||
"tree-sitter-python>=0.23.6,<0.24",
|
||||
]
|
||||
evals = [
|
||||
"plotly>=6.0.0,<7",
|
||||
"gdown>=5.2.0,<6",
|
||||
]
|
||||
gui = [
|
||||
"pyside6>=6.8.3,<7",
|
||||
"qasync>=0.27.1,<0.28",
|
||||
]
|
||||
graphiti = ["graphiti-core>=0.7.0,<0.8"]
|
||||
dev = [
|
||||
"pytest>=7.4.0,<8",
|
||||
"pytest-cov>=6.1.1",
|
||||
"pytest-asyncio>=0.21.1,<0.22",
|
||||
"coverage>=7.3.2,<8",
|
||||
"mypy>=1.7.1,<2",
|
||||
"notebook>=7.1.0,<8",
|
||||
"deptry>=0.20.0,<0.21",
|
||||
"pylint>=3.0.3,<4",
|
||||
"ruff>=0.9.2,<1.0.0",
|
||||
"tweepy==4.14.0",
|
||||
"gitpython>=3.1.43,<4",
|
||||
"mkdocs-material>=9.5.42,<10",
|
||||
"mkdocs-minify-plugin>=0.8.0,<0.9",
|
||||
"mkdocstrings[python]>=0.26.2,<0.27",
|
||||
]
|
||||
debug = ["debugpy==1.8.9"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://www.cognee.ai"
|
||||
Repository = "https://github.com/topoteretes/cognee"
|
||||
|
||||
[tool.poetry.extras]
|
||||
api = ["uvicorn", "gunicorn"]
|
||||
weaviate = ["weaviate-client"]
|
||||
qdrant = ["qdrant-client"]
|
||||
neo4j = ["neo4j"]
|
||||
postgres = ["psycopg2", "pgvector", "asyncpg"]
|
||||
notebook = ["notebook", "ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
|
||||
langchain = ["langsmith", "langchain_text_splitters"]
|
||||
llama-index = ["llama-index-core"]
|
||||
gemini = ["google-generativeai"]
|
||||
huggingface = ["transformers"]
|
||||
ollama = ["transformers"]
|
||||
mistral = ["mistral-common"]
|
||||
anthropic = ["anthropic"]
|
||||
deepeval = ["deepeval"]
|
||||
posthog = ["posthog"]
|
||||
falkordb = ["falkordb"]
|
||||
kuzu = ["kuzu"]
|
||||
groq = ["groq"]
|
||||
milvus = ["pymilvus"]
|
||||
chromadb = ["chromadb"]
|
||||
docs = ["unstructured"]
|
||||
codegraph = ["fastembed", "transformers", "tree-sitter", "tree-sitter-python"]
|
||||
evals = ["plotly", "gdown"]
|
||||
gui = ["pyside6", "qasync"]
|
||||
graphiti = ["graphiti-core"]
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^7.4.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
coverage = "^7.3.2"
|
||||
mypy = "^1.7.1"
|
||||
notebook = {version = "^7.1.0", optional = true}
|
||||
deptry = "^0.20.0"
|
||||
debugpy = "1.8.9"
|
||||
pylint = "^3.0.3"
|
||||
ruff = ">=0.9.2,<1.0.0"
|
||||
tweepy = "4.14.0"
|
||||
gitpython = "^3.1.43"
|
||||
pylance = "0.19.2"
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
mkdocs-material = "^9.5.42"
|
||||
mkdocs-minify-plugin = "^0.8.0"
|
||||
mkdocstrings = {extras = ["python"], version = "^0.26.2"}
|
||||
|
||||
[tool.ruff] # https://beta.ruff.rs/docs/
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
exclude = [
|
||||
"migrations/", # Ignore migrations directory
|
||||
|
|
@ -151,7 +151,3 @@ exclude = [
|
|||
|
||||
[tool.ruff.lint]
|
||||
ignore = ["F401"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue