Merge branch 'main' of github.com:topoteretes/cognee-private into COG-502-backend-error-handling

This commit is contained in:
Igor Ilic 2024-12-02 13:12:20 +01:00
commit 04960eeb4e
33 changed files with 864 additions and 261 deletions

66
.github/workflows/reusable_notebook.yml vendored Normal file
View file

@ -0,0 +1,66 @@
name: test-notebook
on:
workflow_call:
inputs:
notebook-location:
description: "Location of Jupyter notebook to run"
required: true
type: string
secrets:
GRAPHISTRY_USERNAME:
required: true
GRAPHISTRY_PASSWORD:
required: true
OPENAI_API_KEY:
required: true
env:
RUNTIME__LOG_LEVEL: ERROR
jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml
run_notebook_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:
run:
shell: bash
steps:
- name: Check out
uses: actions/checkout@master
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Install dependencies
run: |
poetry install --no-interaction --all-extras
poetry add jupyter --no-interaction
- name: Execute Jupyter Notebook
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
run: |
poetry run jupyter nbconvert \
--to notebook \
--execute ${{ inputs.notebook-location }} \
--output executed_notebook.ipynb \
--ExecutePreprocessor.timeout=1200

View file

@ -0,0 +1,60 @@
name: test-example
on:
workflow_call:
inputs:
example-location:
description: "Location of example script to run"
required: true
type: string
secrets:
GRAPHISTRY_USERNAME:
required: true
GRAPHISTRY_PASSWORD:
required: true
OPENAI_API_KEY:
required: true
env:
RUNTIME__LOG_LEVEL: ERROR
jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml
run_notebook_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:
run:
shell: bash
steps:
- name: Check out
uses: actions/checkout@master
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Install dependencies
run: |
poetry install --no-interaction --all-extras
- name: Execute Python Example
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
run: poetry run python ${{ inputs.example-location }}

View file

@ -7,57 +7,16 @@ on:
- main - main
types: [labeled, synchronize] types: [labeled, synchronize]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true cancel-in-progress: true
env:
RUNTIME__LOG_LEVEL: ERROR
jobs: jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml
run_notebook_test: run_notebook_test:
name: test uses: ./.github/workflows/reusable_notebook.yml
needs: get_docs_changes with:
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && github.event.label.name == 'run-checks' notebook-location: notebooks/cognee_llama_index.ipynb
runs-on: ubuntu-latest secrets:
defaults: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
shell: bash GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
steps:
- name: Check out
uses: actions/checkout@master
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Install dependencies
run: |
poetry install --no-interaction --all-extras
poetry add jupyter --no-interaction
- name: Execute Jupyter Notebook
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
run: |
poetry run jupyter nbconvert \
--to notebook \
--execute notebooks/cognee_llama_index.ipynb \
--output executed_notebook.ipynb \
--ExecutePreprocessor.timeout=1200

View file

@ -7,57 +7,16 @@ on:
- main - main
types: [labeled, synchronize] types: [labeled, synchronize]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true cancel-in-progress: true
env:
RUNTIME__LOG_LEVEL: ERROR
jobs: jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml
run_notebook_test: run_notebook_test:
name: test uses: ./.github/workflows/reusable_notebook.yml
needs: get_docs_changes with:
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }} notebook-location: notebooks/cognee_multimedia_demo.ipynb
runs-on: ubuntu-latest secrets:
defaults: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
shell: bash GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
steps:
- name: Check out
uses: actions/checkout@master
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Install dependencies
run: |
poetry install --no-interaction
poetry add jupyter --no-interaction
- name: Execute Jupyter Notebook
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
run: |
poetry run jupyter nbconvert \
--to notebook \
--execute notebooks/cognee_multimedia_demo.ipynb \
--output executed_notebook.ipynb \
--ExecutePreprocessor.timeout=1200

View file

@ -0,0 +1,23 @@
name: test | dynamic steps example
on:
workflow_dispatch:
pull_request:
branches:
- main
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
run_dynamic_steps_example_test:
uses: ./.github/workflows/reusable_python_example.yml
with:
example-location: ./examples/python/dynamic_steps_example.py
secrets:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}

View file

@ -0,0 +1,23 @@
name: test | multimedia example
on:
workflow_dispatch:
pull_request:
branches:
- main
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
run_multimedia_example_test:
uses: ./.github/workflows/reusable_python_example.yml
with:
example-location: ./examples/python/multimedia_example.py
secrets:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}

View file

@ -12,52 +12,12 @@ concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true cancel-in-progress: true
env:
RUNTIME__LOG_LEVEL: ERROR
jobs: jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml
run_notebook_test: run_notebook_test:
name: test uses: ./.github/workflows/reusable_notebook.yml
needs: get_docs_changes with:
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }} notebook-location: notebooks/cognee_demo.ipynb
runs-on: ubuntu-latest secrets:
defaults: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
shell: bash GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
steps:
- name: Check out
uses: actions/checkout@master
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'
- name: Install Poetry
uses: snok/install-poetry@v1.3.2
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Install dependencies
run: |
poetry install --no-interaction
poetry add jupyter --no-interaction
- name: Execute Jupyter Notebook
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}
run: |
poetry run jupyter nbconvert \
--to notebook \
--execute notebooks/cognee_demo.ipynb \
--output executed_notebook.ipynb \
--ExecutePreprocessor.timeout=1200

View file

@ -0,0 +1,23 @@
name: test | simple example
on:
workflow_dispatch:
pull_request:
branches:
- main
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
run_simple_example_test:
uses: ./.github/workflows/reusable_python_example.yml
with:
example-location: ./examples/python/simple_example.py
secrets:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHISTRY_USERNAME: ${{ secrets.GRAPHISTRY_USERNAME }}
GRAPHISTRY_PASSWORD: ${{ secrets.GRAPHISTRY_PASSWORD }}

View file

@ -1,11 +1,28 @@
"""Factory function to get the appropriate graph client based on the graph type.""" """Factory function to get the appropriate graph client based on the graph type."""
from functools import lru_cache
from .config import get_graph_config from .config import get_graph_config
from .graph_db_interface import GraphDBInterface from .graph_db_interface import GraphDBInterface
async def get_graph_engine() -> GraphDBInterface : async def get_graph_engine() -> GraphDBInterface:
"""Factory function to get the appropriate graph client based on the graph type.""" """Factory function to get the appropriate graph client based on the graph type."""
graph_client = create_graph_engine()
# Async functions can't be cached. After creating and caching the graph engine
# handle all necessary async operations for different graph types bellow.
config = get_graph_config()
# Handle loading of graph for NetworkX
if config.graph_database_provider.lower() == "networkx" and graph_client.graph is None:
await graph_client.load_graph_from_file()
return graph_client
@lru_cache
def create_graph_engine() -> GraphDBInterface:
"""Factory function to create the appropriate graph client based on the graph type."""
config = get_graph_config() config = get_graph_config()
if config.graph_database_provider == "neo4j": if config.graph_database_provider == "neo4j":
@ -15,9 +32,9 @@ async def get_graph_engine() -> GraphDBInterface :
from .neo4j_driver.adapter import Neo4jAdapter from .neo4j_driver.adapter import Neo4jAdapter
return Neo4jAdapter( return Neo4jAdapter(
graph_database_url = config.graph_database_url, graph_database_url=config.graph_database_url,
graph_database_username = config.graph_database_username, graph_database_username=config.graph_database_username,
graph_database_password = config.graph_database_password graph_database_password=config.graph_database_password
) )
elif config.graph_database_provider == "falkordb": elif config.graph_database_provider == "falkordb":
@ -30,15 +47,12 @@ async def get_graph_engine() -> GraphDBInterface :
embedding_engine = get_embedding_engine() embedding_engine = get_embedding_engine()
return FalkorDBAdapter( return FalkorDBAdapter(
database_url = config.graph_database_url, database_url=config.graph_database_url,
database_port = config.graph_database_port, database_port=config.graph_database_port,
embedding_engine = embedding_engine, embedding_engine=embedding_engine,
) )
from .networkx.adapter import NetworkXAdapter from .networkx.adapter import NetworkXAdapter
graph_client = NetworkXAdapter(filename = config.graph_file_path) graph_client = NetworkXAdapter(filename=config.graph_file_path)
if graph_client.graph is None:
await graph_client.load_graph_from_file()
return graph_client return graph_client

View file

@ -2,7 +2,7 @@
import logging import logging
import asyncio import asyncio
from textwrap import dedent from textwrap import dedent
from typing import Optional, Any, List, Dict from typing import Optional, Any, List, Dict, Union
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from uuid import UUID from uuid import UUID
from neo4j import AsyncSession from neo4j import AsyncSession
@ -432,3 +432,49 @@ class Neo4jAdapter(GraphDBInterface):
) for record in result] ) for record in result]
return (nodes, edges) return (nodes, edges)
async def get_filtered_graph_data(self, attribute_filters):
"""
Fetches nodes and relationships filtered by specified attribute values.
Args:
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
Example: [{"community": ["1", "2"]}]
Returns:
tuple: A tuple containing two lists: nodes and edges.
"""
where_clauses = []
for attribute, values in attribute_filters[0].items():
values_str = ", ".join(f"'{value}'" if isinstance(value, str) else str(value) for value in values)
where_clauses.append(f"n.{attribute} IN [{values_str}]")
where_clause = " AND ".join(where_clauses)
query_nodes = f"""
MATCH (n)
WHERE {where_clause}
RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties
"""
result_nodes = await self.query(query_nodes)
nodes = [(
record["id"],
record["properties"],
) for record in result_nodes]
query_edges = f"""
MATCH (n)-[r]->(m)
WHERE {where_clause} AND {where_clause.replace('n.', 'm.')}
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result_edges = await self.query(query_edges)
edges = [(
record["source"],
record["target"],
record["type"],
record["properties"],
) for record in result_edges]
return (nodes, edges)

View file

@ -6,7 +6,7 @@ import json
import asyncio import asyncio
import logging import logging
from re import A from re import A
from typing import Dict, Any, List from typing import Dict, Any, List, Union
from uuid import UUID from uuid import UUID
import aiofiles import aiofiles
import aiofiles.os as aiofiles_os import aiofiles.os as aiofiles_os
@ -301,3 +301,39 @@ class NetworkXAdapter(GraphDBInterface):
logger.info("Graph deleted successfully.") logger.info("Graph deleted successfully.")
except Exception as error: except Exception as error:
logger.error("Failed to delete graph: %s", error) logger.error("Failed to delete graph: %s", error)
async def get_filtered_graph_data(self, attribute_filters: List[Dict[str, List[Union[str, int]]]]):
"""
Fetches nodes and relationships filtered by specified attribute values.
Args:
attribute_filters (list of dict): A list of dictionaries where keys are attributes and values are lists of values to filter on.
Example: [{"community": ["1", "2"]}]
Returns:
tuple: A tuple containing two lists:
- Nodes: List of tuples (node_id, node_properties).
- Edges: List of tuples (source_id, target_id, relationship_type, edge_properties).
"""
# Create filters for nodes based on the attribute filters
where_clauses = []
for attribute, values in attribute_filters[0].items():
where_clauses.append((attribute, values))
# Filter nodes
filtered_nodes = [
(node, data) for node, data in self.graph.nodes(data=True)
if all(data.get(attr) in values for attr, values in where_clauses)
]
# Filter edges where both source and target nodes satisfy the filters
filtered_edges = [
(source, target, data.get('relationship_type', 'UNKNOWN'), data)
for source, target, data in self.graph.edges(data=True)
if (
all(self.graph.nodes[source].get(attr) in values for attr, values in where_clauses) and
all(self.graph.nodes[target].get(attr) in values for attr, values in where_clauses)
)
]
return filtered_nodes, filtered_edges

View file

@ -1,7 +1,9 @@
from functools import lru_cache
from .config import get_relational_config from .config import get_relational_config
from .create_relational_engine import create_relational_engine from .create_relational_engine import create_relational_engine
@lru_cache
def get_relational_engine(): def get_relational_engine():
relational_config = get_relational_config() relational_config = get_relational_config()

View file

@ -172,6 +172,27 @@ class SQLAlchemyAdapter():
results = await connection.execute(query) results = await connection.execute(query)
return {result["data_id"]: result["status"] for result in results} return {result["data_id"]: result["status"] for result in results}
async def get_all_data_from_table(self, table_name: str, schema: str = "public"):
async with self.get_async_session() as session:
# Validate inputs to prevent SQL injection
if not table_name.isidentifier():
raise ValueError("Invalid table name")
if schema and not schema.isidentifier():
raise ValueError("Invalid schema name")
if self.engine.dialect.name == "sqlite":
table = await self.get_table(table_name)
else:
table = await self.get_table(table_name, schema)
# Query all data from the table
query = select(table)
result = await session.execute(query)
# Fetch all rows as a list of dictionaries
rows = result.mappings().all()
return rows
async def execute_query(self, query): async def execute_query(self, query):
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
result = await connection.execute(text(query)) result = await connection.execute(text(query))
@ -206,7 +227,6 @@ class SQLAlchemyAdapter():
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
LocalStorage.remove(self.db_path) LocalStorage.remove(self.db_path)
self.db_path = None
else: else:
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
schema_list = await self.get_schema_list() schema_list = await self.get_schema_list()

View file

@ -12,6 +12,7 @@ from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties from cognee.modules.storage.utils import copy_model, get_own_properties
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
@ -143,6 +144,33 @@ class LanceDBAdapter(VectorDBInterface):
score = 0, score = 0,
) for result in results.to_dict("index").values()] ) for result in results.to_dict("index").values()]
async def get_distance_from_collection_elements(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None
):
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
results = await collection.vector_search(query_vector).to_pandas()
result_values = list(results.to_dict("index").values())
normalized_values = normalize_distances(result_values)
return [ScoredResult(
id=UUID(result["id"]),
payload=result["payload"],
score=normalized_values[value_index],
) for value_index, result in enumerate(result_values)]
async def search( async def search(
self, self,
collection_name: str, collection_name: str,
@ -150,6 +178,7 @@ class LanceDBAdapter(VectorDBInterface):
query_vector: List[float] = None, query_vector: List[float] = None,
limit: int = 5, limit: int = 5,
with_vector: bool = False, with_vector: bool = False,
normalized: bool = True
): ):
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!") raise InvalidValueError(message="One of query_text or query_vector must be provided!")
@ -164,26 +193,7 @@ class LanceDBAdapter(VectorDBInterface):
result_values = list(results.to_dict("index").values()) result_values = list(results.to_dict("index").values())
min_value = 100 normalized_values = normalize_distances(result_values)
max_value = 0
for result in result_values:
value = float(result["_distance"])
if value > max_value:
max_value = value
if value < min_value:
min_value = value
normalized_values = []
min_value = min(result["_distance"] for result in result_values)
max_value = max(result["_distance"] for result in result_values)
if max_value == min_value:
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
normalized_values = [0 for _ in result_values]
else:
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
result_values]
return [ScoredResult( return [ScoredResult(
id = UUID(result["id"]), id = UUID(result["id"]),

View file

@ -13,6 +13,7 @@ from cognee.infrastructure.engine import DataPoint
from .serialize_data import serialize_data from .serialize_data import serialize_data
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ...relational.ModelBase import Base from ...relational.ModelBase import Base
@ -24,6 +25,19 @@ class IndexSchema(DataPoint):
"index_fields": ["text"] "index_fields": ["text"]
} }
def singleton(class_):
# Note: Using this singleton as a decorator to a class removes
# the option to use class methods for that class
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
@singleton
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
def __init__( def __init__(
@ -164,6 +178,51 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
) for result in results ) for result in results
] ]
async def get_distance_from_collection_elements(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
with_vector: bool = False
)-> List[ScoredResult]:
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Use async session to connect to the database
async with self.get_async_session() as session:
# Find closest vectors to query_vector
closest_items = await session.execute(
select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
"similarity"
),
)
.order_by("similarity")
)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects
return [
ScoredResult(
id = UUID(str(row.id)),
payload = row.payload,
score = row.similarity
) for row in vector_list
]
async def search( async def search(
self, self,
collection_name: str, collection_name: str,

View file

@ -1,12 +1,12 @@
from ...relational.ModelBase import Base
from ..get_vector_engine import get_vector_engine, get_vectordb_config from ..get_vector_engine import get_vector_engine, get_vectordb_config
from sqlalchemy import text from sqlalchemy import text
async def create_db_and_tables(): async def create_db_and_tables():
vector_config = get_vectordb_config() vector_config = get_vectordb_config()
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
if vector_config.vector_db_provider == "pgvector": if vector_config.vector_db_provider == "pgvector":
await vector_engine.create_database()
async with vector_engine.engine.begin() as connection: async with vector_engine.engine.begin() as connection:
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))

View file

@ -143,6 +143,41 @@ class QDrantAdapter(VectorDBInterface):
await client.close() await client.close()
return results return results
async def get_distance_from_collection_elements(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
with_vector: bool = False
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
client = self.get_qdrant_client()
results = await client.search(
collection_name = collection_name,
query_vector = models.NamedVector(
name = "text",
vector = query_vector if query_vector is not None else (await self.embed_data([query_text]))[0],
),
with_vectors = with_vector
)
await client.close()
return [
ScoredResult(
id = UUID(result.id),
payload = {
**result.payload,
"id": UUID(result.id),
},
score = 1 - result.score,
) for result in results
]
async def search( async def search(
self, self,
collection_name: str, collection_name: str,

View file

@ -0,0 +1,16 @@
from typing import List
def normalize_distances(result_values: List[dict]) -> List[float]:
min_value = min(result["_distance"] for result in result_values)
max_value = max(result["_distance"] for result in result_values)
if max_value == min_value:
# Avoid division by zero: Assign all normalized values to 0 (or any constant value like 1)
normalized_values = [0 for _ in result_values]
else:
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
result_values]
return normalized_values

View file

@ -154,6 +154,36 @@ class WeaviateAdapter(VectorDBInterface):
return await future return await future
async def get_distance_from_collection_elements(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
with_vector: bool = False
) -> List[ScoredResult]:
import weaviate.classes as wvc
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_vector is None:
query_vector = (await self.embed_data([query_text]))[0]
search_result = self.get_collection(collection_name).query.hybrid(
query=None,
vector=query_vector,
include_vector=with_vector,
return_metadata=wvc.query.MetadataQuery(score=True),
)
return [
ScoredResult(
id=UUID(str(result.uuid)),
payload=result.properties,
score=1 - float(result.metadata.score)
) for result in search_result.objects
]
async def search( async def search(
self, self,
collection_name: str, collection_name: str,

View file

@ -1,3 +1,5 @@
import numpy as np
from typing import List, Dict, Union from typing import List, Dict, Union
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
@ -5,6 +7,8 @@ from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyEx
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
import heapq
from graphistry import edges
class CogneeGraph(CogneeAbstractGraph): class CogneeGraph(CogneeAbstractGraph):
@ -41,26 +45,33 @@ class CogneeGraph(CogneeAbstractGraph):
def get_node(self, node_id: str) -> Node: def get_node(self, node_id: str) -> Node:
return self.nodes.get(node_id, None) return self.nodes.get(node_id, None)
def get_edges(self, node_id: str) -> List[Edge]: def get_edges_from_node(self, node_id: str) -> List[Edge]:
node = self.get_node(node_id) node = self.get_node(node_id)
if node: if node:
return node.skeleton_edges return node.skeleton_edges
else: else:
raise EntityNotFoundError(message=f"Node with id {node_id} does not exist.") raise EntityNotFoundError(message=f"Node with id {node_id} does not exist.")
def get_edges(self)-> List[Edge]:
return self.edges
async def project_graph_from_db(self, async def project_graph_from_db(self,
adapter: Union[GraphDBInterface], adapter: Union[GraphDBInterface],
node_properties_to_project: List[str], node_properties_to_project: List[str],
edge_properties_to_project: List[str], edge_properties_to_project: List[str],
directed = True, directed = True,
node_dimension = 1, node_dimension = 1,
edge_dimension = 1) -> None: edge_dimension = 1,
memory_fragment_filter = []) -> None:
if node_dimension < 1 or edge_dimension < 1: if node_dimension < 1 or edge_dimension < 1:
raise InvalidValueError(message="Dimensions must be positive integers") raise InvalidValueError(message="Dimensions must be positive integers")
try: try:
nodes_data, edges_data = await adapter.get_graph_data() if len(memory_fragment_filter) == 0:
nodes_data, edges_data = await adapter.get_graph_data()
else:
nodes_data, edges_data = await adapter.get_filtered_graph_data(attribute_filters = memory_fragment_filter)
if not nodes_data: if not nodes_data:
raise EntityNotFoundError(message="No node data retrieved from the database.") raise EntityNotFoundError(message="No node data retrieved from the database.")
@ -91,3 +102,81 @@ class CogneeGraph(CogneeAbstractGraph):
print(f"Error projecting graph: {e}") print(f"Error projecting graph: {e}")
except Exception as ex: except Exception as ex:
print(f"Unexpected error: {ex}") print(f"Unexpected error: {ex}")
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
for category, scored_results in node_distances.items():
for scored_result in scored_results:
node_id = str(scored_result.id)
score = scored_result.score
node =self.get_node(node_id)
if node:
node.add_attribute("vector_distance", score)
else:
print(f"Node with id {node_id} not found in the graph.")
async def map_vector_distances_to_graph_edges(self, vector_engine, query) -> None: # :TODO: When we calculate edge embeddings in vector db change this similarly to node mapping
try:
# Step 1: Generate the query embedding
query_vector = await vector_engine.embed_data([query])
query_vector = query_vector[0]
if query_vector is None or len(query_vector) == 0:
raise ValueError("Failed to generate query embedding.")
# Step 2: Collect all unique relationship types
unique_relationship_types = set()
for edge in self.edges:
relationship_type = edge.attributes.get('relationship_type')
if relationship_type:
unique_relationship_types.add(relationship_type)
# Step 3: Embed all unique relationship types
unique_relationship_types = list(unique_relationship_types)
relationship_type_embeddings = await vector_engine.embed_data(unique_relationship_types)
# Step 4: Map relationship types to their embeddings and calculate distances
embedding_map = {}
for relationship_type, embedding in zip(unique_relationship_types, relationship_type_embeddings):
edge_vector = np.array(embedding)
# Calculate cosine similarity
similarity = np.dot(query_vector, edge_vector) / (
np.linalg.norm(query_vector) * np.linalg.norm(edge_vector)
)
distance = 1 - similarity
# Round the distance to 4 decimal places and store it
embedding_map[relationship_type] = round(distance, 4)
# Step 4: Assign precomputed distances to edges
for edge in self.edges:
relationship_type = edge.attributes.get('relationship_type')
if not relationship_type or relationship_type not in embedding_map:
print(f"Edge {edge} has an unknown or missing relationship type.")
continue
# Assign the precomputed distance
edge.attributes["vector_distance"] = embedding_map[relationship_type]
except Exception as ex:
print(f"Error mapping vector distances to edges: {ex}")
async def calculate_top_triplet_importances(self, k: int) -> List:
min_heap = []
for i, edge in enumerate(self.edges):
source_node = self.get_node(edge.node1.id)
target_node = self.get_node(edge.node2.id)
source_distance = source_node.attributes.get("vector_distance", 1) if source_node else 1
target_distance = target_node.attributes.get("vector_distance", 1) if target_node else 1
edge_distance = edge.attributes.get("vector_distance", 1)
total_distance = source_distance + target_distance + edge_distance
heapq.heappush(min_heap, (-total_distance, i, edge))
if len(min_heap) > k:
heapq.heappop(min_heap)
return [edge for _, _, edge in sorted(min_heap)]

View file

@ -1,5 +1,5 @@
import numpy as np import numpy as np
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any, Union
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
@ -24,6 +24,7 @@ class Node:
raise InvalidValueError(message="Dimension must be a positive integer") raise InvalidValueError(message="Dimension must be a positive integer")
self.id = node_id self.id = node_id
self.attributes = attributes if attributes is not None else {} self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float('inf')
self.skeleton_neighbours = [] self.skeleton_neighbours = []
self.skeleton_edges = [] self.skeleton_edges = []
self.status = np.ones(dimension, dtype=int) self.status = np.ones(dimension, dtype=int)
@ -58,6 +59,12 @@ class Node:
raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.") raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
return self.status[dimension] == 1 return self.status[dimension] == 1
def add_attribute(self, key: str, value: Any) -> None:
self.attributes[key] = value
def get_attribute(self, key: str) -> Union[str, int, float]:
return self.attributes[key]
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Node({self.id}, attributes={self.attributes})" return f"Node({self.id}, attributes={self.attributes})"
@ -90,6 +97,7 @@ class Edge:
self.node1 = node1 self.node1 = node1
self.node2 = node2 self.node2 = node2
self.attributes = attributes if attributes is not None else {} self.attributes = attributes if attributes is not None else {}
self.attributes["vector_distance"] = float('inf')
self.directed = directed self.directed = directed
self.status = np.ones(dimension, dtype=int) self.status = np.ones(dimension, dtype=int)
@ -98,6 +106,12 @@ class Edge:
raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.") raise InvalidValueError(message=f"Dimension {dimension} is out of range. Valid range is 0 to {len(self.status) - 1}.")
return self.status[dimension] == 1 return self.status[dimension] == 1
def add_attribute(self, key: str, value: Any) -> None:
self.attributes[key] = value
def get_attribute(self, key: str, value: Any) -> Union[str, int, float]:
return self.attributes[key]
def __repr__(self) -> str: def __repr__(self) -> str:
direction = "->" if self.directed else "--" direction = "->" if self.directed else "--"
return f"Edge({self.node1.id} {direction} {self.node2.id}, attributes={self.attributes})" return f"Edge({self.node1.id} {direction} {self.node2.id}, attributes={self.attributes})"

View file

View file

@ -0,0 +1,150 @@
import asyncio
import logging
from typing import List
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.utils import send_telemetry
def format_triplets(edges):
print("\n\n\n")
def filter_attributes(obj, attributes):
"""Helper function to filter out non-None properties, including nested dicts."""
result = {}
for attr in attributes:
value = getattr(obj, attr, None)
if value is not None:
# If the value is a dict, extract relevant keys from it
if isinstance(value, dict):
nested_values = {k: v for k, v in value.items() if k in attributes and v is not None}
result[attr] = nested_values
else:
result[attr] = value
return result
triplets = []
for edge in edges:
node1 = edge.node1
node2 = edge.node2
edge_attributes = edge.attributes
node1_attributes = node1.attributes
node2_attributes = node2.attributes
# Filter only non-None properties
node1_info = {key: value for key, value in node1_attributes.items() if value is not None}
node2_info = {key: value for key, value in node2_attributes.items() if value is not None}
edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
# Create the formatted triplet
triplet = (
f"Node1: {node1_info}\n"
f"Edge: {edge_info}\n"
f"Node2: {node2_info}\n\n\n"
)
triplets.append(triplet)
return "".join(triplets)
async def brute_force_triplet_search(query: str, user: User = None, top_k = 5) -> list:
if user is None:
user = await get_default_user()
if user is None:
raise PermissionError("No user found in the system. Please create a user.")
retrieved_results = await brute_force_search(query, user, top_k)
return retrieved_results
def delete_duplicated_vector_db_elements(collections, results): #:TODO: This is just for now to fix vector db duplicates
results_dict = {}
for collection, results in zip(collections, results):
seen_ids = set()
unique_results = []
for result in results:
if result.id not in seen_ids:
unique_results.append(result)
seen_ids.add(result.id)
else:
print(f"Duplicate found in collection '{collection}': {result.id}")
results_dict[collection] = unique_results
return results_dict
async def brute_force_search(
query: str,
user: User,
top_k: int,
collections: List[str] = None
) -> list:
"""
Performs a brute force search to retrieve the top triplets from the graph.
Args:
query (str): The search query.
user (User): The user performing the search.
top_k (int): The number of top results to retrieve.
collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections.
Returns:
list: The top triplet results.
"""
if not query or not isinstance(query, str):
raise ValueError("The query must be a non-empty string.")
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
if collections is None:
collections = ["entity_name", "text_summary_text", "entity_type_name", "document_chunk_text"]
try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
logging.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
try:
results = await asyncio.gather(
*[vector_engine.get_distance_from_collection_elements(collection, query_text=query) for collection in collections]
)
############################################# :TODO: Change when vector db does not contain duplicates
node_distances = delete_duplicated_vector_db_elements(collections, results)
# node_distances = {collection: result for collection, result in zip(collections, results)}
##############################################
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(graph_engine,
node_properties_to_project=['id',
'description',
'name',
'type',
'text'],
edge_properties_to_project=['relationship_name'])
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
#:TODO: Change when vectordb contains edge embeddings
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
#:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db
return results
except Exception as e:
logging.error("Error during brute force search for user: %s, query: %s. Error: %s", user.id, query, e)
send_telemetry("cognee.brute_force_triplet_search EXECUTION FAILED", user.id)
raise RuntimeError("An error occurred during brute force search") from e

View file

@ -1,9 +1,12 @@
import os import os
from functools import lru_cache
import dlt import dlt
from typing import Union from typing import Union
from cognee.infrastructure.databases.relational import get_relational_config from cognee.infrastructure.databases.relational import get_relational_config
@lru_cache
def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]: def get_dlt_destination() -> Union[type[dlt.destinations.sqlalchemy], None]:
""" """
Handles propagation of the cognee database configuration to the dlt library Handles propagation of the cognee database configuration to the dlt library

View file

@ -1,6 +1,7 @@
import dlt import dlt
import cognee.modules.ingestion as ingestion import cognee.modules.ingestion as ingestion
from uuid import UUID
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
@ -17,25 +18,33 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
) )
@dlt.resource(standalone = True, merge_key = "id") @dlt.resource(standalone = True, merge_key = "id")
async def data_resources(file_paths: str, user: User): async def data_resources(file_paths: str):
for file_path in file_paths: for file_path in file_paths:
with open(file_path.replace("file://", ""), mode = "rb") as file: with open(file_path.replace("file://", ""), mode = "rb") as file:
classified_data = ingestion.classify(file) classified_data = ingestion.classify(file)
data_id = ingestion.identify(classified_data) data_id = ingestion.identify(classified_data)
file_metadata = classified_data.get_metadata() file_metadata = classified_data.get_metadata()
yield {
"id": data_id,
"name": file_metadata["name"],
"file_path": file_metadata["file_path"],
"extension": file_metadata["extension"],
"mime_type": file_metadata["mime_type"],
}
from sqlalchemy import select async def data_storing(table_name, dataset_name, user: User):
from cognee.modules.data.models import Data db_engine = get_relational_engine()
db_engine = get_relational_engine() async with db_engine.get_async_session() as session:
# Read metadata stored with dlt
async with db_engine.get_async_session() as session: files_metadata = await db_engine.get_all_data_from_table(table_name, dataset_name)
for file_metadata in files_metadata:
from sqlalchemy import select
from cognee.modules.data.models import Data
dataset = await create_dataset(dataset_name, user.id, session) dataset = await create_dataset(dataset_name, user.id, session)
data = (await session.execute( data = (await session.execute(
select(Data).filter(Data.id == data_id) select(Data).filter(Data.id == UUID(file_metadata["id"]))
)).scalar_one_or_none() )).scalar_one_or_none()
if data is not None: if data is not None:
@ -48,7 +57,7 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
await session.commit() await session.commit()
else: else:
data = Data( data = Data(
id = data_id, id = UUID(file_metadata["id"]),
name = file_metadata["name"], name = file_metadata["name"],
raw_data_location = file_metadata["file_path"], raw_data_location = file_metadata["file_path"],
extension = file_metadata["extension"], extension = file_metadata["extension"],
@ -58,25 +67,34 @@ async def ingest_data(file_paths: list[str], dataset_name: str, user: User):
dataset.data.append(data) dataset.data.append(data)
await session.commit() await session.commit()
yield { await give_permission_on_document(user, UUID(file_metadata["id"]), "read")
"id": data_id, await give_permission_on_document(user, UUID(file_metadata["id"]), "write")
"name": file_metadata["name"],
"file_path": file_metadata["file_path"],
"extension": file_metadata["extension"],
"mime_type": file_metadata["mime_type"],
}
await give_permission_on_document(user, data_id, "read")
await give_permission_on_document(user, data_id, "write")
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id) send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
run_info = pipeline.run(
data_resources(file_paths, user), db_engine = get_relational_engine()
table_name = "file_metadata",
dataset_name = dataset_name, # Note: DLT pipeline has its own event loop, therefore objects created in another event loop
write_disposition = "merge", # can't be used inside the pipeline
) if db_engine.engine.dialect.name == "sqlite":
# To use sqlite with dlt dataset_name must be set to "main".
# Sqlite doesn't support schemas
run_info = pipeline.run(
data_resources(file_paths),
table_name = "file_metadata",
dataset_name = "main",
write_disposition = "merge",
)
else:
run_info = pipeline.run(
data_resources(file_paths),
table_name="file_metadata",
dataset_name=dataset_name,
write_disposition="merge",
)
await data_storing("file_metadata", dataset_name, user)
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id) send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id)
return run_info return run_info

View file

@ -4,6 +4,7 @@ import logging
import pathlib import pathlib
import cognee import cognee
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -61,6 +62,9 @@ async def main():
assert len(history) == 6, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
results = await brute_force_triplet_search('What is a quantum computer?')
assert len(results) > 0
await cognee.prune.prune_data() await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -3,6 +3,7 @@ import logging
import pathlib import pathlib
import cognee import cognee
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -89,6 +90,9 @@ async def main():
history = await cognee.get_search_history() history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
results = await brute_force_triplet_search('What is a quantum computer?')
assert len(results) > 0
await cognee.prune.prune_data() await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -5,6 +5,7 @@ import logging
import pathlib import pathlib
import cognee import cognee
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -61,6 +62,9 @@ async def main():
history = await cognee.get_search_history() history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
results = await brute_force_triplet_search('What is a quantum computer?')
assert len(results) > 0
await cognee.prune.prune_data() await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -3,6 +3,7 @@ import logging
import pathlib import pathlib
import cognee import cognee
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -59,6 +60,9 @@ async def main():
history = await cognee.get_search_history() history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct." assert len(history) == 6, "Search history is not correct."
results = await brute_force_triplet_search('What is a quantum computer?')
assert len(results) > 0
await cognee.prune.prune_data() await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted" assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -9,7 +9,7 @@ def test_node_initialization():
"""Test that a Node is initialized correctly.""" """Test that a Node is initialized correctly."""
node = Node("node1", {"attr1": "value1"}, dimension=2) node = Node("node1", {"attr1": "value1"}, dimension=2)
assert node.id == "node1" assert node.id == "node1"
assert node.attributes == {"attr1": "value1"} assert node.attributes == {"attr1": "value1", 'vector_distance': np.inf}
assert len(node.status) == 2 assert len(node.status) == 2
assert np.all(node.status == 1) assert np.all(node.status == 1)
@ -96,7 +96,7 @@ def test_edge_initialization():
edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2) edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2)
assert edge.node1 == node1 assert edge.node1 == node1
assert edge.node2 == node2 assert edge.node2 == node2
assert edge.attributes == {"weight": 10} assert edge.attributes == {'vector_distance': np.inf,"weight": 10}
assert edge.directed is False assert edge.directed is False
assert len(edge.status) == 2 assert len(edge.status) == 2
assert np.all(edge.status == 1) assert np.all(edge.status == 1)

View file

@ -1,6 +1,6 @@
import pytest import pytest
from cognee.exceptions import EntityNotFoundError, EntityAlreadyExistsError from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
@ -78,11 +78,11 @@ def test_get_edges_success(setup_graph):
graph.add_node(node2) graph.add_node(node2)
edge = Edge(node1, node2) edge = Edge(node1, node2)
graph.add_edge(edge) graph.add_edge(edge)
assert edge in graph.get_edges("node1") assert edge in graph.get_edges_from_node("node1")
def test_get_edges_nonexistent_node(setup_graph): def test_get_edges_nonexistent_node(setup_graph):
"""Test retrieving edges for a nonexistent node raises an exception.""" """Test retrieving edges for a nonexistent node raises an exception."""
graph = setup_graph graph = setup_graph
with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."): with pytest.raises(EntityNotFoundError, match="Node with id nonexistent does not exist."):
graph.get_edges("nonexistent") graph.get_edges_from_node("nonexistent")

View file

@ -46,7 +46,7 @@ services:
- 7687:7687 - 7687:7687
environment: environment:
- NEO4J_AUTH=neo4j/pleaseletmein - NEO4J_AUTH=neo4j/pleaseletmein
- NEO4J_PLUGINS=["apoc"] - NEO4J_PLUGINS=["apoc", "graph-data-science"]
networks: networks:
- cognee-network - cognee-network

View file

@ -1,32 +1,7 @@
import cognee import cognee
import asyncio import asyncio
from cognee.api.v1.search import SearchType from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.brute_force_triplet_search import format_triplets
job_position = """0:Senior Data Scientist (Machine Learning)
Company: TechNova Solutions
Location: San Francisco, CA
Job Description:
TechNova Solutions is seeking a Senior Data Scientist specializing in Machine Learning to join our dynamic analytics team. The ideal candidate will have a strong background in developing and deploying machine learning models, working with large datasets, and translating complex data into actionable insights.
Responsibilities:
Develop and implement advanced machine learning algorithms and models.
Analyze large, complex datasets to extract meaningful patterns and insights.
Collaborate with cross-functional teams to integrate predictive models into products.
Stay updated with the latest advancements in machine learning and data science.
Mentor junior data scientists and provide technical guidance.
Qualifications:
Masters or Ph.D. in Data Science, Computer Science, Statistics, or a related field.
5+ years of experience in data science and machine learning.
Proficient in Python, R, and SQL.
Experience with deep learning frameworks (e.g., TensorFlow, PyTorch).
Strong problem-solving skills and attention to detail.
Candidate CVs
"""
job_1 = """ job_1 = """
CV 1: Relevant CV 1: Relevant
@ -195,7 +170,7 @@ async def main(enable_steps):
# Step 2: Add text # Step 2: Add text
if enable_steps.get("add_text"): if enable_steps.get("add_text"):
text_list = [job_position, job_1, job_2, job_3, job_4, job_5] text_list = [job_1, job_2, job_3, job_4, job_5]
for text in text_list: for text in text_list:
await cognee.add(text) await cognee.add(text)
print(f"Added text: {text[:35]}...") print(f"Added text: {text[:35]}...")
@ -206,24 +181,21 @@ async def main(enable_steps):
print("Knowledge graph created.") print("Knowledge graph created.")
# Step 4: Query insights # Step 4: Query insights
if enable_steps.get("search_insights"): if enable_steps.get("retriever"):
search_results = await cognee.search( results = await brute_force_triplet_search('Who has the most experience with graphic design?')
SearchType.INSIGHTS, print(format_triplets(results))
{'query': 'Which applicant has the most relevant experience in data science?'}
)
print("Search results:")
for result_text in search_results:
print(result_text)
if __name__ == '__main__': if __name__ == '__main__':
# Flags to enable/disable steps # Flags to enable/disable steps
rebuild_kg = True
retrieve = True
steps_to_enable = { steps_to_enable = {
"prune_data": True, "prune_data": rebuild_kg,
"prune_system": True, "prune_system": rebuild_kg,
"add_text": True, "add_text": rebuild_kg,
"cognify": True, "cognify": rebuild_kg,
"search_insights": True "retriever": retrieve
} }
asyncio.run(main(steps_to_enable)) asyncio.run(main(steps_to_enable))