feat: s3 storage (#988)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: vasilije <vas.markovic@gmail.com>
Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
This commit is contained in:
Boris 2025-07-14 21:47:08 +02:00 committed by GitHub
parent 4bcb893a54
commit 46c4463cb2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
102 changed files with 10328 additions and 9216 deletions

26
.gitguardian.yml Normal file
View file

@ -0,0 +1,26 @@
# .gitguardian.yml
version: v1
secret-scan:
# Ignore specific files
excluded-paths:
- '.env.template'
- '.github/workflows/*.yml'
- 'examples/**'
- 'tests/**'
# Ignore specific patterns
excluded-detectors:
- 'Generic Password'
- 'Generic High Entropy Secret'
# Ignore by commit (if needed)
excluded-commits:
- '782bbb4'
# Custom rules for template files
paths-ignore:
- path: '.env.template'
comment: 'Template file with placeholder values'
- path: '.github/workflows/search_db_tests.yml'
comment: 'Test workflow with test credentials'

View file

@ -44,7 +44,7 @@ jobs:
strategy:
matrix:
python-version: ${{ fromJSON(inputs.python-versions) }}
os: [ubuntu-22.04, macos-13, macos-15, windows-latest]
os: [ubuntu-22.04, macos-13, macos-15]
fail-fast: false
steps:
- name: Check out

View file

@ -6,7 +6,7 @@ from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from cognee.infrastructure.databases.relational import Base
from alembic import context
from cognee.infrastructure.databases.relational import get_relational_engine, get_relational_config
from cognee.infrastructure.databases.relational import get_relational_engine
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
@ -86,12 +86,6 @@ def run_migrations_online() -> None:
db_engine = get_relational_engine()
if db_engine.engine.dialect.name == "sqlite":
from cognee.infrastructure.files.storage import LocalStorage
db_config = get_relational_config()
LocalStorage.ensure_directory_exists(db_config.db_path)
print("Using database:", db_engine.db_uri)
config.set_section_option(

View file

@ -8,10 +8,10 @@ requires-python = ">=3.10"
dependencies = [
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/<username>/Desktop/cognee",
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.2.0",
"fastmcp>=1.0",
"mcp==1.5.0",
"uv>=0.6.3",
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]>=0.2.0,<1.0.0",
"fastmcp>=1.0,<2.0.0",
"mcp>=1.11.0,<2.0.0",
"uv>=0.6.3,<1.0.0",
]
authors = [
@ -29,7 +29,7 @@ packages = ["src"]
[dependency-groups]
dev = [
"debugpy>=1.8.12",
"debugpy>=1.8.12,<2.0.0",
]
[tool.hatch.metadata]

5890
cognee-mcp/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -7,5 +7,5 @@ readme = "README.md"
requires-python = ">=3.10, <=3.13"
dependencies = [
"cognee>=0.1.38",
"cognee>=0.1.38,<1.0.0",
]

View file

@ -3,7 +3,6 @@ import asyncio
import pathlib
from cognee import config, add, cognify, search, SearchType, prune, visualize_graph
# from cognee.shared.utils import render_graph
from cognee.low_level import DataPoint
@ -50,10 +49,6 @@ async def main():
# Cognify the text data.
await cognify(graph_model=ProgrammingLanguage)
# # Get a graphistry url (Register for a free account at https://www.graphistry.com)
# url = await render_graph()
# print(f"Graphistry URL: {url}")
# Or use our simple graph preview
graph_file_path = str(
pathlib.Path(

View file

@ -2,7 +2,6 @@ import os
import asyncio
import pathlib
from cognee import config, add, cognify, search, SearchType, prune, visualize_graph
# from cognee.shared.utils import render_graph
async def main():
@ -30,10 +29,6 @@ async def main():
# Cognify the text data.
await cognify()
# # Get a graphistry url (Register for a free account at https://www.graphistry.com)
# url = await render_graph()
# print(f"Graphistry URL: {url}")
# Or use our simple graph preview
graph_file_path = str(
pathlib.Path(

View file

@ -9,21 +9,28 @@ from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.databases.relational import get_relational_config, get_migration_config
from cognee.infrastructure.files.storage import LocalStorage
class config:
@staticmethod
def system_root_directory(system_root_directory: str):
databases_directory_path = os.path.join(system_root_directory, "databases")
base_config = get_base_config()
base_config.system_root_directory = os.path.join(system_root_directory, ".cognee_system")
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
relational_config = get_relational_config()
relational_config.db_path = databases_directory_path
LocalStorage.ensure_directory_exists(databases_directory_path)
graph_config = get_graph_config()
graph_file_name = graph_config.graph_filename
graph_config.graph_file_path = os.path.join(databases_directory_path, graph_file_name)
# For Kuzu v0.11.0+, use single-file database with .kuzu extension
if graph_config.graph_database_provider.lower() == "kuzu":
graph_config.graph_file_path = os.path.join(
databases_directory_path, f"{graph_file_name}.kuzu"
)
else:
graph_config.graph_file_path = os.path.join(databases_directory_path, graph_file_name)
vector_config = get_vectordb_config()
if vector_config.vector_db_provider == "lancedb":
@ -32,7 +39,7 @@ class config:
@staticmethod
def data_root_directory(data_root_directory: str):
base_config = get_base_config()
base_config.data_root_directory = data_root_directory
base_config.data_root_directory = os.path.join(data_root_directory, ".data_storage")
@staticmethod
def monitoring_tool(monitoring_tool: object):

View file

@ -1,22 +1,25 @@
from typing import Union, BinaryIO, List
from cognee.modules.ingestion import classify
from cognee.infrastructure.databases.relational import get_relational_engine
from sqlalchemy import select
from sqlalchemy.sql import delete as sql_delete
from cognee.modules.data.models import Data, DatasetData, Dataset
from cognee.infrastructure.databases.graph import get_graph_engine
from io import BytesIO
import os
import hashlib
from uuid import UUID
from cognee.modules.users.models import User
from cognee.infrastructure.databases.vector import get_vector_engine
from io import BytesIO
from sqlalchemy import select
from typing import Union, BinaryIO, List
from sqlalchemy.sql import delete as sql_delete
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.files.storage import get_file_storage
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.ingestion import classify
from cognee.modules.users.models import User
from cognee.shared.logging_utils import get_logger
from cognee.modules.data.models import Data, DatasetData, Dataset
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.users.methods import get_default_user
from cognee.modules.data.methods import get_authorized_existing_datasets
from cognee.context_global_variables import set_database_global_context_variables
from .exceptions import DocumentNotFoundError, DatasetNotFoundError, DocumentSubgraphNotFoundError
from cognee.shared.logging_utils import get_logger
logger = get_logger()
@ -56,7 +59,14 @@ async def delete(
# Handle different input types
if isinstance(data, str):
if data.startswith("file://") or data.startswith("/"): # It's a file path
with open(data.replace("file://", ""), mode="rb") as file:
full_file_path = data.replace("file://", "")
file_dir = os.path.dirname(full_file_path)
file_path = os.path.basename(full_file_path)
file_storage = get_file_storage(file_dir)
async with file_storage.open(file_path, mode="rb") as file:
classified_data = classify(file)
content_hash = classified_data.get_metadata()["content_hash"]
return await delete_single_document(content_hash, dataset[0].id, mode)

View file

@ -1,14 +1,14 @@
from cognee.modules.data.deletion import prune_system, prune_data
from cognee.modules.data.deletion import prune_system as _prune_system, prune_data as _prune_data
class prune:
@staticmethod
async def prune_data():
await prune_data()
await _prune_data()
@staticmethod
async def prune_system(graph=True, vector=True, metadata=False):
await prune_system(graph, vector, metadata)
await _prune_system(graph, vector, metadata)
if __name__ == "__main__":

View file

@ -8,6 +8,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class BaseConfig(BaseSettings):
data_root_directory: str = get_absolute_path(".data_storage")
system_root_directory: str = get_absolute_path(".cognee_system")
monitoring_tool: object = Observer.LANGFUSE
graphistry_username: Optional[str] = os.getenv("GRAPHISTRY_USERNAME")
graphistry_password: Optional[str] = os.getenv("GRAPHISTRY_PASSWORD")

View file

@ -1,11 +1,12 @@
import os
import pathlib
from contextvars import ContextVar
from typing import Union
from uuid import UUID
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
from cognee.modules.users.methods import get_user
from cognee.infrastructure.files.storage.config import file_storage_config
# Note: ContextVar allows us to use different graph db configurations in Cognee
# for different async tasks, threads and processes
@ -32,6 +33,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
"""
base_config = get_base_config()
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return
@ -40,16 +43,16 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
# To ensure permissions are enforced properly all datasets will have their own databases
dataset_database = await get_or_create_dataset_database(dataset, user)
# TODO: Find better location for database files
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, f".cognee_system/databases/{user.id}")
).resolve()
data_root_directory = os.path.join(
base_config.data_root_directory, str(user.tenant_id or user.id)
)
system_directory_path = os.path.join(
base_config.system_root_directory, "databases", str(user.id)
)
# Set vector and graph database configuration based on dataset database information
vector_config = {
"vector_db_url": os.path.join(cognee_directory_path, dataset_database.vector_database_name),
"vector_db_url": os.path.join(system_directory_path, dataset_database.vector_database_name),
"vector_db_key": "",
"vector_db_provider": "lancedb",
}
@ -57,11 +60,16 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
graph_config = {
"graph_database_provider": "kuzu",
"graph_file_path": os.path.join(
cognee_directory_path, dataset_database.graph_database_name
system_directory_path, f"{dataset_database.graph_database_name}.kuzu"
),
}
storage_config = {
"data_root_directory": data_root_directory,
}
# Use ContextVar to use these graph and vector configurations are used
# in the current async context across Cognee
graph_db_config.set(graph_config)
vector_db_config.set(vector_config)
file_storage_config.set(storage_config)

View file

@ -5,7 +5,7 @@ from cognee.eval_framework.answer_generation.answer_generation_executor import (
AnswerGeneratorExecutor,
retriever_options,
)
from cognee.infrastructure.files.storage import LocalStorage
from cognee.infrastructure.files.storage import get_file_storage
from cognee.infrastructure.databases.relational.get_relational_engine import (
get_relational_engine,
get_relational_config,
@ -22,7 +22,7 @@ async def create_and_insert_answers_table(questions_payload):
relational_engine = get_relational_engine()
if relational_engine.engine.dialect.name == "sqlite":
LocalStorage.ensure_directory_exists(relational_config.db_path)
await get_file_storage(relational_config.db_path).ensure_directory_exists()
async with relational_engine.engine.begin() as connection:
if len(AnswersBase.metadata.tables.keys()) > 0:

View file

@ -2,7 +2,7 @@ from cognee.shared.logging_utils import get_logger, ERROR
import json
from typing import List, Optional
from cognee.infrastructure.files.storage import LocalStorage
from cognee.infrastructure.files.storage import get_file_storage
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
from cognee.modules.data.models.questions_base import QuestionsBase
from cognee.modules.data.models.questions_data import Questions
@ -21,7 +21,7 @@ async def create_and_insert_questions_table(questions_payload):
relational_engine = get_relational_engine()
if relational_engine.engine.dialect.name == "sqlite":
LocalStorage.ensure_directory_exists(relational_config.db_path)
await get_file_storage(relational_config.db_path).ensure_directory_exists()
async with relational_engine.engine.begin() as connection:
if len(QuestionsBase.metadata.tables.keys()) > 0:

View file

@ -4,7 +4,7 @@ from typing import List
from cognee.eval_framework.evaluation.evaluation_executor import EvaluationExecutor
from cognee.eval_framework.analysis.metrics_calculator import calculate_metrics_statistics
from cognee.eval_framework.analysis.dashboard_generator import create_dashboard
from cognee.infrastructure.files.storage import LocalStorage
from cognee.infrastructure.files.storage import get_file_storage
from cognee.infrastructure.databases.relational.get_relational_engine import (
get_relational_engine,
get_relational_config,
@ -21,7 +21,7 @@ async def create_and_insert_metrics_table(questions_payload):
relational_engine = get_relational_engine()
if relational_engine.engine.dialect.name == "sqlite":
LocalStorage.ensure_directory_exists(relational_config.db_path)
await get_file_storage(relational_config.db_path).ensure_directory_exists()
async with relational_engine.engine.begin() as connection:
if len(MetricsBase.metadata.tables.keys()) > 0:

View file

@ -5,8 +5,8 @@ from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
import pydantic
from pydantic import Field
from cognee.base_config import get_base_config
from cognee.shared.data_models import KnowledgeGraph
from cognee.root_dir import get_absolute_path
class GraphConfig(BaseSettings):
@ -55,8 +55,14 @@ class GraphConfig(BaseSettings):
values.graph_filename = f"cognee_graph_{provider}"
# Set file path based on graph database provider if no file path is provided
if not values.graph_file_path:
base = os.path.join(get_absolute_path(".cognee_system"), "databases")
values.graph_file_path = os.path.join(base, values.graph_filename)
base_config = get_base_config()
base = os.path.join(base_config.system_root_directory, "databases")
# For Kuzu v0.11.0+, use single-file database with .kuzu extension
if provider == "kuzu":
values.graph_file_path = os.path.join(base, f"{values.graph_filename}.kuzu")
else:
values.graph_file_path = os.path.join(base, values.graph_filename)
return values
def to_dict(self) -> dict:

View file

@ -1,20 +1,20 @@
"""Adapter for Kuzu graph database."""
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
from cognee.shared.logging_utils import get_logger
import json
import os
import shutil
import json
import asyncio
from typing import Dict, Any, List, Union, Optional, Tuple, Type
from datetime import datetime, timezone
import tempfile
from uuid import UUID
from kuzu import Connection
from kuzu.database import Database
from datetime import datetime, timezone
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type
import kuzu
from kuzu.database import Database
from kuzu import Connection
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.files.storage import get_file_storage
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
@ -46,9 +46,38 @@ class KuzuAdapter(GraphDBInterface):
def _initialize_connection(self) -> None:
"""Initialize the Kuzu database connection and schema."""
try:
os.makedirs(self.db_path, exist_ok=True)
if "s3://" in self.db_path:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
self.temp_graph_file = temp_file.name
run_sync(self.pull_from_s3())
self.db = Database(
self.temp_graph_file,
buffer_pool_size=256 * 1024 * 1024, # 256MB buffer pool
max_db_size=1024 * 1024 * 1024,
)
else:
# Ensure the parent directory exists before creating the database
db_dir = os.path.dirname(self.db_path)
# If db_path is just a filename, db_dir will be empty string
# In this case, use the directory containing the db_path or current directory
if not db_dir:
# If no directory in path, use the absolute path's directory
abs_path = os.path.abspath(self.db_path)
db_dir = os.path.dirname(abs_path)
file_storage = get_file_storage(db_dir)
run_sync(file_storage.ensure_directory_exists())
self.db = Database(
self.db_path,
buffer_pool_size=256 * 1024 * 1024, # 256MB buffer pool
max_db_size=1024 * 1024 * 1024,
)
self.db = Database(self.db_path)
self.db.init_database()
self.connection = Connection(self.db)
# Create node table with essential fields and timestamp
@ -77,6 +106,22 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to initialize Kuzu database: {e}")
raise e
async def push_to_s3(self) -> None:
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("")
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
async def pull_from_s3(self) -> None:
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("")
try:
s3_file_storage.s3.get(self.db_path, self.temp_graph_file, recursive=True)
except FileNotFoundError:
pass
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
"""
Execute a Kuzu query asynchronously with automatic reconnection.
@ -1385,41 +1430,11 @@ class KuzuAdapter(GraphDBInterface):
"relationship_types": [rel[0] for rel in rel_types],
}
async def get_node_labels_string(self) -> str:
"""
Get all node labels as a string.
This method aggregates all unique node labels from the graph into a single string
representation, which can be helpful for overview and debugging purposes.
Returns:
--------
- str: A string of all distinct node labels, separated by '|'.
"""
labels = await self.query("MATCH (n:Node) RETURN DISTINCT labels(n)")
return "|".join(sorted(set([label[0] for label in labels])))
async def get_relationship_labels_string(self) -> str:
"""
Get all relationship types as a string.
This method aggregates all unique relationship types from the graph into a single string
representation, providing an overview of the relationships defined in the graph.
Returns:
--------
- str: A string of all distinct relationship types, separated by '|'.
"""
types = await self.query("MATCH ()-[r:EDGE]->() RETURN DISTINCT r.relationship_name")
return "|".join(sorted(set([t[0] for t in types])))
async def delete_graph(self) -> None:
"""
Delete all data from the graph directory.
Delete all data from the graph database.
This method deletes all nodes and relationships from the graph directory
This method deletes all nodes and relationships from the graph database.
It raises exceptions for failures occurring during deletion processes.
"""
try:
@ -1432,8 +1447,13 @@ class KuzuAdapter(GraphDBInterface):
if self.db:
self.db.close()
self.db = None
if os.path.exists(self.db_path):
shutil.rmtree(self.db_path)
db_dir = os.path.dirname(self.db_path)
db_name = os.path.basename(self.db_path)
file_storage = get_file_storage(db_dir)
if await file_storage.file_exists(db_name):
await file_storage.remove_all()
logger.info(f"Deleted Kuzu database files at {self.db_path}")
except Exception as e:
@ -1454,9 +1474,15 @@ class KuzuAdapter(GraphDBInterface):
if self.db:
self.db.close()
self.db = None
if os.path.exists(self.db_path):
shutil.rmtree(self.db_path)
db_dir = os.path.dirname(self.db_path)
db_name = os.path.basename(self.db_path)
file_storage = get_file_storage(db_dir)
if await file_storage.file_exists(db_name):
await file_storage.remove_all()
logger.info(f"Deleted Kuzu database files at {self.db_path}")
# Reinitialize the database
self._initialize_connection()
# Verify the database is empty
@ -1472,61 +1498,6 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Error during database clearing: {e}")
raise
async def save_graph_to_file(self, file_path: str) -> None:
"""
Export the Kuzu database to a file.
This method exports the entire Kuzu graph database to a specified file path, utilizing
Kuzu's native export command. Ensure the directory exists prior to attempting the
export, and manage related exceptions as they arise.
Parameters:
-----------
- file_path (str): Path where to export the database.
"""
try:
# Ensure directory exists
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# Use Kuzu's native EXPORT command, output is Parquet
export_query = f"EXPORT DATABASE '{file_path}'"
await self.query(export_query)
logger.info(f"Graph exported to {file_path}")
except Exception as e:
logger.error(f"Failed to export graph to file: {e}")
raise
async def load_graph_from_file(self, file_path: str) -> None:
"""
Import a Kuzu database from a file.
This method imports a database from a specified file path, ensuring that the file exists
before attempting to import. Errors during the import process are managed accordingly,
allowing for smooth operation.
Parameters:
-----------
- file_path (str): Path to the exported database file.
"""
try:
if not os.path.exists(file_path):
logger.warning(f"File {file_path} not found")
return
# Use Kuzu's native IMPORT command
import_query = f"IMPORT DATABASE '{file_path}'"
await self.query(import_query)
logger.info(f"Graph imported from {file_path}")
except Exception as e:
logger.error(f"Failed to import graph from file: {e}")
raise
async def get_document_subgraph(self, content_hash: str):
"""
Get all nodes that should be deleted when removing a document.

View file

@ -1,17 +1,17 @@
"""Adapter for NetworkX graph database."""
from datetime import datetime, timezone
import os
import json
import asyncio
import numpy as np
from uuid import UUID
import networkx as nx
from datetime import datetime, timezone
from typing import Dict, Any, List, Union, Type, Tuple
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
from cognee.infrastructure.files.storage import get_file_storage
from cognee.shared.logging_utils import get_logger
from typing import Dict, Any, List, Union, Type, Tuple
from uuid import UUID
import aiofiles
import aiofiles.os as aiofiles_os
import networkx as nx
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
record_graph_changes,
@ -19,7 +19,6 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.modules.storage.utils import JSONEncoder
import numpy as np
logger = get_logger()
@ -551,11 +550,6 @@ class NetworkXAdapter(GraphDBInterface):
"""
self.graph = nx.MultiDiGraph()
# Only create directory if file_path contains a directory
file_dir = os.path.dirname(file_path)
if file_dir and not os.path.exists(file_dir):
os.makedirs(file_dir, exist_ok=True)
await self.save_graph_to_file(file_path)
async def save_graph_to_file(self, file_path: str = None) -> None:
@ -573,9 +567,14 @@ class NetworkXAdapter(GraphDBInterface):
graph_data = nx.readwrite.json_graph.node_link_data(self.graph, edges="links")
async with aiofiles.open(file_path, "w") as file:
json_data = json.dumps(graph_data, cls=JSONEncoder)
await file.write(json_data)
file_dir_path = os.path.dirname(file_path)
file_path = os.path.basename(file_path)
file_storage = get_file_storage(file_dir_path)
json_data = json.dumps(graph_data, cls=JSONEncoder)
await file_storage.store(file_path, json_data, overwrite=True)
async def load_graph_from_file(self, file_path: str = None):
"""
@ -590,9 +589,14 @@ class NetworkXAdapter(GraphDBInterface):
if not file_path:
file_path = self.filename
try:
if os.path.exists(file_path):
async with aiofiles.open(file_path, "r") as file:
graph_data = json.loads(await file.read())
file_dir_path = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
file_storage = get_file_storage(file_dir_path)
if await file_storage.file_exists(file_name):
async with file_storage.open(file_name, "r") as file:
graph_data = json.loads(file.read())
for node in graph_data["nodes"]:
try:
if not isinstance(node["id"], UUID):
@ -674,8 +678,12 @@ class NetworkXAdapter(GraphDBInterface):
self.filename
) # Assuming self.filename is defined elsewhere and holds the default graph file path
try:
if os.path.exists(file_path):
await aiofiles_os.remove(file_path)
file_dir_path = os.path.dirname(file_path)
file_path = os.path.basename(file_path)
file_storage = get_file_storage(file_dir_path)
await file_storage.remove(file_path)
self.graph = None
logger.info("Graph deleted successfully.")

View file

@ -1,17 +1,21 @@
import os
import asyncio
from os import path
import tempfile
from uuid import UUID
from typing import Optional
from typing import AsyncGenerator, List
from contextlib import asynccontextmanager
from sqlalchemy import text, select, MetaData, Table, delete, inspect
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import NoResultFound
from sqlalchemy import NullPool, text, select, MetaData, Table, delete, inspect
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.modules.data.models.Data import Data
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
from ..ModelBase import Base
@ -29,11 +33,42 @@ class SQLAlchemyAdapter:
self.db_path: str = None
self.db_uri: str = connection_string
self.engine = create_async_engine(connection_string)
if "sqlite" in connection_string:
[prefix, db_path] = connection_string.split("///")
self.db_path = db_path
if "s3://" in self.db_path:
db_dir_path = path.dirname(self.db_path)
file_storage = get_file_storage(db_dir_path)
run_sync(file_storage.ensure_directory_exists())
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
self.temp_db_file = temp_file.name
connection_string = prefix + "///" + self.temp_db_file
run_sync(self.pull_from_s3())
self.engine = create_async_engine(
connection_string, poolclass=NullPool if "sqlite" in connection_string else None
)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
if self.engine.dialect.name == "sqlite":
self.db_path = connection_string.split("///")[1]
async def push_to_s3(self) -> None:
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_db_file"):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("")
s3_file_storage.s3.put(self.temp_db_file, self.db_path, recursive=True)
async def pull_from_s3(self) -> None:
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("")
try:
s3_file_storage.s3.get(self.db_path, self.temp_db_file, recursive=True)
except FileNotFoundError:
pass
@asynccontextmanager
async def get_async_session(self) -> AsyncGenerator[AsyncSession, None]:
@ -249,13 +284,18 @@ class SQLAlchemyAdapter:
# Don't delete local file unless this is the only reference to the file in the database
if len(raw_data_location_entities) == 1:
# delete local file only if it's created by cognee
from cognee.base_config import get_base_config
storage_config = get_storage_config()
config = get_base_config()
if (
storage_config["data_root_directory"]
in raw_data_location_entities[0].raw_data_location
):
file_storage = get_file_storage(storage_config["data_root_directory"])
if config.data_root_directory in raw_data_location_entities[0].raw_data_location:
if os.path.exists(raw_data_location_entities[0].raw_data_location):
os.remove(raw_data_location_entities[0].raw_data_location)
file_path = os.path.basename(raw_data_location_entities[0].raw_data_location)
if await file_storage.file_exists(file_path):
await file_storage.remove(file_path)
else:
# Report bug as file should exist
logger.error("Local file which should exist can't be found.")
@ -434,11 +474,13 @@ class SQLAlchemyAdapter:
Create the database if it does not exist, ensuring necessary directories are in place
for SQLite.
"""
if self.engine.dialect.name == "sqlite" and not os.path.exists(self.db_path):
from cognee.infrastructure.files.storage import LocalStorage
if self.engine.dialect.name == "sqlite":
db_directory = path.dirname(self.db_path)
LocalStorage.ensure_directory_exists(db_directory)
db_name = path.basename(self.db_path)
file_storage = get_file_storage(db_directory)
if not await file_storage.file_exists(db_name):
await file_storage.ensure_directory_exists()
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
@ -450,13 +492,13 @@ class SQLAlchemyAdapter:
"""
try:
if self.engine.dialect.name == "sqlite":
from cognee.infrastructure.files.storage import LocalStorage
await self.engine.dispose(close=True)
# Wait for the database connections to close and release the file (Windows)
await asyncio.sleep(2)
db_directory = path.dirname(self.db_path)
LocalStorage.ensure_directory_exists(db_directory)
with open(self.db_path, "w") as file:
file.write("")
file_path = path.basename(self.db_path)
file_storage = get_file_storage(db_directory)
await file_storage.remove(file_path)
else:
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information

View file

@ -1,4 +1,5 @@
import asyncio
from os import path
import lancedb
from pydantic import BaseModel
from lancedb.pydantic import LanceModel, Vector
@ -7,7 +8,7 @@ from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.files.storage import LocalStorage
from cognee.infrastructure.files.storage import get_file_storage
from cognee.modules.storage.utils import copy_model, get_own_properties
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
@ -310,7 +311,9 @@ class LanceDBAdapter(VectorDBInterface):
await connection.drop_table(collection_name)
if self.url.startswith("/"):
LocalStorage.remove_all(self.url)
db_dir_path = path.dirname(self.url)
db_file_name = path.basename(self.url)
await get_file_storage(db_dir_path).remove_all(db_file_name)
def get_data_point_schema(self, model_type: BaseModel):
related_models_fields = []

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import os
from uuid import UUID
from typing import List, Optional
@ -7,6 +8,7 @@ from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.files.storage import get_file_storage
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
@ -74,6 +76,34 @@ class MilvusAdapter(VectorDBInterface):
"""
from pymilvus import MilvusClient
# Ensure the parent directory exists for local file-based Milvus databases
if self.url and not self.url.startswith(("http://", "https://", "grpc://")):
# This is likely a local file path, ensure the directory exists
db_dir = os.path.dirname(self.url)
if db_dir and not os.path.exists(db_dir):
try:
file_storage = get_file_storage(db_dir)
if hasattr(file_storage, "ensure_directory_exists"):
if asyncio.iscoroutinefunction(file_storage.ensure_directory_exists):
# Run async function synchronously in this sync method
loop = asyncio.get_event_loop()
if loop.is_running():
# If we're already in an async context, we can't use run_sync easily
# Create the directory directly as a fallback
os.makedirs(db_dir, exist_ok=True)
else:
loop.run_until_complete(file_storage.ensure_directory_exists())
else:
file_storage.ensure_directory_exists()
else:
# Fallback to os.makedirs if file_storage doesn't have ensure_directory_exists
os.makedirs(db_dir, exist_ok=True)
except Exception as e:
logger.warning(
f"Could not create directory {db_dir} using file_storage, falling back to os.makedirs: {e}"
)
os.makedirs(db_dir, exist_ok=True)
if self.api_key:
client = MilvusClient(uri=self.url, token=self.api_key)
else:
@ -343,8 +373,6 @@ class MilvusAdapter(VectorDBInterface):
"""
from pymilvus import MilvusException, exceptions
if limit <= 0:
return []
client = self.get_milvus_client()
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")

View file

@ -1,3 +1 @@
from .add_file_to_storage import add_file_to_storage
from .remove_file_from_storage import remove_file_from_storage
from .utils.get_file_metadata import get_file_metadata, FileMetadata

View file

@ -1,22 +0,0 @@
from typing import BinaryIO
from cognee.root_dir import get_absolute_path
from .storage.StorageManager import StorageManager
from .storage.LocalStorage import LocalStorage
async def add_file_to_storage(file_path: str, file: BinaryIO):
"""
Store a file in local storage.
This function initializes a storage manager and uses it to store the provided file at
the specified file path.
Parameters:
-----------
- file_path (str): The path where the file will be stored.
- file (BinaryIO): The file object to be stored, which must be a binary file.
"""
storage_manager = StorageManager(LocalStorage(get_absolute_path("data/files")))
storage_manager.store(file_path, file)

View file

@ -0,0 +1,13 @@
from fastapi import status
class FileContentHashingError(Exception):
"""Raised when the file content cannot be hashed."""
def __init__(
self,
message: str = "Failed to hash content of the file.",
name: str = "FileContentHashingError",
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
):
super().__init__(message, name, status_code)

View file

@ -1,23 +0,0 @@
from cognee.root_dir import get_absolute_path
from .storage.StorageManager import StorageManager
from .storage.LocalStorage import LocalStorage
async def remove_file_from_storage(file_path: str):
"""
Remove a specified file from storage.
This function initializes a storage manager with a local storage instance and calls the
remove method of the storage manager to delete the file identified by the given path.
Ensure that the file exists in the specified storage before calling this function to
avoid
potential errors.
Parameters:
-----------
- file_path (str): The path of the file to remove from storage.
"""
storage_manager = StorageManager(LocalStorage(get_absolute_path("data/files")))
storage_manager.remove(file_path)

View file

@ -0,0 +1,16 @@
from io import BufferedReader
class FileBufferedReader(BufferedReader):
def __init__(self, file_obj, name):
super().__init__(file_obj)
self._file = file_obj
self._name = name
@property
def name(self):
return self._name
def read(self, size: int = -1):
data = self._file.read(size)
return data

View file

@ -0,0 +1,256 @@
import os
import shutil
from urllib.parse import urlparse
from contextlib import contextmanager
from typing import BinaryIO, Optional, Union
from .FileBufferedReader import FileBufferedReader
from .storage import Storage
def get_parsed_path(file_path: str) -> str:
# Check if this is actually a URL (has a scheme like file://, http://, etc.)
if "://" in file_path:
parsed_url = urlparse(file_path)
# Handle file:// URLs specially
if parsed_url.scheme == "file":
# On Windows, urlparse handles drive letters correctly
# Convert the path component to a proper file path
if os.name == "nt": # Windows
# Remove leading slash from Windows paths like /C:/Users/...
# but handle UNC paths like //server/share correctly
parsed_path = parsed_url.path
if parsed_path.startswith("/") and len(parsed_path) > 1 and parsed_path[2] == ":":
# This is a Windows drive path like /C:/Users/...
parsed_path = parsed_path[1:]
elif parsed_path.startswith("///"):
# This is a UNC path like ///server/share, convert to //server/share
parsed_path = parsed_path[1:]
else: # Unix-like systems
parsed_path = parsed_url.path
else:
# For non-file URLs, use the path as-is
parsed_path = parsed_url.path
if (
os.name == "nt"
and parsed_path.startswith("/")
and len(parsed_path) > 1
and parsed_path[2] == ":"
):
parsed_path = parsed_path[1:]
# Normalize path separators to ensure consistency
return os.path.normpath(parsed_path)
else:
# This is a regular file path, not a URL - normalize separators
return os.path.normpath(file_path)
class LocalFileStorage(Storage):
"""
Manage local file storage operations such as storing, retrieving, and managing files on
the filesystem.
"""
storage_path: Optional[str] = None
def __init__(self, storage_path: str):
self.storage_path = storage_path
def store(self, file_path: str, data: Union[BinaryIO, str], overwrite: bool = False) -> str:
"""
Store data into a specified file path. The data can be either a string or a binary
stream.
This method ensures that the storage directory exists before attempting to write the
data. If the provided data is a stream, it reads from the stream and writes to the file;
otherwise, it directly writes the provided data.
Parameters:
-----------
- file_path (str): The relative path of the file where the data will be stored.
- data (Union[BinaryIO, str]): The data to be stored, which can be a string or a
binary stream.
- overwrite (bool): If True, overwrite the existing file.
"""
parsed_storage_path = get_parsed_path(self.storage_path)
full_file_path = os.path.join(parsed_storage_path, file_path)
file_dir_path = os.path.dirname(full_file_path)
self.ensure_directory_exists(file_dir_path)
if overwrite or not os.path.exists(full_file_path):
with open(
full_file_path,
mode="w" if isinstance(data, str) else "wb",
encoding="utf-8" if isinstance(data, str) else None,
) as file:
if hasattr(data, "read"):
data.seek(0)
file.write(data.read())
else:
file.write(data)
file.close()
return "file://" + full_file_path
@contextmanager
def open(self, file_path: str, mode: str = "rb", *args, **kwargs):
"""
Retrieve data from a specified file path, returning the content as bytes.
This method opens the file in read mode and reads its content. The function expects the
file to exist; if it does not, a FileNotFoundError will be raised.
Parameters:
-----------
- file_path (str): The relative path of the file to retrieve data from.
- mode (str): The mode to open the file, with "rb" as the default for reading binary
files. (default "rb")
Returns:
--------
The content of the retrieved file as bytes.
"""
parsed_storage_path = get_parsed_path(self.storage_path)
full_file_path = os.path.join(parsed_storage_path, file_path)
# Add debug information for Windows path issues
if not os.path.exists(full_file_path):
# Try to provide helpful debug information
if os.path.exists(parsed_storage_path):
available_files = []
try:
available_files = os.listdir(parsed_storage_path)
except (OSError, PermissionError):
available_files = ["<unable to list directory>"]
raise FileNotFoundError(
f"File not found: '{full_file_path}'\n"
f"Storage path: '{parsed_storage_path}'\n"
f"Requested file: '{file_path}'\n"
f"Storage path exists: {os.path.exists(parsed_storage_path)}\n"
f"Available files in storage: {available_files[:10]}..." # Limit to first 10 files
)
else:
raise FileNotFoundError(
f"Storage directory does not exist: '{parsed_storage_path}'\n"
f"Original storage path: '{self.storage_path}'\n"
f"Requested file: '{file_path}'"
)
with open(full_file_path, mode=mode, *args, **kwargs) as file:
file = FileBufferedReader(file, name="file://" + full_file_path)
try:
yield file
finally:
file.close()
def file_exists(self, file_path: str):
"""
Check if a specified file exists in the storage.
Parameters:
-----------
- file_path (str): The path of the file to check for existence.
Returns:
--------
- bool: True if the file exists, otherwise False.
"""
parsed_storage_path = get_parsed_path(self.storage_path)
return os.path.exists(os.path.join(parsed_storage_path, file_path))
def ensure_directory_exists(self, directory_path: str = None):
"""
Ensure that the specified directory exists, creating it if necessary.
If the directory already exists, no action is taken.
Parameters:
-----------
- directory_path (str): The path of the directory to check or create.
"""
if directory_path is None:
directory_path = get_parsed_path(self.storage_path)
elif not directory_path or directory_path.strip() == "":
# Handle empty string case - use current directory or storage path
directory_path = get_parsed_path(self.storage_path) or "."
if not os.path.exists(directory_path):
os.makedirs(directory_path, exist_ok=True)
def copy_file(self, source_file_path: str, destination_file_path: str):
"""
Copy a file from a source path to a destination path.
Files need to be in the same storage.
Parameters:
-----------
- source_file_path (str): The path of the file to be copied.
- destination_file_path (str): The path where the file will be copied to.
Returns:
--------
- str: The path to the copied file.
"""
parsed_storage_path = get_parsed_path(self.storage_path)
return shutil.copy2(
os.path.join(parsed_storage_path, source_file_path),
os.path.join(parsed_storage_path, destination_file_path),
)
def remove(self, file_path: str):
"""
Remove the specified file from the storage if it exists.
Parameters:
-----------
- file_path (str): The path of the file to be removed.
"""
parsed_storage_path = get_parsed_path(self.storage_path)
full_file_path = os.path.join(parsed_storage_path, file_path)
if os.path.exists(full_file_path):
os.remove(full_file_path)
def remove_all(self, tree_path: str = None):
"""
Remove an entire directory tree at the specified path, including all files and
subdirectories.
If the directory does not exist, no action is taken and no exception is raised.
If directories don't exist in the storage we ignore it.
Parameters:
-----------
- tree_path (str): The root path of the directory tree to be removed.
"""
parsed_storage_path = get_parsed_path(self.storage_path)
if tree_path is None:
tree_path = parsed_storage_path
else:
tree_path = os.path.join(parsed_storage_path, tree_path)
try:
return shutil.rmtree(tree_path)
except FileNotFoundError:
pass

View file

@ -1,153 +0,0 @@
import os
import shutil
from typing import BinaryIO, Union
from .StorageManager import Storage
class LocalStorage(Storage):
"""
Manage local file storage operations such as storing, retrieving, and managing files on
the filesystem.
"""
storage_path: str = None
def __init__(self, storage_path: str):
self.storage_path = storage_path
def store(self, file_path: str, data: Union[BinaryIO, str]):
"""
Store data into a specified file path. The data can be either a string or a binary
stream.
This method ensures that the storage directory exists before attempting to write the
data. If the provided data is a stream, it reads from the stream and writes to the file;
otherwise, it directly writes the provided data.
Parameters:
-----------
- file_path (str): The relative path of the file where the data will be stored.
- data (Union[BinaryIO, str]): The data to be stored, which can be a string or a
binary stream.
"""
full_file_path = self.storage_path + "/" + file_path
LocalStorage.ensure_directory_exists(self.storage_path)
with open(
full_file_path,
mode="w" if isinstance(data, str) else "wb",
encoding="utf-8" if isinstance(data, str) else None,
) as f:
if hasattr(data, "read"):
data.seek(0)
f.write(data.read())
else:
f.write(data)
def retrieve(self, file_path: str, mode: str = "rb"):
"""
Retrieve data from a specified file path, returning the content as bytes.
This method opens the file in read mode and reads its content. The function expects the
file to exist; if it does not, a FileNotFoundError will be raised.
Parameters:
-----------
- file_path (str): The relative path of the file to retrieve data from.
- mode (str): The mode to open the file, with 'rb' as the default for reading binary
files. (default 'rb')
Returns:
--------
The content of the retrieved file as bytes.
"""
full_file_path = self.storage_path + "/" + file_path
with open(full_file_path, mode=mode) as f:
f.seek(0)
return f.read()
@staticmethod
def file_exists(file_path: str):
"""
Check if a specified file exists in the filesystem.
Parameters:
-----------
- file_path (str): The path of the file to check for existence.
Returns:
--------
- bool: True if the file exists, otherwise False.
"""
return os.path.exists(file_path)
@staticmethod
def ensure_directory_exists(file_path: str):
"""
Ensure that the specified directory exists, creating it if necessary.
If the directory already exists, no action is taken.
Parameters:
-----------
- file_path (str): The path of the directory to check or create.
"""
if not os.path.exists(file_path):
os.makedirs(file_path, exist_ok=True)
@staticmethod
def remove(file_path: str):
"""
Remove the specified file from the filesystem if it exists.
Parameters:
-----------
- file_path (str): The path of the file to be removed.
"""
if os.path.exists(file_path):
os.remove(file_path)
@staticmethod
def copy_file(source_file_path: str, destination_file_path: str):
"""
Copy a file from a source path to a destination path.
Parameters:
-----------
- source_file_path (str): The path of the file to be copied.
- destination_file_path (str): The path where the file will be copied to.
Returns:
--------
- str: The path to the copied file.
"""
return shutil.copy2(source_file_path, destination_file_path)
@staticmethod
def remove_all(tree_path: str):
"""
Remove an entire directory tree at the specified path, including all files and
subdirectories.
If the directory does not exist, no action is taken and no exception is raised.
Parameters:
-----------
- tree_path (str): The root path of the directory tree to be removed.
"""
try:
shutil.rmtree(tree_path)
except FileNotFoundError:
pass

View file

@ -0,0 +1,215 @@
import os
import s3fs
from typing import BinaryIO, Union
from contextlib import asynccontextmanager
from cognee.api.v1.add.config import get_s3_config
from cognee.infrastructure.utils.run_async import run_async
from cognee.infrastructure.files.storage.FileBufferedReader import FileBufferedReader
from .storage import Storage
class S3FileStorage(Storage):
"""
Manage local file storage operations such as storing, retrieving, and managing files on
the filesystem.
"""
storage_path: str
s3: s3fs.S3FileSystem
def __init__(self, storage_path: str):
self.storage_path = storage_path
s3_config = get_s3_config()
if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None:
self.s3 = s3fs.S3FileSystem(
key=s3_config.aws_access_key_id,
secret=s3_config.aws_secret_access_key,
anon=False,
endpoint_url="https://s3-eu-west-1.amazonaws.com",
)
else:
raise ValueError("S3 credentials are not set in the configuration.")
async def store(
self, file_path: str, data: Union[BinaryIO, str], overwrite: bool = False
) -> str:
"""
Store data into a specified file path. The data can be either a string or a binary
stream.
This method ensures that the storage directory exists before attempting to write the
data. If the provided data is a stream, it reads from the stream and writes to the file;
otherwise, it directly writes the provided data.
Parameters:
-----------
- file_path (str): The relative path of the file where the data will be stored.
- data (Union[BinaryIO, str]): The data to be stored, which can be a string or a
binary stream.
- overwrite (bool): If True, overwrite the existing file.
"""
full_file_path = os.path.join(self.storage_path.replace("s3://", ""), file_path)
file_dir_path = os.path.dirname(full_file_path)
await self.ensure_directory_exists(file_dir_path)
if overwrite or not await self.file_exists(file_path):
def save_data_to_file():
with self.s3.open(
full_file_path,
mode="w" if isinstance(data, str) else "wb",
encoding="utf-8" if isinstance(data, str) else None,
) as file:
if hasattr(data, "read"):
data.seek(0)
file.write(data.read())
else:
file.write(data)
file.close()
await run_async(save_data_to_file)
return "s3://" + full_file_path
@asynccontextmanager
async def open(self, file_path: str, mode: str = "r"):
"""
Retrieve data from a specified file path, returning the content as bytes.
This method opens the file in read mode and reads its content. The function expects the
file to exist; if it does not, a FileNotFoundError will be raised.
Parameters:
-----------
- file_path (str): The relative path of the file to retrieve data from.
- mode (str): The mode to open the file, with "r" as the default for reading binary
files. (default "r")
Returns:
--------
The content of the retrieved file as bytes.
"""
full_file_path = os.path.join(self.storage_path.replace("s3://", ""), file_path)
def get_file():
return self.s3.open(full_file_path, mode=mode)
file = await run_async(get_file)
file = FileBufferedReader(file, name="s3://" + full_file_path)
try:
yield file
finally:
file.close()
async def file_exists(self, file_path: str):
"""
Check if a specified file exists in the filesystem.
Parameters:
-----------
- file_path (str): The path of the file to check for existence.
Returns:
--------
- bool: True if the file exists, otherwise False.
"""
return await run_async(
self.s3.exists, os.path.join(self.storage_path.replace("s3://", ""), file_path)
)
async def ensure_directory_exists(self, directory_path: str = None):
"""
Ensure that the specified directory exists, creating it if necessary.
If the directory already exists, no action is taken.
Parameters:
-----------
- directory_path (str): The path of the directory to check or create.
"""
if directory_path is None:
directory_path = self.storage_path.replace("s3://", "")
def ensure_directory():
if not self.s3.exists(directory_path):
self.s3.makedirs(directory_path, exist_ok=True)
await run_async(ensure_directory)
async def copy_file(self, source_file_path: str, destination_file_path: str):
"""
Copy a file from a source path to a destination path.
Parameters:
-----------
- source_file_path (str): The path of the file to be copied.
- destination_file_path (str): The path where the file will be copied to.
Returns:
--------
- str: The path to the copied file.
"""
def copy():
return self.s3.copy(
os.path.join(self.storage_path.replace("s3://", ""), source_file_path),
os.path.join(self.storage_path.replace("s3://", ""), destination_file_path),
recursive=True,
)
return await run_async(copy)
async def remove(self, file_path: str):
"""
Remove the specified file from the filesystem if it exists.
Parameters:
-----------
- file_path (str): The path of the file to be removed.
"""
full_file_path = os.path.join(self.storage_path.replace("s3://", ""), file_path)
def remove_file():
if self.s3.exists(full_file_path):
self.s3.rm_file(full_file_path)
await run_async(remove_file)
async def remove_all(self, tree_path: str):
"""
Remove an entire directory tree at the specified path, including all files and
subdirectories.
If the directory does not exist, no action is taken and no exception is raised.
Parameters:
-----------
- tree_path (str): The root path of the directory tree to be removed.
"""
if tree_path is None:
tree_path = self.storage_path.replace("s3://", "")
else:
tree_path = os.path.join(self.storage_path.replace("s3://", ""), tree_path)
# async_remove_all = run_async(lambda: self.s3.rm(tree_path, recursive=True))
try:
# await async_remove_all()
await run_async(self.s3.rm, tree_path, recursive=True)
except FileNotFoundError:
pass

View file

@ -1,45 +1,8 @@
from typing import Protocol, BinaryIO
import inspect
from typing import BinaryIO
from contextlib import asynccontextmanager
class Storage(Protocol):
"""
Abstract interface for storage operations.
"""
def store(self, file_path: str, data: bytes):
"""
Store data at the specified file path.
Parameters:
-----------
- file_path (str): The path where the data will be stored.
- data (bytes): The binary data to be stored.
"""
pass
def retrieve(self, file_path: str):
"""
Retrieve data from the specified file path.
Parameters:
-----------
- file_path (str): The path from where the data will be retrieved.
"""
pass
@staticmethod
def remove(file_path: str):
"""
Remove the storage at the specified file path.
Parameters:
-----------
- file_path (str): The path of the file to be removed.
"""
pass
from .storage import Storage
class StorageManager:
@ -48,8 +11,9 @@ class StorageManager:
Public methods include:
- store: Store data in the specified path.
- retrieve: Retrieve data from the specified path.
- open: Open a file from the specified path.
- remove: Remove the file at the specified path.
- remove_all: Remove all files under the directory tree.
"""
storage: Storage = None
@ -57,7 +21,26 @@ class StorageManager:
def __init__(self, storage: Storage):
self.storage = storage
def store(self, file_path: str, data: BinaryIO):
async def file_exists(self, file_path: str):
"""
Check if a specified file exists in the storage.
Parameters:
-----------
- file_path (str): The path of the file to check for existence.
Returns:
--------
- bool: True if the file exists, otherwise False.
"""
if inspect.iscoroutinefunction(self.storage.file_exists):
return await self.storage.file_exists(file_path)
else:
return self.storage.file_exists(file_path)
async def store(self, file_path: str, data: BinaryIO, overwrite: bool = False) -> str:
"""
Store data at the specified file path.
@ -66,16 +49,20 @@ class StorageManager:
- file_path (str): The path where the data should be stored.
- data (BinaryIO): The data in a binary format that needs to be stored.
- overwrite (bool): If True, overwrite the existing file.
Returns:
--------
Returns the outcome of the store operation, as defined by the storage
implementation.
Returns the full path to the file.
"""
return self.storage.store(file_path, data)
if inspect.iscoroutinefunction(self.storage.store):
return await self.storage.store(file_path, data, overwrite)
else:
return self.storage.store(file_path, data, overwrite)
def retrieve(self, file_path: str):
@asynccontextmanager
async def open(self, file_path: str, encoding: str = None, *args, **kwargs):
"""
Retrieve data from the specified file path.
@ -89,9 +76,34 @@ class StorageManager:
Returns the retrieved data, as defined by the storage implementation.
"""
return self.storage.retrieve(file_path)
# Check the actual storage type by class name to determine if open() is async or sync
def remove(self, file_path: str):
if self.storage.__class__.__name__ == "S3FileStorage":
# S3FileStorage.open() is async
async with self.storage.open(file_path, *args, **kwargs) as file:
yield file
else:
# LocalFileStorage.open() is sync
with self.storage.open(file_path, *args, **kwargs) as file:
yield file
async def ensure_directory_exists(self, directory_path: str = None):
"""
Ensure that the specified directory exists, creating it if necessary.
If the directory already exists, no action is taken.
Parameters:
-----------
- directory_path (str): The path of the directory to check or create.
"""
if inspect.iscoroutinefunction(self.storage.ensure_directory_exists):
return await self.storage.ensure_directory_exists(directory_path)
else:
return self.storage.ensure_directory_exists(directory_path)
async def remove(self, file_path: str):
"""
Remove the file at the specified path.
@ -106,4 +118,24 @@ class StorageManager:
Returns the outcome of the remove operation, as defined by the storage
implementation.
"""
return self.storage.remove(file_path)
if inspect.iscoroutinefunction(self.storage.remove):
return await self.storage.remove(file_path)
else:
return self.storage.remove(file_path)
async def remove_all(self, tree_path: str = None):
"""
Remove an entire directory tree at the specified path, including all files and
subdirectories.
If the directory does not exist, no action is taken and no exception is raised.
Parameters:
-----------
- tree_path (str): The root path of the directory tree to be removed.
"""
if inspect.iscoroutinefunction(self.storage.remove_all):
return await self.storage.remove_all(tree_path)
else:
return self.storage.remove_all(tree_path)

View file

@ -1 +1,3 @@
from .LocalStorage import LocalStorage
from .StorageManager import StorageManager
from .get_file_storage import get_file_storage
from .get_storage_config import get_storage_config

View file

@ -0,0 +1,4 @@
from contextvars import ContextVar
file_storage_config = ContextVar("file_storage_config", default=None)

View file

@ -0,0 +1,23 @@
import os
from cognee.base_config import get_base_config
from .StorageManager import StorageManager
def get_file_storage(storage_path: str) -> StorageManager:
base_config = get_base_config()
# Use S3FileStorage if the storage_path is an S3 URL or if configured for S3
if storage_path.startswith("s3://") or (
os.getenv("STORAGE_BACKEND") == "s3"
and "s3://" in base_config.system_root_directory
and "s3://" in base_config.data_root_directory
):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
return StorageManager(S3FileStorage(storage_path))
else:
from cognee.infrastructure.files.storage.LocalFileStorage import LocalFileStorage
return StorageManager(LocalFileStorage(storage_path))

View file

@ -0,0 +1,18 @@
from cognee.base_config import get_base_config
from .config import file_storage_config
def get_global_storage_config():
base_config = get_base_config()
return {
"data_root_directory": base_config.data_root_directory,
}
def get_storage_config():
context_config = file_storage_config.get()
if context_config:
return context_config
return get_global_storage_config()

View file

@ -0,0 +1,105 @@
from typing import BinaryIO, Protocol, Union
class Storage(Protocol):
storage_path: str
"""
Abstract interface for storage operations.
"""
def file_exists(self, file_path: str) -> bool:
"""
Check if a specified file exists in the storage.
Parameters:
-----------
- file_path (str): The path of the file to check for existence.
Returns:
--------
- bool: True if the file exists, otherwise False.
"""
pass
def store(self, file_path: str, data: Union[BinaryIO, str], overwrite: bool):
"""
Store data at the specified file path.
Parameters:
-----------
- file_path (str): The path where the data will be stored.
- data (bytes): The binary data to be stored.
- overwrite (bool): If True, overwrite the existing file.
"""
pass
def open(self, file_path: str, mode: str = "r"):
"""
Retrieve file from the specified file path.
Parameters:
-----------
- file_path (str): The path from where the data will be retrieved.
- mode (str): The mode to open the file, with "r" as the default for reading text
"""
pass
def copy_file(self, source_file_path: str, destination_file_path: str):
"""
Copy a file from a source path to a destination path.
Parameters:
-----------
- source_file_path (str): The path of the file to be copied.
- destination_file_path (str): The path where the file will be copied to.
Returns:
--------
- str: The path to the copied file.
"""
pass
def ensure_directory_exists(self, directory_path: str = None):
"""
Ensure that the specified directory exists, creating it if necessary.
If the directory already exists, no action is taken.
Parameters:
-----------
- directory_path (str): The path of the directory to check or create.
"""
pass
def remove(self, file_path: str):
"""
Remove the storage at the specified file path.
Parameters:
-----------
- file_path (str): The path of the file to be removed.
"""
pass
def remove_all(self, root_path: str = None):
"""
Remove an entire directory tree at the specified path, including all files and
subdirectories.
If the directory does not exist, no action is taken and no exception is raised.
Parameters:
-----------
- tree_path (str): The root path of the directory tree to be removed.
"""
pass

View file

@ -0,0 +1,40 @@
import hashlib
import os
from os import path
from typing import BinaryIO, Union
from ..storage import get_file_storage
from ..exceptions import FileContentHashingError
async def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
h = hashlib.md5()
try:
if isinstance(file_obj, str):
# Normalize path separators to handle mixed separators on Windows
normalized_path = os.path.normpath(file_obj)
file_dir_path = path.dirname(normalized_path)
file_path = path.basename(normalized_path)
file_storage = get_file_storage(file_dir_path)
async with file_storage.open(file_path, "rb") as file:
while True:
# Reading is buffered, so we can read smaller chunks.
chunk = file.read(h.block_size)
if not chunk:
break
h.update(chunk)
else:
while True:
# Reading is buffered, so we can read smaller chunks.
chunk = file_obj.read(h.block_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
except IOError as e:
raise FileContentHashingError(message=f"Failed to hash data from {file_obj}: {e}")

View file

@ -3,7 +3,7 @@ import os.path
from typing import BinaryIO, TypedDict
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import get_file_content_hash
from cognee.infrastructure.files.utils.get_file_content_hash import get_file_content_hash
from .guess_file_type import guess_file_type
logger = get_logger("FileMetadata")
@ -26,7 +26,7 @@ class FileMetadata(TypedDict):
file_size: int
def get_file_metadata(file: BinaryIO) -> FileMetadata:
async def get_file_metadata(file: BinaryIO) -> FileMetadata:
"""
Retrieve metadata from a file object.
@ -47,7 +47,7 @@ def get_file_metadata(file: BinaryIO) -> FileMetadata:
"""
try:
file.seek(0)
content_hash = get_file_content_hash(file)
content_hash = await get_file_content_hash(file)
file.seek(0)
except io.UnsupportedOperation as error:
logger.error(f"Error retrieving content hash for file: {file.name} \n{str(error)}\n\n")

View file

@ -0,0 +1,65 @@
import os
from os import path
from urllib.parse import urlparse
from contextlib import asynccontextmanager
from cognee.infrastructure.files.storage import get_file_storage
@asynccontextmanager
async def open_data_file(file_path: str, mode: str = "rb", encoding: str = None, **kwargs):
# Check if this is a file URI BEFORE normalizing (which corrupts URIs)
if file_path.startswith("file://"):
# Normalize the file URI for Windows - replace backslashes with forward slashes
normalized_file_uri = file_path.replace("\\", "/")
parsed_url = urlparse(normalized_file_uri)
# Convert URI path to file system path
if os.name == "nt": # Windows
# Handle Windows drive letters correctly
fs_path = parsed_url.path
if fs_path.startswith("/") and len(fs_path) > 1 and fs_path[2] == ":":
fs_path = fs_path[1:] # Remove leading slash for Windows drive paths
else: # Unix-like systems
fs_path = parsed_url.path
# Now split the actual filesystem path
actual_fs_path = os.path.normpath(fs_path)
file_dir_path = path.dirname(actual_fs_path)
file_name = path.basename(actual_fs_path)
elif file_path.startswith("s3://"):
# Handle S3 URLs without normalization (which corrupts them)
parsed_url = urlparse(file_path)
# For S3, reconstruct the directory path and filename
s3_path = parsed_url.path.lstrip("/") # Remove leading slash
if "/" in s3_path:
s3_dir = "/".join(s3_path.split("/")[:-1])
s3_filename = s3_path.split("/")[-1]
else:
s3_dir = ""
s3_filename = s3_path
# Extract filesystem path from S3 URL structure
file_dir_path = (
f"s3://{parsed_url.netloc}/{s3_dir}" if s3_dir else f"s3://{parsed_url.netloc}"
)
file_name = s3_filename
else:
# Regular file path - normalize separators
normalized_path = os.path.normpath(file_path)
file_dir_path = path.dirname(normalized_path)
file_name = path.basename(normalized_path)
# Validate that we have a proper filename
if not file_name or file_name == "." or file_name == "..":
raise ValueError(f"Invalid filename extracted: '{file_name}' from path: '{file_path}'")
file_storage = get_file_storage(file_dir_path)
async with file_storage.open(file_name, mode=mode, encoding=encoding, **kwargs) as file:
yield file

View file

@ -1,16 +1,15 @@
from typing import Type
from pydantic import BaseModel
import base64
import instructor
from typing import Type
from openai import OpenAI
from pydantic import BaseModel
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.rate_limiter import (
rate_limit_async,
rate_limit_sync,
sleep_and_retry_async,
)
from openai import OpenAI
import base64
import os
from cognee.infrastructure.files.utils.open_data_file import open_data_file
class OllamaAPIAdapter(LLMInterface):
@ -87,8 +86,8 @@ class OllamaAPIAdapter(LLMInterface):
return response
@rate_limit_sync
def create_transcript(self, input_file: str) -> str:
@rate_limit_async
async def create_transcript(self, input_file: str) -> str:
"""
Generate an audio transcript from a user query.
@ -107,10 +106,7 @@ class OllamaAPIAdapter(LLMInterface):
- str: The transcription of the audio as a string.
"""
if not os.path.isfile(input_file):
raise FileNotFoundError(f"The file {input_file} does not exist.")
with open(input_file, "rb") as audio_file:
async with open_data_file(input_file, mode="rb") as audio_file:
transcription = self.aclient.audio.transcriptions.create(
model="whisper-1", # Ensure the correct model for transcription
file=audio_file,
@ -123,8 +119,8 @@ class OllamaAPIAdapter(LLMInterface):
return transcription.text
@rate_limit_sync
def transcribe_image(self, input_file: str) -> str:
@rate_limit_async
async def transcribe_image(self, input_file: str) -> str:
"""
Transcribe content from an image using base64 encoding.
@ -144,10 +140,7 @@ class OllamaAPIAdapter(LLMInterface):
- str: The transcription of the image's content as a string.
"""
if not os.path.isfile(input_file):
raise FileNotFoundError(f"The file {input_file} does not exist.")
with open(input_file, "rb") as image_file:
async with open_data_file(input_file, mode="rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
response = self.aclient.chat.completions.create(

View file

@ -1,4 +1,3 @@
import os
import base64
import litellm
import instructor
@ -12,7 +11,7 @@ from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.infrastructure.llm.rate_limiter import (
rate_limit_async,
rate_limit_sync,
@ -222,8 +221,8 @@ class OpenAIAdapter(LLMInterface):
max_retries=self.MAX_RETRIES,
)
@rate_limit_sync
def create_transcript(self, input):
@rate_limit_async
async def create_transcript(self, input):
"""
Generate an audio transcript from a user query.
@ -242,10 +241,7 @@ class OpenAIAdapter(LLMInterface):
The generated transcription of the audio file.
"""
if not input.startswith("s3://") and not os.path.isfile(input):
raise FileNotFoundError(f"The file {input} does not exist.")
with open_data_file(input, mode="rb") as audio_file:
async with open_data_file(input, mode="rb") as audio_file:
transcription = litellm.transcription(
model=self.transcription_model,
file=audio_file,
@ -257,8 +253,8 @@ class OpenAIAdapter(LLMInterface):
return transcription
@rate_limit_sync
def transcribe_image(self, input) -> BaseModel:
@rate_limit_async
async def transcribe_image(self, input) -> BaseModel:
"""
Generate a transcription of an image from a user query.
@ -276,7 +272,7 @@ class OpenAIAdapter(LLMInterface):
- BaseModel: A structured output generated by the model, returned as an instance of
BaseModel.
"""
with open_data_file(input, mode="rb") as image_file:
async with open_data_file(input, mode="rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
return litellm.completion(

View file

@ -0,0 +1,13 @@
import asyncio
from functools import partial
async def run_async(func, *args, loop=None, executor=None, **kwargs):
if loop is None:
try:
running_loop = asyncio.get_running_loop()
except RuntimeError:
running_loop = asyncio.get_event_loop()
pfunc = partial(func, *args, **kwargs)
return await running_loop.run_in_executor(executor, pfunc)

View file

@ -0,0 +1,25 @@
import asyncio
import threading
def run_sync(coro, timeout=None):
result = None
exception = None
def runner():
nonlocal result, exception
try:
result = asyncio.run(coro)
except Exception as e:
exception = e
thread = threading.Thread(target=runner)
thread.start()
thread.join(timeout)
if thread.is_alive():
raise asyncio.TimeoutError("Coroutine execution timed out.")
if exception:
raise exception
return result

View file

@ -33,8 +33,8 @@ class LangchainChunker(Chunker):
length_function=lambda text: len(text.split()),
)
def read(self):
for content_text in self.get_text():
async def read(self):
async for content_text in self.get_text():
for chunk in self.splitter.split_text(content_text):
embedding_engine = get_vector_engine().embedding_engine
token_count = embedding_engine.tokenizer.count_tokens(chunk)

View file

@ -9,9 +9,9 @@ logger = get_logger()
class TextChunker(Chunker):
def read(self):
async def read(self):
paragraph_chunks = []
for content_text in self.get_text():
async for content_text in self.get_text():
for chunk_data in chunk_by_paragraph(
content_text,
self.max_chunk_size,

View file

@ -1,8 +1,7 @@
from cognee.base_config import get_base_config
from cognee.infrastructure.files.storage import LocalStorage
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
async def prune_data():
base_config = get_base_config()
data_root_directory = base_config.data_root_directory
LocalStorage.remove_all(data_root_directory)
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
await get_file_storage(data_root_directory).remove_all()

View file

@ -7,15 +7,16 @@ from .Document import Document
class AudioDocument(Document):
type: str = "audio"
def create_transcript(self):
result = get_llm_client().create_transcript(self.raw_data_location)
async def create_transcript(self):
result = await get_llm_client().create_transcript(self.raw_data_location)
return result.text
def read(self, chunker_cls: Chunker, max_chunk_size: int):
# Transcribe the audio file
async def read(self, chunker_cls: Chunker, max_chunk_size: int):
async def get_text():
# Transcribe the audio file
yield await self.create_transcript()
text = self.create_transcript()
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=lambda: [text])
yield from chunker.read()
async for chunk in chunker.read():
yield chunk

View file

@ -10,5 +10,5 @@ class Document(DataPoint):
mime_type: str
metadata: dict = {"index_fields": ["name"]}
def read(self, chunker_cls: Chunker, max_chunk_size: int) -> str:
async def read(self, chunker_cls: Chunker, max_chunk_size: int) -> str:
pass

View file

@ -7,14 +7,16 @@ from .Document import Document
class ImageDocument(Document):
type: str = "image"
def transcribe_image(self):
result = get_llm_client().transcribe_image(self.raw_data_location)
async def transcribe_image(self):
result = await get_llm_client().transcribe_image(self.raw_data_location)
return result.choices[0].message.content
def read(self, chunker_cls: Chunker, max_chunk_size: int):
# Transcribe the image file
text = self.transcribe_image()
async def read(self, chunker_cls: Chunker, max_chunk_size: int):
async def get_text():
# Transcribe the image file
yield await self.transcribe_image()
chunker = chunker_cls(self, get_text=lambda: [text], max_chunk_size=max_chunk_size)
chunker = chunker_cls(self, get_text=get_text, max_chunk_size=max_chunk_size)
yield from chunker.read()
async for chunk in chunker.read():
yield chunk

View file

@ -1,25 +1,27 @@
from pypdf import PdfReader
from cognee.modules.chunking.Chunker import Chunker
from .open_data_file import open_data_file
from .Document import Document
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.modules.data.processing.document_types.exceptions.exceptions import PyPdfInternalError
from .Document import Document
logger = get_logger("PDFDocument")
class PdfDocument(Document):
type: str = "pdf"
def read(self, chunker_cls: Chunker, max_chunk_size: int):
with open_data_file(self.raw_data_location, mode="rb") as stream:
async def read(self, chunker_cls: Chunker, max_chunk_size: int):
async with open_data_file(self.raw_data_location, mode="rb") as stream:
logger.info(f"Reading PDF:{self.raw_data_location}")
try:
file = PdfReader(stream, strict=False)
except Exception:
raise PyPdfInternalError()
def get_text():
async def get_text():
try:
for page in file.pages:
page_text = page.extract_text()
@ -29,4 +31,5 @@ class PdfDocument(Document):
chunker = chunker_cls(self, get_text=get_text, max_chunk_size=max_chunk_size)
yield from chunker.read()
async for chunk in chunker.read():
yield chunk

View file

@ -1,15 +1,15 @@
from .Document import Document
from cognee.modules.chunking.Chunker import Chunker
from .open_data_file import open_data_file
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from .Document import Document
class TextDocument(Document):
type: str = "text"
mime_type: str = "text/plain"
def read(self, chunker_cls: Chunker, max_chunk_size: int):
def get_text():
with open_data_file(self.raw_data_location, mode="r", encoding="utf-8") as file:
async def read(self, chunker_cls: Chunker, max_chunk_size: int):
async def get_text():
async with open_data_file(self.raw_data_location, mode="r", encoding="utf-8") as file:
while True:
text = file.read(1000000)
if not text.strip():
@ -17,4 +17,6 @@ class TextDocument(Document):
yield text
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
yield from chunker.read()
async for chunk in chunker.read():
yield chunk

View file

@ -1,8 +1,9 @@
from io import StringIO
from typing import Any, AsyncGenerator
from cognee.modules.chunking.Chunker import Chunker
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from .Document import Document
@ -10,15 +11,15 @@ from .Document import Document
class UnstructuredDocument(Document):
type: str = "unstructured"
def read(self, chunker_cls: Chunker, max_chunk_size: int) -> str:
def get_text():
async def read(self, chunker_cls: Chunker, max_chunk_size: int) -> AsyncGenerator[Any, Any]:
async def get_text():
try:
from unstructured.partition.auto import partition
except ModuleNotFoundError:
raise UnstructuredLibraryImportError
if self.raw_data_location.startswith("s3://"):
with open_data_file(self.raw_data_location, mode="rb") as f:
async with open_data_file(self.raw_data_location, mode="rb") as f:
elements = partition(file=f, content_type=self.mime_type)
else:
elements = partition(self.raw_data_location, content_type=self.mime_type)
@ -34,4 +35,5 @@ class UnstructuredDocument(Document):
chunker = chunker_cls(self, get_text=get_text, max_chunk_size=max_chunk_size)
yield from chunker.read()
async for chunk in chunker.read():
yield chunk

View file

@ -1,41 +0,0 @@
from typing import IO, Optional
from urllib.parse import urlparse
import os
from cognee.api.v1.add.config import get_s3_config
def open_data_file(
file_path: str, mode: str = "rb", encoding: Optional[str] = None, **kwargs
) -> IO:
if file_path.startswith("s3://"):
s3_config = get_s3_config()
if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None:
import s3fs
fs = s3fs.S3FileSystem(
key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, anon=False
)
else:
raise ValueError("S3 credentials are not set in the configuration.")
if "b" in mode:
f = fs.open(file_path, mode=mode, **kwargs)
if not hasattr(f, "name") or not f.name:
f.name = file_path.split("/")[-1]
return f
else:
return fs.open(file_path, mode=mode, encoding=encoding, **kwargs)
elif file_path.startswith("file://"):
# Handle local file URLs by properly parsing the URI
parsed_url = urlparse(file_path)
# On Windows, urlparse handles drive letters correctly
# Convert the path component to a proper file path
if os.name == "nt": # Windows
# Remove leading slash from Windows paths like /C:/Users/...
local_path = parsed_url.path.lstrip("/")
else: # Unix-like systems
local_path = parsed_url.path
return open(local_path, mode=mode, encoding=encoding, **kwargs)
else:
return open(file_path, mode=mode, encoding=encoding, **kwargs)

View file

@ -1,29 +1,29 @@
from os import path
from io import BufferedReader
from typing import Union, BinaryIO, Optional, Any
from .data_types import TextData, BinaryData
from typing import Union, BinaryIO
from tempfile import SpooledTemporaryFile
from cognee.modules.ingestion.exceptions import IngestionError
try:
from s3fs.core import S3File
from cognee.modules.ingestion.data_types.S3BinaryData import S3BinaryData
except ImportError:
S3File = None
S3BinaryData = None
from .data_types import TextData, BinaryData, S3BinaryData
def classify(data: Union[str, BinaryIO], filename: str = None, s3fs: Optional[Any] = None):
def classify(data: Union[str, BinaryIO], filename: str = None):
if isinstance(data, str):
return TextData(data)
if isinstance(data, BufferedReader) or isinstance(data, SpooledTemporaryFile):
return BinaryData(data, str(data.name).split("/")[-1] if data.name else filename)
return BinaryData(
data, str(data.name).split("/")[-1] if hasattr(data, "name") else filename
)
try:
from s3fs import S3File
except ImportError:
S3File = None
if S3File is not None:
if isinstance(data, S3File):
derived_filename = str(data.full_name).split("/")[-1] if data.full_name else filename
return S3BinaryData(s3_path=data.full_name, name=derived_filename, s3=s3fs)
return S3BinaryData(s3_path=path.join("s3://", data.bucket, data.key), name=data.key)
raise IngestionError(
message=f"Type of data sent to classify(data: Union[str, BinaryIO) not supported or s3fs is not installed: {type(data)}"

View file

@ -1,5 +1,7 @@
from typing import BinaryIO
from contextlib import asynccontextmanager
from cognee.infrastructure.files import get_file_metadata, FileMetadata
from cognee.infrastructure.utils.run_sync import run_sync
from .IngestionData import IngestionData
@ -22,16 +24,17 @@ class BinaryData(IngestionData):
return metadata["content_hash"]
def get_metadata(self):
self.ensure_metadata()
run_sync(self.ensure_metadata())
return self.metadata
def ensure_metadata(self):
async def ensure_metadata(self):
if self.metadata is None:
self.metadata = get_file_metadata(self.data)
self.metadata = await get_file_metadata(self.data)
if self.metadata["name"] is None:
self.metadata["name"] = self.name
def get_data(self):
return self.data
@asynccontextmanager
async def get_data(self):
yield self.data

View file

@ -1,42 +1,55 @@
import os
from typing import Optional
import s3fs
from contextlib import asynccontextmanager
from cognee.infrastructure.files import get_file_metadata, FileMetadata
from cognee.infrastructure.utils import run_sync
from .IngestionData import IngestionData
def create_s3_binary_data(
s3_path: str, name: Optional[str] = None, s3: Optional[s3fs.S3FileSystem] = None
) -> "S3BinaryData":
return S3BinaryData(s3_path, name=name, s3=s3)
def create_s3_binary_data(s3_path: str, name: Optional[str] = None) -> "S3BinaryData":
return S3BinaryData(s3_path, name=name)
class S3BinaryData(IngestionData):
name: Optional[str] = None
s3_path: str = None
fs: s3fs.S3FileSystem = None
metadata: Optional[FileMetadata] = None
def __init__(
self, s3_path: str, name: Optional[str] = None, s3: Optional[s3fs.S3FileSystem] = None
):
def __init__(self, s3_path: str, name: Optional[str] = None):
self.s3_path = s3_path
self.name = name
self.fs = s3 if s3 is not None else s3fs.S3FileSystem()
def get_identifier(self):
metadata = self.get_metadata()
return metadata["content_hash"]
def get_metadata(self):
self.ensure_metadata()
run_sync(self.ensure_metadata())
return self.metadata
def ensure_metadata(self):
async def ensure_metadata(self):
if self.metadata is None:
with self.fs.open(self.s3_path, "rb") as f:
self.metadata = get_file_metadata(f)
if self.metadata.get("name") is None:
self.metadata["name"] = self.name or self.s3_path.split("/")[-1]
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
def get_data(self):
return self.fs.open(self.s3_path, "rb")
file_dir_path = os.path.dirname(self.s3_path)
file_path = os.path.basename(self.s3_path)
file_storage = S3FileStorage(file_dir_path)
async with file_storage.open(file_path, "rb") as file:
self.metadata = await get_file_metadata(file)
if self.metadata.get("name") is None:
self.metadata["name"] = self.name or file_path
@asynccontextmanager
async def get_data(self):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
file_dir_path = os.path.dirname(self.s3_path)
file_path = os.path.basename(self.s3_path)
file_storage = S3FileStorage(file_dir_path)
async with file_storage.open(file_path, "rb") as file:
yield file

View file

@ -1,4 +1,5 @@
from typing import BinaryIO
from contextlib import asynccontextmanager
from cognee.infrastructure.data.utils.extract_keywords import extract_keywords
from .IngestionData import IngestionData
@ -28,5 +29,6 @@ class TextData(IngestionData):
if self.metadata is None:
self.metadata = {}
def get_data(self):
return self.data
@asynccontextmanager
async def get_data(self):
yield self.data

View file

@ -1,3 +1,4 @@
from .TextData import TextData, create_text_data
from .BinaryData import BinaryData, create_binary_data
from .S3BinaryData import S3BinaryData, create_s3_binary_data
from .IngestionData import IngestionData

View file

@ -1,29 +1,28 @@
import os.path
import hashlib
from typing import BinaryIO, Union
from cognee.base_config import get_base_config
from cognee.infrastructure.files.storage import LocalStorage
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
from .classify import classify
def save_data_to_file(data: Union[str, BinaryIO], filename: str = None):
base_config = get_base_config()
data_directory_path = base_config.data_root_directory
async def save_data_to_file(data: Union[str, BinaryIO], filename: str = None):
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
classified_data = classify(data, filename)
storage_path = os.path.join(data_directory_path, "data")
LocalStorage.ensure_directory_exists(storage_path)
file_metadata = classified_data.get_metadata()
if "name" not in file_metadata or file_metadata["name"] is None:
data_contents = classified_data.get_data().encode("utf-8")
hash_contents = hashlib.md5(data_contents).hexdigest()
file_metadata["name"] = "text_" + hash_contents + ".txt"
file_name = file_metadata["name"]
# Don't save file if it already exists
if not os.path.isfile(os.path.join(storage_path, file_name)):
LocalStorage(storage_path).store(file_name, classified_data.get_data())
async with classified_data.get_data() as data:
if "name" not in file_metadata or file_metadata["name"] is None:
data_contents = data.encode("utf-8")
hash_contents = hashlib.md5(data_contents).hexdigest()
file_metadata["name"] = "text_" + hash_contents + ".txt"
return "file://" + storage_path + "/" + file_name
file_name = file_metadata["name"]
storage = get_file_storage(data_root_directory)
full_file_path = await storage.store(file_name, data)
return full_file_path

View file

@ -1,6 +1,6 @@
import asyncio
from uuid import UUID
from typing import Union
from uuid import NAMESPACE_OID, uuid5, UUID
from cognee.shared.logging_utils import get_logger
from cognee.modules.data.methods.get_dataset_data import get_dataset_data

View file

@ -3,6 +3,7 @@ from uuid import UUID
from typing import Any
from functools import wraps
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
from cognee.modules.users.models import User
@ -104,6 +105,14 @@ async def run_tasks(
dataset_name=dataset.name,
)
graph_engine = await get_graph_engine()
if hasattr(graph_engine, "push_to_s3"):
await graph_engine.push_to_s3()
relational_engine = get_relational_engine()
if hasattr(relational_engine, "push_to_s3"):
await relational_engine.push_to_s3()
except Exception as error:
await log_pipeline_run_error(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error

View file

@ -3,7 +3,7 @@ import json
import networkx
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.files.storage import LocalStorage
from cognee.infrastructure.files.storage.LocalFileStorage import LocalFileStorage
logger = get_logger()
@ -308,10 +308,12 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
home_dir = os.path.expanduser("~")
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
LocalStorage.ensure_directory_exists(os.path.dirname(destination_file_path))
dir_path = os.path.dirname(destination_file_path)
file_path = os.path.basename(destination_file_path)
with open(destination_file_path, "w") as f:
f.write(html_content)
file_storage = LocalFileStorage(dir_path)
file_storage.store(file_path, html_content, overwrite=True)
logger.info(f"Graph visualization saved as {destination_file_path}")

View file

@ -1,10 +1,7 @@
"""This module contains utility functions for the cognee."""
import os
from typing import BinaryIO, Union
import requests
import hashlib
from datetime import datetime, timezone
import graphistry
import networkx as nx
@ -13,14 +10,12 @@ import matplotlib.pyplot as plt
import http.server
import socketserver
from threading import Thread
import sys
import pathlib
from uuid import uuid4
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.graph import get_graph_engine
from uuid import uuid4
import pathlib
from cognee.shared.exceptions import IngestionError
# Analytics Proxy Url, currently hosted by Vercel
proxy_url = "https://test.prometh.ai"
@ -102,182 +97,6 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
print(f"Error sending telemetry through proxy: {response.status_code}")
def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
h = hashlib.md5()
try:
if isinstance(file_obj, str):
with open(file_obj, "rb") as file:
while True:
# Reading is buffered, so we can read smaller chunks.
chunk = file.read(h.block_size)
if not chunk:
break
h.update(chunk)
else:
while True:
# Reading is buffered, so we can read smaller chunks.
chunk = file_obj.read(h.block_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
except IOError as e:
raise IngestionError(message=f"Failed to load data from {file}: {e}")
def generate_color_palette(unique_layers):
colormap = plt.cm.get_cmap("viridis", len(unique_layers))
colors = [colormap(i) for i in range(len(unique_layers))]
hex_colors = [
"#%02x%02x%02x" % (int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
for rgb in colors
]
return dict(zip(unique_layers, hex_colors))
async def register_graphistry():
config = get_base_config()
graphistry.register(
api=3, username=config.graphistry_username, password=config.graphistry_password
)
def prepare_edges(graph, source, target, edge_key):
edge_list = [
{
source: str(edge[0]),
target: str(edge[1]),
edge_key: str(edge[2]),
}
for edge in graph.edges(keys=True, data=True)
]
return pd.DataFrame(edge_list)
def prepare_nodes(graph, include_size=False):
nodes_data = []
for node in graph.nodes:
node_info = graph.nodes[node]
if not node_info:
continue
node_data = {
**node_info,
"id": str(node),
"name": node_info["name"] if "name" in node_info else str(node),
}
if include_size:
default_size = 10 # Default node size
larger_size = 20 # Size for nodes with specific keywords in their ID
keywords = ["DOCUMENT", "User"]
node_size = (
larger_size if any(keyword in str(node) for keyword in keywords) else default_size
)
node_data["size"] = node_size
nodes_data.append(node_data)
return pd.DataFrame(nodes_data)
async def render_graph(
graph=None, include_nodes=True, include_color=False, include_size=False, include_labels=True
):
await register_graphistry()
if not isinstance(graph, nx.MultiDiGraph):
graph_engine = await get_graph_engine()
networkx_graph = nx.MultiDiGraph()
(nodes, edges) = await graph_engine.get_graph_data()
networkx_graph.add_nodes_from(nodes)
networkx_graph.add_edges_from(edges)
graph = networkx_graph
edges = prepare_edges(graph, "source_node", "target_node", "relationship_name")
plotter = graphistry.edges(edges, "source_node", "target_node")
plotter = plotter.bind(edge_label="relationship_name")
if include_nodes:
nodes = prepare_nodes(graph, include_size=include_size)
plotter = plotter.nodes(nodes, "id")
if include_size:
plotter = plotter.bind(point_size="size")
if include_color:
pass
# unique_layers = nodes["layer_description"].unique()
# color_palette = generate_color_palette(unique_layers)
# plotter = plotter.encode_point_color("layer_description", categorical_mapping=color_palette,
# default_mapping="silver")
if include_labels:
plotter = plotter.bind(point_label="name")
# Visualization
url = plotter.plot(render=False, as_files=True, memoize=False)
print(f"Graph is visualized at: {url}")
return url
# def sanitize_df(df):
# """Replace NaNs and infinities in a DataFrame with None, making it JSON compliant."""
# return df.replace([np.inf, -np.inf, np.nan], None)
async def convert_to_serializable_graph(G):
"""
Convert a graph into a serializable format with stringified node and edge attributes.
"""
(nodes, edges) = G
networkx_graph = nx.MultiDiGraph()
networkx_graph.add_nodes_from(nodes)
networkx_graph.add_edges_from(edges)
# Create a new graph to store the serializable version
new_G = nx.MultiDiGraph()
# Serialize nodes
for node, data in networkx_graph.nodes(data=True):
serializable_data = {k: str(v) for k, v in data.items()}
new_G.add_node(str(node), **serializable_data)
# Serialize edges
for u, v, data in networkx_graph.edges(data=True):
serializable_data = {k: str(v) for k, v in data.items()}
new_G.add_edge(str(u), str(v), **serializable_data)
return new_G
def generate_layout_positions(G, layout_func, layout_scale):
"""
Generate layout positions for the graph using the specified layout function.
"""
positions = layout_func(G)
return {str(node): (x * layout_scale, y * layout_scale) for node, (x, y) in positions.items()}
def assign_node_colors(G, node_attribute, palette):
"""
Assign colors to nodes based on a specified attribute and a given palette.
"""
unique_attrs = set(G.nodes[node].get(node_attribute, "Unknown") for node in G.nodes)
color_map = {attr: palette[i % len(palette)] for i, attr in enumerate(unique_attrs)}
return [color_map[G.nodes[node].get(node_attribute, "Unknown")] for node in G.nodes], color_map
def embed_logo(p, layout_scale, logo_alpha, position):
"""
Embed a logo into the graph visualization as a watermark.
@ -307,18 +126,6 @@ def embed_logo(p, layout_scale, logo_alpha, position):
)
def graph_to_tuple(graph):
"""
Converts a networkx graph to a tuple of (nodes, edges).
:param graph: A networkx graph.
:return: A tuple (nodes, edges).
"""
nodes = list(graph.nodes(data=True)) # Get nodes with attributes
edges = list(graph.edges(data=True)) # Get edges with attributes
return (nodes, edges)
def start_visualization_server(
host="0.0.0.0", port=8001, handler_class=http.server.SimpleHTTPRequestHandler
):

View file

@ -1,11 +1,11 @@
from uuid import UUID
from sqlalchemy import select
from typing import AsyncGenerator
from cognee.shared.logging_utils import get_logger
from cognee.modules.data.processing.document_types.Document import Document
from sqlalchemy import select
from cognee.modules.data.models import Data
from cognee.infrastructure.databases.relational import get_relational_engine
from uuid import UUID
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.chunking.Chunker import Chunker
from cognee.modules.data.processing.document_types.exceptions.exceptions import PyPdfInternalError
@ -41,7 +41,9 @@ async def extract_chunks_from_documents(
for document in documents:
document_token_count = 0
try:
for document_chunk in document.read(max_chunk_size=max_chunk_size, chunker_cls=chunker):
async for document_chunk in document.read(
max_chunk_size=max_chunk_size, chunker_cls=chunker
):
document_token_count += document_chunk.chunk_size
document_chunk.belongs_to_set = document.belongs_to_set
yield document_chunk

View file

@ -13,7 +13,7 @@ import aiofiles
import pandas as pd
from pydantic import BaseModel
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.modules.graph.exceptions import EntityNotFoundError
from cognee.modules.ingestion.exceptions import IngestionError
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client

View file

@ -1,9 +1,11 @@
import json
import inspect
from os import path
from uuid import UUID
from typing import Union, BinaryIO, Any, List, Optional
import cognee.modules.ingestion as ingestion
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Data
from cognee.modules.users.models import User
@ -18,9 +20,6 @@ from cognee.modules.data.methods import (
from .save_data_item_to_storage import save_data_item_to_storage
from cognee.api.v1.add.config import get_s3_config
async def ingest_data(
data: Any,
dataset_name: str,
@ -31,23 +30,6 @@ async def ingest_data(
if not user:
user = await get_default_user()
s3_config = get_s3_config()
fs = None
if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None:
import s3fs
fs = s3fs.S3FileSystem(
key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, anon=False
)
def open_data_file(file_path: str):
if file_path.startswith("s3://"):
return fs.open(file_path, mode="rb")
else:
local_path = file_path.replace("file://", "")
return open(local_path, mode="rb")
def get_external_metadata_dict(data_item: Union[BinaryIO, str, Any]) -> dict[str, Any]:
if hasattr(data_item, "dict") and inspect.ismethod(getattr(data_item, "dict")):
return {"metadata": data_item.dict(), "origin": str(type(data_item))}
@ -92,11 +74,11 @@ async def ingest_data(
dataset_data_map = {str(data.id): True for data in dataset_data}
for data_item in data:
file_path = await save_data_item_to_storage(data_item, dataset_name)
file_path = await save_data_item_to_storage(data_item)
# Ingest data and add metadata
with open_data_file(file_path) as file:
classified_data = ingestion.classify(file, s3fs=fs)
async with open_data_file(file_path) as file:
classified_data = ingestion.classify(file)
# data_id is the hash of file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, user)

View file

@ -3,30 +3,49 @@ from typing import Union, BinaryIO, Any
from cognee.modules.ingestion.exceptions import IngestionError
from cognee.modules.ingestion import save_data_to_file
from pydantic_settings import BaseSettings, SettingsConfigDict
async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any], dataset_name: str) -> str:
class SaveDataSettings(BaseSettings):
accept_local_file_path: bool = True
model_config = SettingsConfigDict(env_file=".env", extra="allow")
settings = SaveDataSettings()
async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str:
if "llama_index" in str(type(data_item)):
# Dynamic import is used because the llama_index module is optional.
from .transform_data import get_data_from_llama_index
file_path = get_data_from_llama_index(data_item, dataset_name)
file_path = await get_data_from_llama_index(data_item)
# data is a file object coming from upload.
elif hasattr(data_item, "file"):
file_path = save_data_to_file(data_item.file, filename=data_item.filename)
file_path = await save_data_to_file(data_item.file, filename=data_item.filename)
elif isinstance(data_item, str):
if data_item.startswith("s3://"):
# data is s3 file or local file path
if data_item.startswith("s3://") or data_item.startswith("file://"):
file_path = data_item
# data is a file path
elif data_item.startswith("file://") or data_item.startswith("/"):
if os.getenv("ACCEPT_LOCAL_FILE_PATH", "true").lower() == "false":
elif data_item.startswith("/") or (
os.name == "nt" and len(data_item) > 1 and data_item[1] == ":"
):
# Handle both Unix absolute paths (/path) and Windows absolute paths (C:\path)
if settings.accept_local_file_path:
# Normalize path separators before creating file URL
normalized_path = os.path.normpath(data_item)
# Use forward slashes in file URLs for consistency
url_path = normalized_path.replace(os.sep, "/")
file_path = "file://" + url_path
else:
raise IngestionError(message="Local files are not accepted.")
file_path = data_item.replace("file://", "")
# data is text
else:
file_path = save_data_to_file(data_item)
file_path = await save_data_to_file(data_item)
else:
raise IngestionError(message=f"Data type not supported: {type(data_item)}")

View file

@ -4,7 +4,7 @@ from cognee.modules.ingestion import save_data_to_file
from typing import Union
def get_data_from_llama_index(data_point: Union[Document, ImageDocument], dataset_name: str) -> str:
async def get_data_from_llama_index(data_point: Union[Document, ImageDocument]) -> str:
"""
Retrieve the file path based on the data point type.
@ -17,7 +17,6 @@ def get_data_from_llama_index(data_point: Union[Document, ImageDocument], datase
- data_point (Union[Document, ImageDocument]): An instance of Document or
ImageDocument to extract data from.
- dataset_name (str): The name of the dataset associated with the data point.
Returns:
--------
@ -29,11 +28,11 @@ def get_data_from_llama_index(data_point: Union[Document, ImageDocument], datase
if isinstance(data_point, Document) and type(data_point) is Document:
file_path = data_point.metadata.get("file_path")
if file_path is None:
file_path = save_data_to_file(data_point.text)
file_path = await save_data_to_file(data_point.text)
return file_path
return file_path
elif isinstance(data_point, ImageDocument) and type(data_point) is ImageDocument:
if data_point.image_path is None:
file_path = save_data_to_file(data_point.text)
file_path = await save_data_to_file(data_point.text)
return file_path
return data_point.image_path

View file

@ -1,32 +0,0 @@
import asyncio
from cognee.shared.utils import render_graph
from cognee.infrastructure.databases.graph import get_graph_engine
if __name__ == "__main__":
async def main():
import os
import pathlib
import cognee
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_library")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_library")
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
graph_client = await get_graph_engine()
graph = graph_client.graph
graph_url = await render_graph(graph)
print(graph_url)
asyncio.run(main())

View file

@ -1,8 +1,11 @@
import sys
import uuid
import pytest
from unittest.mock import patch
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
import sys
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
@ -38,7 +41,8 @@ TEST_TEXT = """
@patch.object(
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
)
def test_AudioDocument(mock_engine):
@pytest.mark.asyncio
async def test_AudioDocument(mock_engine):
document = AudioDocument(
id=uuid.uuid4(),
name="audio-dummy-test",
@ -47,7 +51,7 @@ def test_AudioDocument(mock_engine):
mime_type="",
)
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip(
async for ground_truth, paragraph_data in async_gen_zip(
GROUND_TRUTH,
document.read(chunker_cls=TextChunker, max_chunk_size=64),
):

View file

@ -1,9 +1,12 @@
import sys
import uuid
import pytest
from unittest.mock import patch
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
import sys
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
@ -21,7 +24,8 @@ The commotion has attracted an audience: a murder of crows has gathered in the l
@patch.object(
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
)
def test_ImageDocument(mock_engine):
@pytest.mark.asyncio
async def test_ImageDocument(mock_engine):
document = ImageDocument(
id=uuid.uuid4(),
name="image-dummy-test",
@ -30,7 +34,7 @@ def test_ImageDocument(mock_engine):
mime_type="",
)
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip(
async for ground_truth, paragraph_data in async_gen_zip(
GROUND_TRUTH,
document.read(chunker_cls=TextChunker, max_chunk_size=64),
):

View file

@ -1,10 +1,13 @@
import os
import sys
import uuid
import pytest
from unittest.mock import patch
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.PdfDocument import PdfDocument
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
from unittest.mock import patch
import sys
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
@ -18,7 +21,8 @@ GROUND_TRUTH = [
@patch.object(
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
)
def test_PdfDocument(mock_engine):
@pytest.mark.asyncio
async def test_PdfDocument(mock_engine):
test_file_path = os.path.join(
os.sep,
*(os.path.dirname(__file__).split(os.sep)[:-2]),
@ -33,7 +37,7 @@ def test_PdfDocument(mock_engine):
mime_type="",
)
for ground_truth, paragraph_data in zip(
async for ground_truth, paragraph_data in async_gen_zip(
GROUND_TRUTH, document.read(chunker_cls=TextChunker, max_chunk_size=1024)
):
assert ground_truth["word_count"] == paragraph_data.chunk_size, (

View file

@ -1,12 +1,13 @@
import os
import sys
import uuid
import pytest
from unittest.mock import patch
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.TextDocument import TextDocument
from unittest.mock import patch
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
import sys
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
@ -30,7 +31,8 @@ GROUND_TRUTH = {
@patch.object(
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
)
def test_TextDocument(mock_engine, input_file, chunk_size):
@pytest.mark.asyncio
async def test_TextDocument(mock_engine, input_file, chunk_size):
test_file_path = os.path.join(
os.sep,
*(os.path.dirname(__file__).split(os.sep)[:-2]),
@ -45,7 +47,7 @@ def test_TextDocument(mock_engine, input_file, chunk_size):
mime_type="",
)
for ground_truth, paragraph_data in zip(
async for ground_truth, paragraph_data in async_gen_zip(
GROUND_TRUTH[input_file],
document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size),
):

View file

@ -1,10 +1,12 @@
import os
import sys
import uuid
import pytest
from unittest.mock import patch
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.data.processing.document_types.UnstructuredDocument import UnstructuredDocument
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
import sys
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
@ -12,7 +14,8 @@ chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentenc
@patch.object(
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
)
def test_UnstructuredDocument(mock_engine):
@pytest.mark.asyncio
async def test_UnstructuredDocument(mock_engine):
# Define file paths of test data
pptx_file_path = os.path.join(
os.sep,
@ -76,7 +79,7 @@ def test_UnstructuredDocument(mock_engine):
)
# Test PPTX
for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
async for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 19 == paragraph_data.chunk_size, f" 19 != {paragraph_data.chunk_size = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, (
@ -84,7 +87,7 @@ def test_UnstructuredDocument(mock_engine):
)
# Test DOCX
for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
async for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 16 == paragraph_data.chunk_size, f" 16 != {paragraph_data.chunk_size = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert "sentence_end" == paragraph_data.cut_type, (
@ -92,7 +95,7 @@ def test_UnstructuredDocument(mock_engine):
)
# TEST CSV
for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
async for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 15 == paragraph_data.chunk_size, f" 15 != {paragraph_data.chunk_size = }"
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
f"Read text doesn't match expected text: {paragraph_data.text}"
@ -102,7 +105,7 @@ def test_UnstructuredDocument(mock_engine):
)
# Test XLSX
for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
async for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
assert 36 == paragraph_data.chunk_size, f" 36 != {paragraph_data.chunk_size = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, (

View file

@ -0,0 +1,12 @@
async def async_gen_zip(iterable1, async_iterable2):
it1 = iter(iterable1)
it2 = async_iterable2.__aiter__()
while True:
try:
val1 = next(it1)
val2 = await it2.__anext__()
yield val1, val2
except (StopIteration, StopAsyncIteration):
break

View file

@ -1,8 +1,9 @@
import os
from cognee.shared.logging_utils import get_logger
import pathlib
import cognee
import cognee
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.data.models import Data
from cognee.modules.users.methods import get_default_user
from cognee.modules.search.types import SearchType
@ -24,12 +25,12 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile(data.raw_data_location), (
assert os.path.isfile(data.raw_data_location.replace("file://", "")), (
f"Data location doesn't exist: {data.raw_data_location}"
)
# Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id)
assert not os.path.exists(data.raw_data_location), (
assert not os.path.exists(data.raw_data_location.replace("file://", "")), (
f"Data location still exists after deletion: {data.raw_data_location}"
)
@ -38,12 +39,12 @@ async def test_local_file_deletion(data_text, file_location):
data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
).one()
assert os.path.isfile(data.raw_data_location), (
assert os.path.isfile(data.raw_data_location.replace("file://", "")), (
f"Data location doesn't exist: {data.raw_data_location}"
)
# Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id)
assert os.path.exists(data.raw_data_location), (
assert os.path.exists(data.raw_data_location.replace("file://", "")), (
f"Data location doesn't exists: {data.raw_data_location}"
)
@ -157,7 +158,8 @@ async def main():
assert len(history) == 8, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
tables_in_database = await vector_engine.get_collection_names()

View file

@ -5,7 +5,6 @@ from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.shared.utils import render_graph
from cognee.low_level import DataPoint
logger = get_logger()
@ -57,9 +56,6 @@ async def main():
await cognee.cognify(graph_model=ProgrammingLanguage)
url = await render_graph()
print(f"Graphistry URL: {url}")
graph_file_path = str(
pathlib.Path(
os.path.join(

View file

@ -2,6 +2,7 @@ import hashlib
import os
from cognee.shared.logging_utils import get_logger
import pathlib
import pytest
import cognee
from cognee.infrastructure.databases.relational import get_relational_engine
@ -101,6 +102,7 @@ async def test_deduplication():
await cognee.prune.prune_system(metadata=True)
@pytest.mark.asyncio
async def test_deduplication_postgres():
cognee.config.set_vector_db_config(
{"vector_db_url": "", "vector_db_key": "", "vector_db_provider": "pgvector"}
@ -119,6 +121,7 @@ async def test_deduplication_postgres():
await test_deduplication()
@pytest.mark.asyncio
async def test_deduplication_sqlite():
cognee.config.set_vector_db_config(
{"vector_db_url": "", "vector_db_key": "", "vector_db_provider": "lancedb"}

View file

@ -1,11 +1,11 @@
import os
import cognee
import pathlib
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
# from cognee.shared.utils import render_graph
logger = get_logger()
@ -44,8 +44,6 @@ async def main():
await cognee.cognify([dataset_name])
# await render_graph(None, include_labels = True, include_nodes = True)
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
@ -81,7 +79,8 @@ async def main():
# Assert local data files are cleaned properly
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
# Assert relational, vector and graph databases have been cleaned properly
await cognee.prune.prune_system(metadata=True)

View file

@ -3,6 +3,7 @@ import shutil
import cognee
import pathlib
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.engine.models import NodeSet
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.shared.logging_utils import get_logger
@ -112,7 +113,8 @@ async def main():
)
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
from cognee.infrastructure.databases.graph import get_graph_engine

View file

@ -1,6 +1,7 @@
import os
import pathlib
import cognee
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
@ -78,7 +79,8 @@ async def main():
# Assert local data files are cleaned properly
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
# Assert relational, vector and graph databases have been cleaned properly
await cognee.prune.prune_system(metadata=True)
@ -89,16 +91,28 @@ async def main():
from cognee.infrastructure.databases.relational import get_relational_engine
with open(get_relational_engine().db_path, "r") as file:
content = file.read()
assert content == "", "SQLite relational database is not empty"
db_path = get_relational_engine().db_path
dir_path = os.path.dirname(db_path)
file_path = os.path.basename(db_path)
file_storage = get_file_storage(dir_path)
assert not await file_storage.file_exists(file_path), (
"SQLite relational database is not deleted"
)
from cognee.infrastructure.databases.graph import get_graph_config
graph_config = get_graph_config()
assert not os.path.exists(graph_config.graph_file_path) or not os.listdir(
graph_config.graph_file_path
), "Kuzu graph directory is not empty"
# For Kuzu v0.11.0+, check if database file doesn't exist (single-file format with .kuzu extension)
# For older versions or other providers, check if directory is empty
if graph_config.graph_database_provider.lower() == "kuzu":
assert not os.path.exists(graph_config.graph_file_path), (
"Kuzu graph database file still exists"
)
else:
assert not os.path.exists(graph_config.graph_file_path) or not os.listdir(
graph_config.graph_file_path
), "Graph database directory is not empty"
if __name__ == "__main__":

View file

@ -2,6 +2,7 @@ import os
import pathlib
import cognee
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
@ -91,7 +92,8 @@ async def main():
assert len(history) == 8, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
from cognee.infrastructure.databases.graph import get_graph_engine

View file

@ -1,6 +1,7 @@
import os
import pathlib
import cognee
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
@ -87,7 +88,8 @@ async def main():
assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
milvus_client = get_vector_engine().get_milvus_client()

View file

@ -1,6 +1,7 @@
import os
import pathlib
import cognee
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
@ -116,7 +117,8 @@ async def main():
)
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
from cognee.infrastructure.databases.graph import get_graph_engine

View file

@ -1,6 +1,7 @@
import os
import pathlib
import cognee
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.search.operations import get_history
from cognee.shared.logging_utils import get_logger
from cognee.modules.data.models import Data
@ -23,26 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
data_hash = hashlib.md5(encoded_text).hexdigest()
# Get data entry from database based on hash contents
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
assert os.path.isfile(data.raw_data_location), (
assert os.path.isfile(data.raw_data_location.replace("file://", "")), (
f"Data location doesn't exist: {data.raw_data_location}"
)
# Test deletion of data along with local files created by cognee
await engine.delete_data_entity(data.id)
assert not os.path.exists(data.raw_data_location), (
assert not os.path.exists(data.raw_data_location.replace("file://", "")), (
f"Data location still exists after deletion: {data.raw_data_location}"
)
async with engine.get_async_session() as session:
# Get data entry from database based on file path
data = (
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
await session.scalars(
select(Data).where(Data.raw_data_location == "file://" + file_location)
)
).one()
assert os.path.isfile(data.raw_data_location), (
assert os.path.isfile(data.raw_data_location.replace("file://", "")), (
f"Data location doesn't exist: {data.raw_data_location}"
)
# Test local files not created by cognee won't get deleted
await engine.delete_data_entity(data.id)
assert os.path.exists(data.raw_data_location), (
assert os.path.exists(data.raw_data_location.replace("file://", "")), (
f"Data location doesn't exists: {data.raw_data_location}"
)
@ -164,7 +167,8 @@ async def main():
await test_local_file_deletion(text, explanation_file_path)
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
tables_in_database = await vector_engine.get_table_names()

View file

@ -1,6 +1,7 @@
import os
import pathlib
import cognee
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
@ -83,7 +84,8 @@ async def main():
assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
qdrant_client = get_vector_engine().get_qdrant_client()

View file

@ -2,11 +2,11 @@ import os
import shutil
import cognee
import pathlib
from cognee.infrastructure.files.storage import get_storage_config
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.infrastructure.databases.graph.config import get_graph_config
logger = get_logger()
@ -93,7 +93,8 @@ async def main():
assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
from cognee.infrastructure.databases.graph import get_graph_engine

View file

@ -1,6 +1,7 @@
import os
import pathlib
import cognee
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
@ -79,7 +80,8 @@ async def main():
assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
collections = await get_vector_engine().client.collections.list_all()

View file

@ -2,13 +2,15 @@ import os
import tempfile
import pytest
from pathlib import Path
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
from cognee.infrastructure.files.utils.open_data_file import open_data_file
class TestOpenDataFile:
"""Test cases for open_data_file function with file:// URL handling."""
def test_regular_file_path(self):
@pytest.mark.asyncio
async def test_regular_file_path(self):
"""Test that regular file paths work as before."""
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
test_content = "Test content for regular file path"
@ -16,13 +18,14 @@ class TestOpenDataFile:
temp_file_path = f.name
try:
with open_data_file(temp_file_path, mode="r") as f:
async with open_data_file(temp_file_path, mode="r") as f:
content = f.read()
assert content == test_content
finally:
os.unlink(temp_file_path)
def test_file_url_text_mode(self):
@pytest.mark.asyncio
async def test_file_url_text_mode(self):
"""Test that file:// URLs work correctly in text mode."""
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
test_content = "Test content for file:// URL handling"
@ -32,13 +35,14 @@ class TestOpenDataFile:
try:
# Use pathlib.Path.as_uri() for proper cross-platform file URL creation
file_url = Path(temp_file_path).as_uri()
with open_data_file(file_url, mode="r") as f:
async with open_data_file(file_url, mode="r") as f:
content = f.read()
assert content == test_content
finally:
os.unlink(temp_file_path)
def test_file_url_binary_mode(self):
@pytest.mark.asyncio
async def test_file_url_binary_mode(self):
"""Test that file:// URLs work correctly in binary mode."""
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
test_content = "Test content for binary mode"
@ -48,13 +52,14 @@ class TestOpenDataFile:
try:
# Use pathlib.Path.as_uri() for proper cross-platform file URL creation
file_url = Path(temp_file_path).as_uri()
with open_data_file(file_url, mode="rb") as f:
async with open_data_file(file_url, mode="rb") as f:
content = f.read()
assert content == test_content.encode()
finally:
os.unlink(temp_file_path)
def test_file_url_with_encoding(self):
@pytest.mark.asyncio
async def test_file_url_with_encoding(self):
"""Test that file:// URLs work with specific encoding."""
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".txt", encoding="utf-8"
@ -66,20 +71,22 @@ class TestOpenDataFile:
try:
# Use pathlib.Path.as_uri() for proper cross-platform file URL creation
file_url = Path(temp_file_path).as_uri()
with open_data_file(file_url, mode="r", encoding="utf-8") as f:
async with open_data_file(file_url, mode="r", encoding="utf-8") as f:
content = f.read()
assert content == test_content
finally:
os.unlink(temp_file_path)
def test_file_url_nonexistent_file(self):
@pytest.mark.asyncio
async def test_file_url_nonexistent_file(self):
"""Test that file:// URLs raise appropriate error for nonexistent files."""
file_url = "file:///nonexistent/path/to/file.txt"
with pytest.raises(FileNotFoundError):
with open_data_file(file_url, mode="r") as f:
async with open_data_file(file_url, mode="r") as f:
f.read()
def test_multiple_file_prefixes(self):
@pytest.mark.asyncio
async def test_multiple_file_prefixes(self):
"""Test that multiple file:// prefixes are handled correctly."""
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
test_content = "Test content"
@ -91,7 +98,7 @@ class TestOpenDataFile:
# Use proper file URL creation first
proper_file_url = Path(temp_file_path).as_uri()
file_url = f"file://{proper_file_url}"
with open_data_file(file_url, mode="r") as f:
async with open_data_file(file_url, mode="r") as f:
content = f.read()
# This should work because we only replace the first occurrence
assert content == test_content

View file

@ -16,11 +16,11 @@ class TestChunksRetriever:
@pytest.mark.asyncio
async def test_chunk_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunks_context_simple"
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunks_context_simple"
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
@ -162,11 +162,11 @@ class TestChunksRetriever:
@pytest.mark.asyncio
async def test_chunk_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_empty"
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_empty"
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
)
cognee.config.data_root_directory(data_directory_path)
@ -183,16 +183,3 @@ class TestChunksRetriever:
context = await retriever.get_context("Christina Mayer")
assert len(context) == 0, "Found chunks when none should exist"
if __name__ == "__main__":
from asyncio import run
test = TestChunksRetriever()
async def main():
await test.test_chunk_context_simple()
await test.test_chunk_context_complex()
await test.test_chunk_context_on_empty_graph()
run(main())

View file

@ -12,15 +12,17 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor
)
class TestGraphCompletionRetriever:
class TestGraphCompletionWithContextExtensionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_extension_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
pathlib.Path(__file__).parent,
".cognee_system/test_graph_completion_extension_context_simple",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
pathlib.Path(__file__).parent,
".data_storage/test_graph_completion_extension_context_simple",
)
cognee.config.data_root_directory(data_directory_path)
@ -64,11 +66,13 @@ class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_extension_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
pathlib.Path(__file__).parent,
".cognee_system/test_graph_completion_extension_context_complex",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
pathlib.Path(__file__).parent,
".data_storage/test_graph_completion_extension_context_complex",
)
cognee.config.data_root_directory(data_directory_path)
@ -143,11 +147,13 @@ class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_get_graph_completion_extension_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
pathlib.Path(__file__).parent,
".cognee_system/test_get_graph_completion_extension_context_on_empty_graph",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
pathlib.Path(__file__).parent,
".data_storage/test_get_graph_completion_extension_context_on_empty_graph",
)
cognee.config.data_root_directory(data_directory_path)
@ -170,16 +176,3 @@ class TestGraphCompletionRetriever:
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
if __name__ == "__main__":
from asyncio import run
test = TestGraphCompletionRetriever()
async def main():
await test.test_graph_completion_context_simple()
await test.test_graph_completion_context_complex()
await test.test_get_graph_completion_context_on_empty_graph()
run(main())

View file

@ -10,15 +10,15 @@ from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
class TestGraphCompletionRetriever:
class TestGraphCompletionCoTRetriever:
@pytest.mark.asyncio
async def test_graph_completion_cot_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
@ -62,11 +62,12 @@ class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_cot_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
pathlib.Path(__file__).parent,
".cognee_system/test_graph_completion_cot_context_complex",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
@ -141,11 +142,13 @@ class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_get_graph_completion_cot_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
pathlib.Path(__file__).parent,
".cognee_system/test_get_graph_completion_cot_context_on_empty_graph",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
pathlib.Path(__file__).parent,
".data_storage/test_get_graph_completion_cot_context_on_empty_graph",
)
cognee.config.data_root_directory(data_directory_path)
@ -168,16 +171,3 @@ class TestGraphCompletionRetriever:
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
if __name__ == "__main__":
from asyncio import run
test = TestGraphCompletionRetriever()
async def main():
await test.test_graph_completion_context_simple()
await test.test_graph_completion_context_complex()
await test.test_get_graph_completion_context_on_empty_graph()
run(main())

View file

@ -14,11 +14,11 @@ class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
@ -55,11 +55,11 @@ class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_graph_completion_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
@ -127,11 +127,13 @@ class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_get_graph_completion_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
pathlib.Path(__file__).parent,
".cognee_system/test_get_graph_completion_context_on_empty_graph",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
pathlib.Path(__file__).parent,
".data_storage/test_get_graph_completion_context_on_empty_graph",
)
cognee.config.data_root_directory(data_directory_path)
@ -147,16 +149,3 @@ class TestGraphCompletionRetriever:
context = await retriever.get_context("Who works at Figma?")
assert context == "", "Context should be empty on an empty graph"
if __name__ == "__main__":
from asyncio import run
test = TestGraphCompletionRetriever()
async def main():
await test.test_graph_completion_context_simple()
await test.test_graph_completion_context_complex()
await test.test_get_graph_completion_context_on_empty_graph()
run(main())

View file

@ -227,11 +227,11 @@ class TestInsightsRetriever:
@pytest.mark.asyncio
async def test_insights_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_empty"
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_empty"
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_on_empty_graph"
)
cognee.config.data_root_directory(data_directory_path)
@ -249,13 +249,3 @@ class TestInsightsRetriever:
context = await retriever.get_context("Christina Mayer")
assert context == [], "Returned context should be empty on an empty graph"
if __name__ == "__main__":
from asyncio import run
test = TestInsightsRetriever()
run(test.test_insights_context_simple())
run(test.test_insights_context_complex())
run(test.test_insights_context_on_empty_graph())

View file

@ -16,11 +16,11 @@ class TestRAGCompletionRetriever:
@pytest.mark.asyncio
async def test_rag_completion_context_simple(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context"
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_rag_context"
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple"
)
cognee.config.data_root_directory(data_directory_path)
@ -73,11 +73,11 @@ class TestRAGCompletionRetriever:
@pytest.mark.asyncio
async def test_rag_completion_context_complex(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex"
)
cognee.config.data_root_directory(data_directory_path)
@ -163,11 +163,13 @@ class TestRAGCompletionRetriever:
@pytest.mark.asyncio
async def test_get_rag_completion_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
pathlib.Path(__file__).parent,
".cognee_system/test_get_rag_completion_context_on_empty_graph",
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
pathlib.Path(__file__).parent,
".data_storage/test_get_rag_completion_context_on_empty_graph",
)
cognee.config.data_root_directory(data_directory_path)
@ -184,13 +186,3 @@ class TestRAGCompletionRetriever:
context = await retriever.get_context("Christina Mayer")
assert context == "", "Returned context should be empty on an empty graph"
if __name__ == "__main__":
from asyncio import run
test = TestRAGCompletionRetriever()
run(test.test_rag_completion_context_simple())
run(test.test_rag_completion_context_complex())
run(test.test_get_rag_completion_context_on_empty_graph())

View file

@ -17,11 +17,11 @@ class TextSummariesRetriever:
@pytest.mark.asyncio
async def test_chunk_context(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
)
cognee.config.data_root_directory(data_directory_path)
@ -136,11 +136,11 @@ class TextSummariesRetriever:
@pytest.mark.asyncio
async def test_chunk_context_on_empty_graph(self):
system_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
)
cognee.config.system_root_directory(system_directory_path)
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
)
cognee.config.data_root_directory(data_directory_path)
@ -157,12 +157,3 @@ class TextSummariesRetriever:
context = await retriever.get_context("Christina Mayer")
assert context == [], "Returned context should be empty on an empty graph"
if __name__ == "__main__":
from asyncio import run
test = TextSummariesRetriever()
run(test.test_chunk_context())
run(test.test_chunk_context_on_empty_graph())

View file

@ -1,4 +1,5 @@
import os
import tempfile
import pytest
import networkx as nx
import pandas as pd
@ -6,12 +7,9 @@ from unittest.mock import patch, mock_open
from io import BytesIO
from uuid import uuid4
from cognee.shared.utils import (
get_anonymous_id,
get_file_content_hash,
prepare_edges,
prepare_nodes,
)
from cognee.infrastructure.files.utils.get_file_content_hash import get_file_content_hash
from cognee.shared.utils import get_anonymous_id
@pytest.fixture
@ -28,48 +26,31 @@ def test_get_anonymous_id(mock_open_file, mock_makedirs, temp_dir):
assert len(anon_id) > 0
# @patch("requests.post")
# def test_send_telemetry(mock_post):
# mock_post.return_value.status_code = 200
#
# send_telemetry("test_event", "test_user", {"key": "value"})
# mock_post.assert_called_once()
#
# args, kwargs = mock_post.call_args
# assert kwargs["json"]["event_name"] == "test_event"
@pytest.mark.asyncio
async def test_get_file_content_hash_file():
temp_file_path = None
text_content = "Test content with UTF-8: café ☕"
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt", encoding="utf-8") as f:
test_content = text_content
f.write(test_content)
temp_file_path = f.name
@patch("builtins.open", new_callable=mock_open, read_data=b"test_data")
def test_get_file_content_hash_file(mock_open_file):
import hashlib
expected_hash = hashlib.md5(b"test_data").hexdigest()
result = get_file_content_hash("test_file.txt")
assert result == expected_hash
try:
expected_hash = hashlib.md5(text_content.encode("utf-8")).hexdigest()
result = await get_file_content_hash(temp_file_path)
assert result == expected_hash
finally:
os.unlink(temp_file_path)
def test_get_file_content_hash_stream():
@pytest.mark.asyncio
async def test_get_file_content_hash_stream():
stream = BytesIO(b"test_data")
import hashlib
expected_hash = hashlib.md5(b"test_data").hexdigest()
result = get_file_content_hash(stream)
result = await get_file_content_hash(stream)
assert result == expected_hash
def test_prepare_edges():
graph = nx.MultiDiGraph()
graph.add_edge("A", "B", key="AB", weight=1)
edges_df = prepare_edges(graph, "source", "target", "key")
assert isinstance(edges_df, pd.DataFrame)
assert len(edges_df) == 1
def test_prepare_nodes():
graph = nx.Graph()
graph.add_node(1, name="Node1")
nodes_df = prepare_nodes(graph)
assert isinstance(nodes_df, pd.DataFrame)
assert len(nodes_df) == 1

View file

@ -6,7 +6,6 @@ from cognee import visualize_graph
from cognee.low_level import setup, DataPoint
from cognee.pipelines import run_tasks, Task
from cognee.tasks.storage import add_data_points
from cognee.shared.utils import render_graph
class Person(DataPoint):
@ -80,9 +79,6 @@ async def main():
async for status in pipeline:
print(status)
# Get a graphistry url (Register for a free account at https://www.graphistry.com)
await render_graph()
# Or use our simple graph preview
graph_file_path = str(
os.path.join(os.path.dirname(__file__), ".artifacts/graph_visualization.html")

View file

@ -10,7 +10,6 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.low_level import setup, DataPoint
from cognee.pipelines import run_tasks, Task
from cognee.tasks.storage import add_data_points
from cognee.shared.utils import render_graph
class Products(DataPoint):
@ -119,9 +118,6 @@ async def main():
async for status in pipeline:
print(status)
# Get a graphistry url (Register for a free account at https://www.graphistry.com)
await render_graph()
graph_engine = await get_graph_engine()
products_results = await graph_engine.query(

2557
poetry.lock generated

File diff suppressed because it is too large Load diff

Some files were not shown because too many files have changed in this diff Show more