Merge branch 'main' of github.com:topoteretes/cognee-private into COG-502-backend-error-handling
This commit is contained in:
commit
04960eeb4e
33 changed files with 864 additions and 261 deletions
66
.github/workflows/reusable_notebook.yml
vendored
Normal file
66
.github/workflows/reusable_notebook.yml
vendored
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
name: test-notebook
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
notebook-location:
|
||||||
|
description: "Location of Jupyter notebook to run"
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
secrets:
|
||||||
|
GRAPHISTRY_USERNAME:
|
||||||
|
required: true
|
||||||
|
GRAPHISTRY_PASSWORD:
|
||||||
|
required: true
|
||||||
|
OPENAI_API_KEY:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
env:
|
||||||
|
RUNTIME__LOG_LEVEL: ERROR
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
get_docs_changes:
|
||||||
|
name: docs changes
|
||||||
|
uses: ./.github/workflows/get_docs_changes.yml
|
||||||
|
|
||||||
|
run_notebook_test:
|
||||||
|
name: test
|
||||||
|
needs: get_docs_changes
|
||||||
|
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
steps:
|
||||||
|
- name: Check out
|
||||||
|
uses: actions/checkout@master
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: snok/install-poetry@v1.3.2
|
||||||
|
with:
|
||||||
|
virtualenvs-create: true
|
||||||
|
virtualenvs-in-project: true
|
||||||
|
installer-parallel: true
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
poetry install --no-interaction --all-extras
|
||||||
|
poetry add jupyter --no-interaction
|
||||||
|
|
||||||
|
- name: Execute Jupyter Notebook
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
|
run: |
|
||||||
|
poetry run jupyter nbconvert \
|
||||||
|
--to notebook \
|
||||||
|
--execute ${{ inputs.notebook-location }} \
|
||||||
|
--output executed_notebook.ipynb \
|
||||||
|
--ExecutePreprocessor.timeout=1200
|
||||||
60
.github/workflows/reusable_python_example.yml
vendored
Normal file
60
.github/workflows/reusable_python_example.yml
vendored
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
name: test-example
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
example-location:
|
||||||
|
description: "Location of example script to run"
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
secrets:
|
||||||
|
GRAPHISTRY_USERNAME:
|
||||||
|
required: true
|
||||||
|
GRAPHISTRY_PASSWORD:
|
||||||
|
required: true
|
||||||
|
OPENAI_API_KEY:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
env:
|
||||||
|
RUNTIME__LOG_LEVEL: ERROR
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
get_docs_changes:
|
||||||
|
name: docs changes
|
||||||
|
uses: ./.github/workflows/get_docs_changes.yml
|
||||||
|
|
||||||
|
run_notebook_test:
|
||||||
|
name: test
|
||||||
|
needs: get_docs_changes
|
||||||
|
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
steps:
|
||||||
|
- name: Check out
|
||||||
|
uses: actions/checkout@master
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: snok/install-poetry@v1.3.2
|
||||||
|
with:
|
||||||
|
virtualenvs-create: true
|
||||||
|
virtualenvs-in-project: true
|
||||||
|
installer-parallel: true
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
poetry install --no-interaction --all-extras
|
||||||
|
|
||||||
|
- name: Execute Python Example
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
|
run: poetry run python ${{ inputs.example-location }}
|
||||||
|
|
@ -7,57 +7,16 @@ on:
|
||||||
- main
|
- main
|
||||||
types: [labeled, synchronize]
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
env:
|
|
||||||
RUNTIME__LOG_LEVEL: ERROR
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
get_docs_changes:
|
|
||||||
name: docs changes
|
|
||||||
uses: ./.github/workflows/get_docs_changes.yml
|
|
||||||
|
|
||||||
run_notebook_test:
|
run_notebook_test:
|
||||||
name: test
|
uses: ./.github/workflows/reusable_notebook.yml
|
||||||
needs: get_docs_changes
|
with:
|
||||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && github.event.label.name == 'run-checks'
|
notebook-location: notebooks/cognee_llama_index.ipynb
|
||||||
runs-on: ubuntu-latest
|
secrets:
|
||||||
defaults:
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
run:
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
shell: bash
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
steps:
|
|
||||||
- name: Check out
|
|
||||||
uses: actions/checkout@master
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.11.x'
|
|
||||||
|
|
||||||
- name: Install Poetry
|
|
||||||
uses: snok/install-poetry@v1.3.2
|
|
||||||
with:
|
|
||||||
virtualenvs-create: true
|
|
||||||
virtualenvs-in-project: true
|
|
||||||
installer-parallel: true
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
poetry install --no-interaction --all-extras
|
|
||||||
poetry add jupyter --no-interaction
|
|
||||||
|
|
||||||
- name: Execute Jupyter Notebook
|
|
||||||
env:
|
|
||||||
ENV: 'dev'
|
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
|
||||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
|
||||||
run: |
|
|
||||||
poetry run jupyter nbconvert \
|
|
||||||
--to notebook \
|
|
||||||
--execute notebooks/cognee_llama_index.ipynb \
|
|
||||||
--output executed_notebook.ipynb \
|
|
||||||
--ExecutePreprocessor.timeout=1200
|
|
||||||
|
|
|
||||||
|
|
@ -7,57 +7,16 @@ on:
|
||||||
- main
|
- main
|
||||||
types: [labeled, synchronize]
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
env:
|
|
||||||
RUNTIME__LOG_LEVEL: ERROR
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
get_docs_changes:
|
|
||||||
name: docs changes
|
|
||||||
uses: ./.github/workflows/get_docs_changes.yml
|
|
||||||
|
|
||||||
run_notebook_test:
|
run_notebook_test:
|
||||||
name: test
|
uses: ./.github/workflows/reusable_notebook.yml
|
||||||
needs: get_docs_changes
|
with:
|
||||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
notebook-location: notebooks/cognee_multimedia_demo.ipynb
|
||||||
runs-on: ubuntu-latest
|
secrets:
|
||||||
defaults:
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
run:
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
shell: bash
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
steps:
|
|
||||||
- name: Check out
|
|
||||||
uses: actions/checkout@master
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.11.x'
|
|
||||||
|
|
||||||
- name: Install Poetry
|
|
||||||
uses: snok/install-poetry@v1.3.2
|
|
||||||
with:
|
|
||||||
virtualenvs-create: true
|
|
||||||
virtualenvs-in-project: true
|
|
||||||
installer-parallel: true
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
poetry install --no-interaction
|
|
||||||
poetry add jupyter --no-interaction
|
|
||||||
|
|
||||||
- name: Execute Jupyter Notebook
|
|
||||||
env:
|
|
||||||
ENV: 'dev'
|
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
|
||||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
|
||||||
run: |
|
|
||||||
poetry run jupyter nbconvert \
|
|
||||||
--to notebook \
|
|
||||||
--execute notebooks/cognee_multimedia_demo.ipynb \
|
|
||||||
--output executed_notebook.ipynb \
|
|
||||||
--ExecutePreprocessor.timeout=1200
|
|
||||||
|
|
|
||||||
23
.github/workflows/test_dynamic_steps_example.yml
vendored
Normal file
23
.github/workflows/test_dynamic_steps_example.yml
vendored
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
name: test | dynamic steps example
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_dynamic_steps_example_test:
|
||||||
|
uses: ./.github/workflows/reusable_python_example.yml
|
||||||
|
with:
|
||||||
|
example-location: ./examples/python/dynamic_steps_example.py
|
||||||
|
secrets:
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
23
.github/workflows/test_multimedia_example.yaml
vendored
Normal file
23
.github/workflows/test_multimedia_example.yaml
vendored
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
name: test | multimedia example
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_multimedia_example_test:
|
||||||
|
uses: ./.github/workflows/reusable_python_example.yml
|
||||||
|
with:
|
||||||
|
example-location: ./examples/python/multimedia_example.py
|
||||||
|
secrets:
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
54
.github/workflows/test_notebook.yml
vendored
54
.github/workflows/test_notebook.yml
vendored
|
|
@ -12,52 +12,12 @@ concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
env:
|
|
||||||
RUNTIME__LOG_LEVEL: ERROR
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
get_docs_changes:
|
|
||||||
name: docs changes
|
|
||||||
uses: ./.github/workflows/get_docs_changes.yml
|
|
||||||
|
|
||||||
run_notebook_test:
|
run_notebook_test:
|
||||||
name: test
|
uses: ./.github/workflows/reusable_notebook.yml
|
||||||
needs: get_docs_changes
|
with:
|
||||||
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
|
notebook-location: notebooks/cognee_demo.ipynb
|
||||||
runs-on: ubuntu-latest
|
secrets:
|
||||||
defaults:
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
run:
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
shell: bash
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
steps:
|
|
||||||
- name: Check out
|
|
||||||
uses: actions/checkout@master
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.11.x'
|
|
||||||
|
|
||||||
- name: Install Poetry
|
|
||||||
uses: snok/install-poetry@v1.3.2
|
|
||||||
with:
|
|
||||||
virtualenvs-create: true
|
|
||||||
virtualenvs-in-project: true
|
|
||||||
installer-parallel: true
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
poetry install --no-interaction
|
|
||||||
poetry add jupyter --no-interaction
|
|
||||||
|
|
||||||
- name: Execute Jupyter Notebook
|
|
||||||
env:
|
|
||||||
ENV: 'dev'
|
|
||||||
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
|
||||||
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
|
||||||
run: |
|
|
||||||
poetry run jupyter nbconvert \
|
|
||||||
--to notebook \
|
|
||||||
--execute notebooks/cognee_demo.ipynb \
|
|
||||||
--output executed_notebook.ipynb \
|
|
||||||
--ExecutePreprocessor.timeout=1200
|
|
||||||
|
|
|
||||||
23
.github/workflows/test_simple_example.yml
vendored
Normal file
23
.github/workflows/test_simple_example.yml
vendored
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
name: test | simple example
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
types: [labeled, synchronize]
|
||||||
|
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_simple_example_test:
|
||||||
|
uses: ./.github/workflows/reusable_python_example.yml
|
||||||
|
with:
|
||||||
|
example-location: ./examples/python/simple_example.py
|
||||||
|
secrets:
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
|
||||||
|
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
|
||||||
|
|
@ -1,11 +1,28 @@
|
||||||
"""Factory function to get the appropriate graph client based on the graph type."""
|
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
from .config import get_graph_config
|
from .config import get_graph_config
|
||||||
from .graph_db_interface import GraphDBInterface
|
from .graph_db_interface import GraphDBInterface
|
||||||
|
|
||||||
|
|
||||||
async def get_graph_engine() -> GraphDBInterface :
|
async def get_graph_engine() -> GraphDBInterface:
|
||||||
"""Factory function to get the appropriate graph client based on the graph type."""
|
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||||
|
graph_client = create_graph_engine()
|
||||||
|
|
||||||
|
# Async functions can't be cached. After creating and caching the graph engine
|
||||||
|
# handle all necessary async operations for different graph types bellow.
|
||||||
|
config = get_graph_config()
|
||||||
|
|
||||||
|
# Handle loading of graph for NetworkX
|
||||||
|
if config.graph_database_provider.lower() == "networkx" and graph_client.graph is None:
|
||||||
|
await graph_client.load_graph_from_file()
|
||||||
|
|
||||||
|
return graph_client
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def create_graph_engine() -> GraphDBInterface:
|
||||||
|
"""Factory function to create the appropriate graph client based on the graph type."""
|
||||||
config = get_graph_config()
|
config = get_graph_config()
|
||||||
|
|
||||||
if config.graph_database_provider == "neo4j":
|
if config.graph_database_provider == "neo4j":
|
||||||
|
|
@ -15,9 +32,9 @@ async def get_graph_engine() -> GraphDBInterface :
|
||||||
from .neo4j_driver.adapter import Neo4jAdapter
|
from .neo4j_driver.adapter import Neo4jAdapter
|
||||||
|
|
||||||
return Neo4jAdapter(
|
return Neo4jAdapter(
|
||||||
graph_database_url = config.graph_database_url,
|
graph_database_url=config.graph_database_url,
|
||||||
graph_database_username = config.graph_database_username,
|
graph_database_username=config.graph_database_username,
|
||||||
graph_database_password = config.graph_database_password
|
graph_database_password=config.graph_database_password
|
||||||
)
|
)
|
||||||
|
|
||||||
elif config.graph_database_provider == "falkordb":
|
elif config.graph_database_provider == "falkordb":
|
||||||
|
|
@ -30,15 +47,12 @@ async def get_graph_engine() -> GraphDBInterface :
|
||||||
embedding_engine = get_embedding_engine()
|
embedding_engine = get_embedding_engine()
|
||||||
|
|
||||||
return FalkorDBAdapter(
|
return FalkorDBAdapter(
|
||||||
database_url = config.graph_database_url,
|
database_url=config.graph_database_url,
|
||||||
database_port = config.graph_database_port,
|
database_port=config.graph_database_port,
|
||||||
embedding_engine = embedding_engine,
|
embedding_engine=embedding_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .networkx.adapter import NetworkXAdapter
|
from .networkx.adapter import NetworkXAdapter
|
||||||
graph_client = NetworkXAdapter(filename = config.graph_file_path)
|
graph_client = NetworkXAdapter(filename=config.graph_file_path)
|
||||||
|
|
||||||
if graph_client.graph is None:
|
|
||||||
await graph_client.load_graph_from_file()
|
|
||||||
|
|
||||||
return graph_client
|
return graph_client
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Optional, Any, List, Dict
|
from typing import Optional, Any, List, Dict, Union
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from neo4j import AsyncSession
|
from neo4j import AsyncSession
|
||||||
|
|
@ -432,3 +432,49 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
) for record in result]
|
) for record in result]
|
||||||
|
|
||||||
return (nodes, edges)
|
return (nodes, edges)
|
||||||
|
|
||||||
|
async def get_filtered_graph_data(self, attribute_filters):
|
||||||
|
"""
|
||||||
|
Fetches nodes and relationships filtered by specified attribute values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
|
||||||
|
Example: [{"community": ["1", "2"]}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing two lists: nodes and edges.
|
||||||
|
"""
|
||||||
|
where_clauses = []
|
||||||
|
for attribute, values in attribute_filters[0].items():
|
||||||
|
values_str = ", ".join(f"'{value}'" if isinstance(value, str) else str(value) for value in values)
|
||||||
|
where_clauses.append(f"n.{attribute} IN [{values_str}]")
|
||||||
|
|
||||||
|
where_clause = " AND ".join(where_clauses)
|
||||||
|
|
||||||
|
query_nodes = f"""
|
||||||
|
MATCH (n)
|
||||||
|
WHERE {where_clause}
|
||||||
|
RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties
|
||||||
|
"""
|
||||||
|
result_nodes = await self.query(query_nodes)
|
||||||
|
|
||||||
|
nodes = [(
|
||||||
|
record["id"],
|
||||||
|
record["properties"],
|
||||||
|
) for record in result_nodes]
|
||||||
|
|
||||||
|
query_edges = f"""
|
||||||
|
MATCH (n)-[r]->(m)
|
||||||
|
WHERE {where_clause} AND {where_clause.replace('n.', 'm.')}
|
||||||
|
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||||
|
"""
|
||||||
|
result_edges = await self.query(query_edges)
|
||||||
|
|
||||||
|
edges = [(
|
||||||
|
record["source"],
|
||||||
|
record["target"],
|
||||||
|
record["type"],
|
||||||
|
record["properties"],
|
||||||
|
) for record in result_edges]
|
||||||
|
|
||||||
|
return (nodes, edges)
|
||||||
|
|
@ -6,7 +6,7 @@ import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from re import A
|
from re import A
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiofiles.os as aiofiles_os
|
import aiofiles.os as aiofiles_os
|
||||||
|
|
@ -301,3 +301,39 @@ class NetworkXAdapter(GraphDBInterface):
|
||||||
logger.info("Graph deleted successfully.")
|
logger.info("Graph deleted successfully.")
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error("Failed to delete graph: %s", error)
|
logger.error("Failed to delete graph: %s", error)
|
||||||
|
|
||||||
|
async def get_filtered_graph_data(self, attribute_filters: List[Dict[str, List[Union[str, int]]]]):
|
||||||
|
"""
|
||||||
|
Fetches nodes and relationships filtered by specified attribute values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
|
||||||
|
Example: [{"community": ["1", "2"]}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing two lists:
|
||||||
|
- Nodes: List of tuples (node_id, node_properties).
|
||||||
|
- Edges: List of tuples (source_id, target_id, relationship_type, edge_properties).
|
||||||
|
"""
|
||||||
|
# Create filters for nodes based on the attribute filters
|
||||||
|
where_clauses = []
|
||||||
|
for attribute, values in attribute_filters[0].items():
|
||||||
|
where_clauses.append((attribute, values))
|
||||||
|
|
||||||
|
# Filter nodes
|
||||||
|
filtered_nodes = [
|
||||||
|
(node, data) for node, data in self.graph.nodes(data=True)
|
||||||
|
if all(data.get(attr) in values for attr, values in where_clauses)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Filter edges where both source and target nodes satisfy the filters
|
||||||
|
filtered_edges = [
|
||||||
|
(source, target, data.get('relationship_type', 'UNKNOWN'), data)
|
||||||
|
for source, target, data in self.graph.edges(data=True)
|
||||||
|
if (
|
||||||
|
all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses) and
|
||||||
|
all(self.graph.nodes[target].get(attr) in values for attr, values in where_clauses)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return filtered_nodes, filtered_edges
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
from .config import get_relational_config
|
from .config import get_relational_config
|
||||||
from .create_relational_engine import create_relational_engine
|
from .create_relational_engine import create_relational_engine
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
def get_relational_engine():
|
def get_relational_engine():
|
||||||
relational_config = get_relational_config()
|
relational_config = get_relational_config()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -172,6 +172,27 @@ class SQLAlchemyAdapter():
|
||||||
results = await connection.execute(query)
|
results = await connection.execute(query)
|
||||||
return {result["data_id"]: result["status"] for result in results}
|
return {result["data_id"]: result["status"] for result in results}
|
||||||
|
|
||||||
|
async def get_all_data_from_table(self, table_name: str, schema: str = "public"):
|
||||||
|
async with self.get_async_session() as session:
|
||||||
|
# Validate inputs to prevent SQL injection
|
||||||
|
if not table_name.isidentifier():
|
||||||
|
raise ValueError("Invalid table name")
|
||||||
|
if schema and not schema.isidentifier():
|
||||||
|
raise ValueError("Invalid schema name")
|
||||||
|
|
||||||
|
if self.engine.dialect.name == "sqlite":
|
||||||
|
table = await self.get_table(table_name)
|
||||||
|
else:
|
||||||
|
table = await self.get_table(table_name, schema)
|
||||||
|
|
||||||
|
# Query all data from the table
|
||||||
|
query = select(table)
|
||||||
|
result = await session.execute(query)
|
||||||
|
|
||||||
|
# Fetch all rows as a list of dictionaries
|
||||||
|
rows = result.mappings().all()
|
||||||
|
return rows
|
||||||
|
|
||||||
async def execute_query(self, query):
|
async def execute_query(self, query):
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
result = await connection.execute(text(query))
|
result = await connection.execute(text(query))
|
||||||
|
|
@ -206,7 +227,6 @@ class SQLAlchemyAdapter():
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
|
|
||||||
LocalStorage.remove(self.db_path)
|
LocalStorage.remove(self.db_path)
|
||||||
self.db_path = None
|
|
||||||
else:
|
else:
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
schema_list = await self.get_schema_list()
|
schema_list = await self.get_schema_list()
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from cognee.infrastructure.files.storage import LocalStorage
|
||||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
|
from ..utils import normalize_distances
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
class IndexSchema(DataPoint):
|
class IndexSchema(DataPoint):
|
||||||
|
|
@ -143,6 +144,33 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
score = 0,
|
score = 0,
|
||||||
) for result in results.to_dict("index").values()]
|
) for result in results.to_dict("index").values()]
|
||||||
|
|
||||||
|
async def get_distance_from_collection_elements(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query_text: str = None,
|
||||||
|
query_vector: List[float] = None
|
||||||
|
):
|
||||||
|
if query_text is None and query_vector is None:
|
||||||
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
|
if query_text and not query_vector:
|
||||||
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
||||||
|
connection = await self.get_connection()
|
||||||
|
collection = await connection.open_table(collection_name)
|
||||||
|
|
||||||
|
results = await collection.vector_search(query_vector).to_pandas()
|
||||||
|
|
||||||
|
result_values = list(results.to_dict("index").values())
|
||||||
|
|
||||||
|
normalized_values = normalize_distances(result_values)
|
||||||
|
|
||||||
|
return [ScoredResult(
|
||||||
|
id=UUID(result["id"]),
|
||||||
|
payload=result["payload"],
|
||||||
|
score=normalized_values[value_index],
|
||||||
|
) for value_index, result in enumerate(result_values)]
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
|
|
@ -150,6 +178,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
query_vector: List[float] = None,
|
query_vector: List[float] = None,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
|
normalized: bool = True
|
||||||
):
|
):
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
||||||
|
|
@ -164,26 +193,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
result_values = list(results.to_dict("index").values())
|
result_values = list(results.to_dict("index").values())
|
||||||
|
|
||||||
min_value = 100
|
normalized_values = normalize_distances(result_values)
|
||||||
max_value = 0
|
|
||||||
|
|
||||||
for result in result_values:
|
|
||||||
value = float(result["_distance"])
|
|
||||||
if value > max_value:
|
|
||||||
max_value = value
|
|
||||||
if value < min_value:
|
|
||||||
min_value = value
|
|
||||||
|
|
||||||
normalized_values = []
|
|
||||||
min_value = min(result["_distance"] for result in result_values)
|
|
||||||
max_value = max(result["_distance"] for result in result_values)
|
|
||||||
|
|
||||||
if max_value == min_value:
|
|
||||||
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
|
|
||||||
normalized_values = [0 for _ in result_values]
|
|
||||||
else:
|
|
||||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
|
|
||||||
result_values]
|
|
||||||
|
|
||||||
return [ScoredResult(
|
return [ScoredResult(
|
||||||
id = UUID(result["id"]),
|
id = UUID(result["id"]),
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
from .serialize_data import serialize_data
|
from .serialize_data import serialize_data
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
|
from ..utils import normalize_distances
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||||
from ...relational.ModelBase import Base
|
from ...relational.ModelBase import Base
|
||||||
|
|
@ -24,6 +25,19 @@ class IndexSchema(DataPoint):
|
||||||
"index_fields": ["text"]
|
"index_fields": ["text"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def singleton(class_):
|
||||||
|
# Note: Using this singleton as a decorator to a class removes
|
||||||
|
# the option to use class methods for that class
|
||||||
|
instances = {}
|
||||||
|
|
||||||
|
def getinstance(*args, **kwargs):
|
||||||
|
if class_ not in instances:
|
||||||
|
instances[class_] = class_(*args, **kwargs)
|
||||||
|
return instances[class_]
|
||||||
|
|
||||||
|
return getinstance
|
||||||
|
|
||||||
|
@singleton
|
||||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -164,6 +178,51 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
) for result in results
|
) for result in results
|
||||||
]
|
]
|
||||||
|
|
||||||
|
async def get_distance_from_collection_elements(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query_text: str = None,
|
||||||
|
query_vector: List[float] = None,
|
||||||
|
with_vector: bool = False
|
||||||
|
)-> List[ScoredResult]:
|
||||||
|
if query_text is None and query_vector is None:
|
||||||
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
|
if query_text and not query_vector:
|
||||||
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
||||||
|
# Get PGVectorDataPoint Table from database
|
||||||
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
|
|
||||||
|
# Use async session to connect to the database
|
||||||
|
async with self.get_async_session() as session:
|
||||||
|
# Find closest vectors to query_vector
|
||||||
|
closest_items = await session.execute(
|
||||||
|
select(
|
||||||
|
PGVectorDataPoint,
|
||||||
|
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
|
||||||
|
"similarity"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.order_by("similarity")
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_list = []
|
||||||
|
|
||||||
|
# Extract distances and find min/max for normalization
|
||||||
|
for vector in closest_items:
|
||||||
|
# TODO: Add normalization of similarity score
|
||||||
|
vector_list.append(vector)
|
||||||
|
|
||||||
|
# Create and return ScoredResult objects
|
||||||
|
return [
|
||||||
|
ScoredResult(
|
||||||
|
id = UUID(str(row.id)),
|
||||||
|
payload = row.payload,
|
||||||
|
score = row.similarity
|
||||||
|
) for row in vector_list
|
||||||
|
]
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
from ...relational.ModelBase import Base
|
|
||||||
from ..get_vector_engine import get_vector_engine, get_vectordb_config
|
from ..get_vector_engine import get_vector_engine, get_vectordb_config
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
|
||||||
async def create_db_and_tables():
|
async def create_db_and_tables():
|
||||||
vector_config = get_vectordb_config()
|
vector_config = get_vectordb_config()
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
if vector_config.vector_db_provider == "pgvector":
|
if vector_config.vector_db_provider == "pgvector":
|
||||||
await vector_engine.create_database()
|
|
||||||
async with vector_engine.engine.begin() as connection:
|
async with vector_engine.engine.begin() as connection:
|
||||||
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,41 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
await client.close()
|
await client.close()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
async def get_distance_from_collection_elements(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query_text: str = None,
|
||||||
|
query_vector: List[float] = None,
|
||||||
|
with_vector: bool = False
|
||||||
|
) -> List[ScoredResult]:
|
||||||
|
|
||||||
|
if query_text is None and query_vector is None:
|
||||||
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
|
client = self.get_qdrant_client()
|
||||||
|
|
||||||
|
results = await client.search(
|
||||||
|
collection_name = collection_name,
|
||||||
|
query_vector = models.NamedVector(
|
||||||
|
name = "text",
|
||||||
|
vector = query_vector if query_vector is not None else (await self.embed_data([query_text]))[0],
|
||||||
|
),
|
||||||
|
with_vectors = with_vector
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
return [
|
||||||
|
ScoredResult(
|
||||||
|
id = UUID(result.id),
|
||||||
|
payload = {
|
||||||
|
**result.payload,
|
||||||
|
"id": UUID(result.id),
|
||||||
|
},
|
||||||
|
score = 1 - result.score,
|
||||||
|
) for result in results
|
||||||
|
]
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
|
|
|
||||||
16
cognee/infrastructure/databases/vector/utils.py
Normal file
16
cognee/infrastructure/databases/vector/utils.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_distances(result_values: List[dict]) -> List[float]:
|
||||||
|
|
||||||
|
min_value = min(result["_distance"] for result in result_values)
|
||||||
|
max_value = max(result["_distance"] for result in result_values)
|
||||||
|
|
||||||
|
if max_value == min_value:
|
||||||
|
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
|
||||||
|
normalized_values = [0 for _ in result_values]
|
||||||
|
else:
|
||||||
|
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
|
||||||
|
result_values]
|
||||||
|
|
||||||
|
return normalized_values
|
||||||
|
|
@ -154,6 +154,36 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return await future
|
return await future
|
||||||
|
|
||||||
|
async def get_distance_from_collection_elements(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query_text: str = None,
|
||||||
|
query_vector: List[float] = None,
|
||||||
|
with_vector: bool = False
|
||||||
|
) -> List[ScoredResult]:
|
||||||
|
import weaviate.classes as wvc
|
||||||
|
|
||||||
|
if query_text is None and query_vector is None:
|
||||||
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
|
if query_vector is None:
|
||||||
|
query_vector = (await self.embed_data([query_text]))[0]
|
||||||
|
|
||||||
|
search_result = self.get_collection(collection_name).query.hybrid(
|
||||||
|
query=None,
|
||||||
|
vector=query_vector,
|
||||||
|
include_vector=with_vector,
|
||||||
|
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
ScoredResult(
|
||||||
|
id=UUID(str(result.uuid)),
|
||||||
|
payload=result.properties,
|
||||||
|
score=1 - float(result.metadata.score)
|
||||||
|
) for result in search_result.objects
|
||||||
|
]
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from typing import List, Dict, Union
|
from typing import List, Dict, Union
|
||||||
|
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
|
|
@ -5,6 +7,8 @@ from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyEx
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
||||||
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
|
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
|
||||||
|
import heapq
|
||||||
|
from graphistry import edges
|
||||||
|
|
||||||
|
|
||||||
class CogneeGraph(CogneeAbstractGraph):
|
class CogneeGraph(CogneeAbstractGraph):
|
||||||
|
|
@ -41,26 +45,33 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
def get_node(self, node_id: str) -> Node:
|
def get_node(self, node_id: str) -> Node:
|
||||||
return self.nodes.get(node_id, None)
|
return self.nodes.get(node_id, None)
|
||||||
|
|
||||||
def get_edges(self, node_id: str) -> List[Edge]:
|
def get_edges_from_node(self, node_id: str) -> List[Edge]:
|
||||||
node = self.get_node(node_id)
|
node = self.get_node(node_id)
|
||||||
if node:
|
if node:
|
||||||
return node.skeleton_edges
|
return node.skeleton_edges
|
||||||
else:
|
else:
|
||||||
raise EntityNotFoundError(message=f"Node with id {node_id} does not exist.")
|
raise EntityNotFoundError(message=f"Node with id {node_id} does not exist.")
|
||||||
|
|
||||||
|
def get_edges(self)-> List[Edge]:
|
||||||
|
return self.edges
|
||||||
|
|
||||||
async def project_graph_from_db(self,
|
async def project_graph_from_db(self,
|
||||||
adapter: Union[GraphDBInterface],
|
adapter: Union[GraphDBInterface],
|
||||||
node_properties_to_project: List[str],
|
node_properties_to_project: List[str],
|
||||||
edge_properties_to_project: List[str],
|
edge_properties_to_project: List[str],
|
||||||
directed = True,
|
directed = True,
|
||||||
node_dimension = 1,
|
node_dimension = 1,
|
||||||
edge_dimension = 1) -> None:
|
edge_dimension = 1,
|
||||||
|
memory_fragment_filter = []) -> None:
|
||||||
|
|
||||||
if node_dimension < 1 or edge_dimension < 1:
|
if node_dimension < 1 or edge_dimension < 1:
|
||||||
raise InvalidValueError(message="Dimensions must be positive integers")
|
raise InvalidValueError(message="Dimensions must be positive integers")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
nodes_data, edges_data = await adapter.get_graph_data()
|
if len(memory_fragment_filter) == 0:
|
||||||
|
nodes_data, edges_data = await adapter.get_graph_data()
|
||||||
|
else:
|
||||||
|
nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter)
|
||||||
|
|
||||||
if not nodes_data:
|
if not nodes_data:
|
||||||
raise EntityNotFoundError(message="No node data retrieved from the database.")
|
raise EntityNotFoundError(message="No node data retrieved from the database.")
|
||||||
|
|
@ -91,3 +102,81 @@ class CogneeGraph(CogneeAbstractGraph):
|
||||||
print(f"Error projecting graph: {e}")
|
print(f"Error projecting graph: {e}")
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(f"Unexpected error: {ex}")
|
print(f"Unexpected error: {ex}")
|
||||||
|
|
||||||
|
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
|
||||||
|
for category, scored_results in node_distances.items():
|
||||||
|
for scored_result in scored_results:
|
||||||
|
node_id = str(scored_result.id)
|
||||||
|
score = scored_result.score
|
||||||
|
node =self.get_node(node_id)
|
||||||
|
if node:
|
||||||
|
node.add_attribute("vector_distance", score)
|
||||||
|
else:
|
||||||
|
print(f"Node with id {node_id} not found in the graph.")
|
||||||
|
|
||||||
|
async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None: # :TODO: When we calculate edge embeddings in vector db change this similarly to node mapping
|
||||||
|
try:
|
||||||
|
# Step 1: Generate the query embedding
|
||||||
|
query_vector = await vector_engine.embed_data([query])
|
||||||
|
query_vector = query_vector[0]
|
||||||
|
if query_vector is None or len(query_vector) == 0:
|
||||||
|
raise ValueError("Failed to generate query embedding.")
|
||||||
|
|
||||||
|
# Step 2: Collect all unique relationship types
|
||||||
|
unique_relationship_types = set()
|
||||||
|
for edge in self.edges:
|
||||||
|
relationship_type = edge.attributes.get('relationship_type')
|
||||||
|
if relationship_type:
|
||||||
|
unique_relationship_types.add(relationship_type)
|
||||||
|
|
||||||
|
# Step 3: Embed all unique relationship types
|
||||||
|
unique_relationship_types = list(unique_relationship_types)
|
||||||
|
relationship_type_embeddings = await vector_engine.embed_data(unique_relationship_types)
|
||||||
|
|
||||||
|
# Step 4: Map relationship types to their embeddings and calculate distances
|
||||||
|
embedding_map = {}
|
||||||
|
for relationship_type, embedding in zip(unique_relationship_types, relationship_type_embeddings):
|
||||||
|
edge_vector = np.array(embedding)
|
||||||
|
|
||||||
|
# Calculate cosine similarity
|
||||||
|
similarity = np.dot(query_vector, edge_vector) / (
|
||||||
|
np.linalg.norm(query_vector) * np.linalg.norm(edge_vector)
|
||||||
|
)
|
||||||
|
distance = 1 - similarity
|
||||||
|
|
||||||
|
# Round the distance to 4 decimal places and store it
|
||||||
|
embedding_map[relationship_type] = round(distance, 4)
|
||||||
|
|
||||||
|
# Step 4: Assign precomputed distances to edges
|
||||||
|
for edge in self.edges:
|
||||||
|
relationship_type = edge.attributes.get('relationship_type')
|
||||||
|
if not relationship_type or relationship_type not in embedding_map:
|
||||||
|
print(f"Edge {edge} has an unknown or missing relationship type.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Assign the precomputed distance
|
||||||
|
edge.attributes["vector_distance"] = embedding_map[relationship_type]
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"Error mapping vector distances to edges: {ex}")
|
||||||
|
|
||||||
|
|
||||||
|
async def calculate_top_triplet_importances(self, k: int) -> List:
|
||||||
|
min_heap = []
|
||||||
|
for i, edge in enumerate(self.edges):
|
||||||
|
source_node = self.get_node(edge.node1.id)
|
||||||
|
target_node = self.get_node(edge.node2.id)
|
||||||
|
|
||||||
|
source_distance = source_node.attributes.get("vector_distance", 1) if source_node else 1
|
||||||
|
target_distance = target_node.attributes.get("vector_distance", 1) if target_node else 1
|
||||||
|
edge_distance = edge.attributes.get("vector_distance", 1)
|
||||||
|
|
||||||
|
total_distance = source_distance + target_distance + edge_distance
|
||||||
|
|
||||||
|
heapq.heappush(min_heap, (-total_distance, i, edge))
|
||||||
|
if len(min_heap) > k:
|
||||||
|
heapq.heappop(min_heap)
|
||||||
|
|
||||||
|
|
||||||
|
return [edge for _, _, edge in sorted(min_heap)]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Dict, Optional, Any
|
from typing import List, Dict, Optional, Any, Union
|
||||||
|
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
|
|
||||||
|
|
@ -24,6 +24,7 @@ class Node:
|
||||||
raise InvalidValueError(message="Dimension must be a positive integer")
|
raise InvalidValueError(message="Dimension must be a positive integer")
|
||||||
self.id = node_id
|
self.id = node_id
|
||||||
self.attributes = attributes if attributes is not None else {}
|
self.attributes = attributes if attributes is not None else {}
|
||||||
|
self.attributes["vector_distance"] = float('inf')
|
||||||
self.skeleton_neighbours = []
|
self.skeleton_neighbours = []
|
||||||
self.skeleton_edges = []
|
self.skeleton_edges = []
|
||||||
self.status = np.ones(dimension, dtype=int)
|
self.status = np.ones(dimension, dtype=int)
|
||||||
|
|
@ -58,6 +59,12 @@ class Node:
|
||||||
raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
||||||
return self.status[dimension] == 1
|
return self.status[dimension] == 1
|
||||||
|
|
||||||
|
def add_attribute(self, key: str, value: Any) -> None:
|
||||||
|
self.attributes[key] = value
|
||||||
|
|
||||||
|
def get_attribute(self, key: str) -> Union[str, int, float]:
|
||||||
|
return self.attributes[key]
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Node({self.id}, attributes={self.attributes})"
|
return f"Node({self.id}, attributes={self.attributes})"
|
||||||
|
|
||||||
|
|
@ -90,6 +97,7 @@ class Edge:
|
||||||
self.node1 = node1
|
self.node1 = node1
|
||||||
self.node2 = node2
|
self.node2 = node2
|
||||||
self.attributes = attributes if attributes is not None else {}
|
self.attributes = attributes if attributes is not None else {}
|
||||||
|
self.attributes["vector_distance"] = float('inf')
|
||||||
self.directed = directed
|
self.directed = directed
|
||||||
self.status = np.ones(dimension, dtype=int)
|
self.status = np.ones(dimension, dtype=int)
|
||||||
|
|
||||||
|
|
@ -98,6 +106,12 @@ class Edge:
|
||||||
raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
|
||||||
return self.status[dimension] == 1
|
return self.status[dimension] == 1
|
||||||
|
|
||||||
|
def add_attribute(self, key: str, value: Any) -> None:
|
||||||
|
self.attributes[key] = value
|
||||||
|
|
||||||
|
def get_attribute(self, key: str, value: Any) -> Union[str, int, float]:
|
||||||
|
return self.attributes[key]
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
direction = "->" if self.directed else "--"
|
direction = "->" if self.directed else "--"
|
||||||
return f"Edge({self.node1.id} {direction} {self.node2.id}, attributes={self.attributes})"
|
return f"Edge({self.node1.id} {direction} {self.node2.id}, attributes={self.attributes})"
|
||||||
|
|
|
||||||
0
cognee/modules/retrieval/__init__.py
Normal file
0
cognee/modules/retrieval/__init__.py
Normal file
150
cognee/modules/retrieval/brute_force_triplet_search.py
Normal file
150
cognee/modules/retrieval/brute_force_triplet_search.py
Normal file
|
|
@ -0,0 +1,150 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.shared.utils import send_telemetry
|
||||||
|
|
||||||
|
def format_triplets(edges):
|
||||||
|
print("\n\n\n")
|
||||||
|
def filter_attributes(obj, attributes):
|
||||||
|
"""Helper function to filter out non-None properties, including nested dicts."""
|
||||||
|
result = {}
|
||||||
|
for attr in attributes:
|
||||||
|
value = getattr(obj, attr, None)
|
||||||
|
if value is not None:
|
||||||
|
# If the value is a dict, extract relevant keys from it
|
||||||
|
if isinstance(value, dict):
|
||||||
|
nested_values = {k: v for k, v in value.items() if k in attributes and v is not None}
|
||||||
|
result[attr] = nested_values
|
||||||
|
else:
|
||||||
|
result[attr] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
triplets = []
|
||||||
|
for edge in edges:
|
||||||
|
node1 = edge.node1
|
||||||
|
node2 = edge.node2
|
||||||
|
edge_attributes = edge.attributes
|
||||||
|
node1_attributes = node1.attributes
|
||||||
|
node2_attributes = node2.attributes
|
||||||
|
|
||||||
|
# Filter only non-None properties
|
||||||
|
node1_info = {key: value for key, value in node1_attributes.items() if value is not None}
|
||||||
|
node2_info = {key: value for key, value in node2_attributes.items() if value is not None}
|
||||||
|
edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
|
||||||
|
|
||||||
|
# Create the formatted triplet
|
||||||
|
triplet = (
|
||||||
|
f"Node1: {node1_info}\n"
|
||||||
|
f"Edge: {edge_info}\n"
|
||||||
|
f"Node2: {node2_info}\n\n\n"
|
||||||
|
)
|
||||||
|
triplets.append(triplet)
|
||||||
|
|
||||||
|
return "".join(triplets)
|
||||||
|
|
||||||
|
|
||||||
|
async def brute_force_triplet_search(query: str, user: User = None, top_k = 5) -> list:
|
||||||
|
if user is None:
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise PermissionError("No user found in the system. Please create a user.")
|
||||||
|
|
||||||
|
retrieved_results = await brute_force_search(query, user, top_k)
|
||||||
|
|
||||||
|
|
||||||
|
return retrieved_results
|
||||||
|
|
||||||
|
|
||||||
|
def delete_duplicated_vector_db_elements(collections, results): #:TODO: This is just for now to fix vector db duplicates
|
||||||
|
results_dict = {}
|
||||||
|
for collection, results in zip(collections, results):
|
||||||
|
seen_ids = set()
|
||||||
|
unique_results = []
|
||||||
|
for result in results:
|
||||||
|
if result.id not in seen_ids:
|
||||||
|
unique_results.append(result)
|
||||||
|
seen_ids.add(result.id)
|
||||||
|
else:
|
||||||
|
print(f"Duplicate found in collection '{collection}': {result.id}")
|
||||||
|
results_dict[collection] = unique_results
|
||||||
|
|
||||||
|
return results_dict
|
||||||
|
|
||||||
|
|
||||||
|
async def brute_force_search(
|
||||||
|
query: str,
|
||||||
|
user: User,
|
||||||
|
top_k: int,
|
||||||
|
collections: List[str] = None
|
||||||
|
) -> list:
|
||||||
|
"""
|
||||||
|
Performs a brute force search to retrieve the top triplets from the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The search query.
|
||||||
|
user (User): The user performing the search.
|
||||||
|
top_k (int): The number of top results to retrieve.
|
||||||
|
collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The top triplet results.
|
||||||
|
"""
|
||||||
|
if not query or not isinstance(query, str):
|
||||||
|
raise ValueError("The query must be a non-empty string.")
|
||||||
|
if top_k <= 0:
|
||||||
|
raise ValueError("top_k must be a positive integer.")
|
||||||
|
|
||||||
|
if collections is None:
|
||||||
|
collections = ["entity_name", "text_summary_text", "entity_type_name", "document_chunk_text"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Failed to initialize engines: %s", e)
|
||||||
|
raise RuntimeError("Initialization error") from e
|
||||||
|
|
||||||
|
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[vector_engine.get_distance_from_collection_elements(collection, query_text=query) for collection in collections]
|
||||||
|
)
|
||||||
|
|
||||||
|
############################################# :TODO: Change when vector db does not contain duplicates
|
||||||
|
node_distances = delete_duplicated_vector_db_elements(collections, results)
|
||||||
|
# node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||||
|
##############################################
|
||||||
|
|
||||||
|
memory_fragment = CogneeGraph()
|
||||||
|
|
||||||
|
await memory_fragment.project_graph_from_db(graph_engine,
|
||||||
|
node_properties_to_project=['id',
|
||||||
|
'description',
|
||||||
|
'name',
|
||||||
|
'type',
|
||||||
|
'text'],
|
||||||
|
edge_properties_to_project=['relationship_name'])
|
||||||
|
|
||||||
|
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||||
|
|
||||||
|
#:TODO: Change when vectordb contains edge embeddings
|
||||||
|
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
|
||||||
|
|
||||||
|
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||||
|
|
||||||
|
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
||||||
|
|
||||||
|
#:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error during brute force search for user: %s, query: %s. Error: %s", user.id, query, e)
|
||||||
|
send_telemetry("cognee.brute_force_triplet_search EXECUTION FAILED", user.id)
|
||||||
|
raise RuntimeError("An error occurred during brute force search") from e
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
import dlt
|
import dlt
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import get_relational_config
|
from cognee.infrastructure.databases.relational import get_relational_config
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]:
|
def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]:
|
||||||
"""
|
"""
|
||||||
Handles propagation of the cognee database configuration to the dlt library
|
Handles propagation of the cognee database configuration to the dlt library
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import dlt
|
import dlt
|
||||||
import cognee.modules.ingestion as ingestion
|
import cognee.modules.ingestion as ingestion
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
@ -17,25 +18,33 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
||||||
)
|
)
|
||||||
|
|
||||||
@dlt.resource(standalone = True, merge_key = "id")
|
@dlt.resource(standalone = True, merge_key = "id")
|
||||||
async def data_resources(file_paths: str, user: User):
|
async def data_resources(file_paths: str):
|
||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
with open(file_path.replace("file://", ""), mode = "rb") as file:
|
with open(file_path.replace("file://", ""), mode = "rb") as file:
|
||||||
classified_data = ingestion.classify(file)
|
classified_data = ingestion.classify(file)
|
||||||
|
|
||||||
data_id = ingestion.identify(classified_data)
|
data_id = ingestion.identify(classified_data)
|
||||||
|
|
||||||
file_metadata = classified_data.get_metadata()
|
file_metadata = classified_data.get_metadata()
|
||||||
|
yield {
|
||||||
|
"id": data_id,
|
||||||
|
"name": file_metadata["name"],
|
||||||
|
"file_path": file_metadata["file_path"],
|
||||||
|
"extension": file_metadata["extension"],
|
||||||
|
"mime_type": file_metadata["mime_type"],
|
||||||
|
}
|
||||||
|
|
||||||
from sqlalchemy import select
|
async def data_storing(table_name, dataset_name, user: User):
|
||||||
from cognee.modules.data.models import Data
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
async with db_engine.get_async_session() as session:
|
||||||
|
# Read metadata stored with dlt
|
||||||
async with db_engine.get_async_session() as session:
|
files_metadata = await db_engine.get_all_data_from_table(table_name, dataset_name)
|
||||||
|
for file_metadata in files_metadata:
|
||||||
|
from sqlalchemy import select
|
||||||
|
from cognee.modules.data.models import Data
|
||||||
dataset = await create_dataset(dataset_name, user.id, session)
|
dataset = await create_dataset(dataset_name, user.id, session)
|
||||||
|
|
||||||
data = (await session.execute(
|
data = (await session.execute(
|
||||||
select(Data).filter(Data.id == data_id)
|
select(Data).filter(Data.id == UUID(file_metadata["id"]))
|
||||||
)).scalar_one_or_none()
|
)).scalar_one_or_none()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
|
|
@ -48,7 +57,7 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
||||||
await session.commit()
|
await session.commit()
|
||||||
else:
|
else:
|
||||||
data = Data(
|
data = Data(
|
||||||
id = data_id,
|
id = UUID(file_metadata["id"]),
|
||||||
name = file_metadata["name"],
|
name = file_metadata["name"],
|
||||||
raw_data_location = file_metadata["file_path"],
|
raw_data_location = file_metadata["file_path"],
|
||||||
extension = file_metadata["extension"],
|
extension = file_metadata["extension"],
|
||||||
|
|
@ -58,25 +67,34 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
|
||||||
dataset.data.append(data)
|
dataset.data.append(data)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
yield {
|
await give_permission_on_document(user, UUID(file_metadata["id"]), "read")
|
||||||
"id": data_id,
|
await give_permission_on_document(user, UUID(file_metadata["id"]), "write")
|
||||||
"name": file_metadata["name"],
|
|
||||||
"file_path": file_metadata["file_path"],
|
|
||||||
"extension": file_metadata["extension"],
|
|
||||||
"mime_type": file_metadata["mime_type"],
|
|
||||||
}
|
|
||||||
|
|
||||||
await give_permission_on_document(user, data_id, "read")
|
|
||||||
await give_permission_on_document(user, data_id, "write")
|
|
||||||
|
|
||||||
|
|
||||||
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
|
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
|
||||||
run_info = pipeline.run(
|
|
||||||
data_resources(file_paths, user),
|
db_engine = get_relational_engine()
|
||||||
table_name = "file_metadata",
|
|
||||||
dataset_name = dataset_name,
|
# Note: DLT pipeline has its own event loop, therefore objects created in another event loop
|
||||||
write_disposition = "merge",
|
# can't be used inside the pipeline
|
||||||
)
|
if db_engine.engine.dialect.name == "sqlite":
|
||||||
|
# To use sqlite with dlt dataset_name must be set to "main".
|
||||||
|
# Sqlite doesn't support schemas
|
||||||
|
run_info = pipeline.run(
|
||||||
|
data_resources(file_paths),
|
||||||
|
table_name = "file_metadata",
|
||||||
|
dataset_name = "main",
|
||||||
|
write_disposition = "merge",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
run_info = pipeline.run(
|
||||||
|
data_resources(file_paths),
|
||||||
|
table_name="file_metadata",
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
write_disposition="merge",
|
||||||
|
)
|
||||||
|
|
||||||
|
await data_storing("file_metadata", dataset_name, user)
|
||||||
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
|
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
|
||||||
|
|
||||||
return run_info
|
return run_info
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
|
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
@ -61,6 +62,9 @@ async def main():
|
||||||
|
|
||||||
assert len(history) == 6, "Search history is not correct."
|
assert len(history) == 6, "Search history is not correct."
|
||||||
|
|
||||||
|
results = await brute_force_triplet_search('What is a quantum computer?')
|
||||||
|
assert len(results) > 0
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
|
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
@ -89,6 +90,9 @@ async def main():
|
||||||
history = await cognee.get_search_history()
|
history = await cognee.get_search_history()
|
||||||
assert len(history) == 6, "Search history is not correct."
|
assert len(history) == 6, "Search history is not correct."
|
||||||
|
|
||||||
|
results = await brute_force_triplet_search('What is a quantum computer?')
|
||||||
|
assert len(results) > 0
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
|
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
@ -61,6 +62,9 @@ async def main():
|
||||||
history = await cognee.get_search_history()
|
history = await cognee.get_search_history()
|
||||||
assert len(history) == 6, "Search history is not correct."
|
assert len(history) == 6, "Search history is not correct."
|
||||||
|
|
||||||
|
results = await brute_force_triplet_search('What is a quantum computer?')
|
||||||
|
assert len(results) > 0
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.api.v1.search import SearchType
|
||||||
|
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
@ -59,6 +60,9 @@ async def main():
|
||||||
history = await cognee.get_search_history()
|
history = await cognee.get_search_history()
|
||||||
assert len(history) == 6, "Search history is not correct."
|
assert len(history) == 6, "Search history is not correct."
|
||||||
|
|
||||||
|
results = await brute_force_triplet_search('What is a quantum computer?')
|
||||||
|
assert len(results) > 0
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ def test_node_initialization():
|
||||||
"""Test that a Node is initialized correctly."""
|
"""Test that a Node is initialized correctly."""
|
||||||
node = Node("node1", {"attr1": "value1"}, dimension=2)
|
node = Node("node1", {"attr1": "value1"}, dimension=2)
|
||||||
assert node.id == "node1"
|
assert node.id == "node1"
|
||||||
assert node.attributes == {"attr1": "value1"}
|
assert node.attributes == {"attr1": "value1", 'vector_distance': np.inf}
|
||||||
assert len(node.status) == 2
|
assert len(node.status) == 2
|
||||||
assert np.all(node.status == 1)
|
assert np.all(node.status == 1)
|
||||||
|
|
||||||
|
|
@ -96,7 +96,7 @@ def test_edge_initialization():
|
||||||
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
|
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
|
||||||
assert edge.node1 == node1
|
assert edge.node1 == node1
|
||||||
assert edge.node2 == node2
|
assert edge.node2 == node2
|
||||||
assert edge.attributes == {"weight": 10}
|
assert edge.attributes == {'vector_distance': np.inf,"weight": 10}
|
||||||
assert edge.directed is False
|
assert edge.directed is False
|
||||||
assert len(edge.status) == 2
|
assert len(edge.status) == 2
|
||||||
assert np.all(edge.status == 1)
|
assert np.all(edge.status == 1)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from cognee.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||||
|
|
||||||
|
|
@ -78,11 +78,11 @@ def test_get_edges_success(setup_graph):
|
||||||
graph.add_node(node2)
|
graph.add_node(node2)
|
||||||
edge = Edge(node1, node2)
|
edge = Edge(node1, node2)
|
||||||
graph.add_edge(edge)
|
graph.add_edge(edge)
|
||||||
assert edge in graph.get_edges("node1")
|
assert edge in graph.get_edges_from_node("node1")
|
||||||
|
|
||||||
|
|
||||||
def test_get_edges_nonexistent_node(setup_graph):
|
def test_get_edges_nonexistent_node(setup_graph):
|
||||||
"""Test retrieving edges for a nonexistent node raises an exception."""
|
"""Test retrieving edges for a nonexistent node raises an exception."""
|
||||||
graph = setup_graph
|
graph = setup_graph
|
||||||
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
|
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
|
||||||
graph.get_edges("nonexistent")
|
graph.get_edges_from_node("nonexistent")
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ services:
|
||||||
- 7687:7687
|
- 7687:7687
|
||||||
environment:
|
environment:
|
||||||
- NEO4J_AUTH=neo4j/pleaseletmein
|
- NEO4J_AUTH=neo4j/pleaseletmein
|
||||||
- NEO4J_PLUGINS=["apoc"]
|
- NEO4J_PLUGINS=["apoc", "graph-data-science"]
|
||||||
networks:
|
networks:
|
||||||
- cognee-network
|
- cognee-network
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,7 @@
|
||||||
import cognee
|
import cognee
|
||||||
import asyncio
|
import asyncio
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||||
|
from cognee.modules.retrieval.brute_force_triplet_search import format_triplets
|
||||||
job_position = """0:Senior Data Scientist (Machine Learning)
|
|
||||||
|
|
||||||
Company: TechNova Solutions
|
|
||||||
Location: San Francisco, CA
|
|
||||||
|
|
||||||
Job Description:
|
|
||||||
|
|
||||||
TechNova Solutions is seeking a Senior Data Scientist specializing in Machine Learning to join our dynamic analytics team. The ideal candidate will have a strong background in developing and deploying machine learning models, working with large datasets, and translating complex data into actionable insights.
|
|
||||||
|
|
||||||
Responsibilities:
|
|
||||||
|
|
||||||
Develop and implement advanced machine learning algorithms and models.
|
|
||||||
Analyze large, complex datasets to extract meaningful patterns and insights.
|
|
||||||
Collaborate with cross-functional teams to integrate predictive models into products.
|
|
||||||
Stay updated with the latest advancements in machine learning and data science.
|
|
||||||
Mentor junior data scientists and provide technical guidance.
|
|
||||||
Qualifications:
|
|
||||||
|
|
||||||
Master’s or Ph.D. in Data Science, Computer Science, Statistics, or a related field.
|
|
||||||
5+ years of experience in data science and machine learning.
|
|
||||||
Proficient in Python, R, and SQL.
|
|
||||||
Experience with deep learning frameworks (e.g., TensorFlow, PyTorch).
|
|
||||||
Strong problem-solving skills and attention to detail.
|
|
||||||
Candidate CVs
|
|
||||||
"""
|
|
||||||
|
|
||||||
job_1 = """
|
job_1 = """
|
||||||
CV 1: Relevant
|
CV 1: Relevant
|
||||||
|
|
@ -195,7 +170,7 @@ async def main(enable_steps):
|
||||||
|
|
||||||
# Step 2: Add text
|
# Step 2: Add text
|
||||||
if enable_steps.get("add_text"):
|
if enable_steps.get("add_text"):
|
||||||
text_list = [job_position, job_1, job_2, job_3, job_4, job_5]
|
text_list = [job_1, job_2, job_3, job_4, job_5]
|
||||||
for text in text_list:
|
for text in text_list:
|
||||||
await cognee.add(text)
|
await cognee.add(text)
|
||||||
print(f"Added text: {text[:35]}...")
|
print(f"Added text: {text[:35]}...")
|
||||||
|
|
@ -206,24 +181,21 @@ async def main(enable_steps):
|
||||||
print("Knowledge graph created.")
|
print("Knowledge graph created.")
|
||||||
|
|
||||||
# Step 4: Query insights
|
# Step 4: Query insights
|
||||||
if enable_steps.get("search_insights"):
|
if enable_steps.get("retriever"):
|
||||||
search_results = await cognee.search(
|
results = await brute_force_triplet_search('Who has the most experience with graphic design?')
|
||||||
SearchType.INSIGHTS,
|
print(format_triplets(results))
|
||||||
{'query': 'Which applicant has the most relevant experience in data science?'}
|
|
||||||
)
|
|
||||||
print("Search results:")
|
|
||||||
for result_text in search_results:
|
|
||||||
print(result_text)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Flags to enable/disable steps
|
# Flags to enable/disable steps
|
||||||
|
|
||||||
|
rebuild_kg = True
|
||||||
|
retrieve = True
|
||||||
steps_to_enable = {
|
steps_to_enable = {
|
||||||
"prune_data": True,
|
"prune_data": rebuild_kg,
|
||||||
"prune_system": True,
|
"prune_system": rebuild_kg,
|
||||||
"add_text": True,
|
"add_text": rebuild_kg,
|
||||||
"cognify": True,
|
"cognify": rebuild_kg,
|
||||||
"search_insights": True
|
"retriever": retrieve
|
||||||
}
|
}
|
||||||
|
|
||||||
asyncio.run(main(steps_to_enable))
|
asyncio.run(main(steps_to_enable))
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue