feat: s3 storage (#988)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: vasilije <vas.markovic@gmail.com> Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
This commit is contained in:
parent
4bcb893a54
commit
46c4463cb2
102 changed files with 10328 additions and 9216 deletions
26
.gitguardian.yml
Normal file
26
.gitguardian.yml
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
# .gitguardian.yml
|
||||
version: v1
|
||||
|
||||
secret-scan:
|
||||
# Ignore specific files
|
||||
excluded-paths:
|
||||
- '.env.template'
|
||||
- '.github/workflows/*.yml'
|
||||
- 'examples/**'
|
||||
- 'tests/**'
|
||||
|
||||
# Ignore specific patterns
|
||||
excluded-detectors:
|
||||
- 'Generic Password'
|
||||
- 'Generic High Entropy Secret'
|
||||
|
||||
# Ignore by commit (if needed)
|
||||
excluded-commits:
|
||||
- '782bbb4'
|
||||
|
||||
# Custom rules for template files
|
||||
paths-ignore:
|
||||
- path: '.env.template'
|
||||
comment: 'Template file with placeholder values'
|
||||
- path: '.github/workflows/search_db_tests.yml'
|
||||
comment: 'Test workflow with test credentials'
|
||||
2
.github/workflows/python_version_tests.yml
vendored
2
.github/workflows/python_version_tests.yml
vendored
|
|
@ -44,7 +44,7 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
python-version: ${{ fromJSON(inputs.python-versions) }}
|
||||
os: [ubuntu-22.04, macos-13, macos-15, windows-latest]
|
||||
os: [ubuntu-22.04, macos-13, macos-15]
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Check out
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.engine import Connection
|
|||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
from alembic import context
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine, get_relational_config
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
|
@ -86,12 +86,6 @@ def run_migrations_online() -> None:
|
|||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
if db_engine.engine.dialect.name == "sqlite":
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
||||
db_config = get_relational_config()
|
||||
LocalStorage.ensure_directory_exists(db_config.db_path)
|
||||
|
||||
print("Using database:", db_engine.db_uri)
|
||||
|
||||
config.set_section_option(
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ requires-python = ">=3.10"
|
|||
dependencies = [
|
||||
# For local cognee repo usage remove comment bellow and add absolute path to cognee. Then run `uv sync --reinstall` in the mcp folder on local cognee changes.
|
||||
#"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j] @ file:/Users/<username>/Desktop/cognee",
|
||||
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]==0.2.0",
|
||||
"fastmcp>=1.0",
|
||||
"mcp==1.5.0",
|
||||
"uv>=0.6.3",
|
||||
"cognee[postgres,codegraph,gemini,huggingface,docs,neo4j]>=0.2.0,<1.0.0",
|
||||
"fastmcp>=1.0,<2.0.0",
|
||||
"mcp>=1.11.0,<2.0.0",
|
||||
"uv>=0.6.3,<1.0.0",
|
||||
]
|
||||
|
||||
authors = [
|
||||
|
|
@ -29,7 +29,7 @@ packages = ["src"]
|
|||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"debugpy>=1.8.12",
|
||||
"debugpy>=1.8.12,<2.0.0",
|
||||
]
|
||||
|
||||
[tool.hatch.metadata]
|
||||
|
|
|
|||
5890
cognee-mcp/uv.lock
generated
5890
cognee-mcp/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -7,5 +7,5 @@ readme = "README.md"
|
|||
requires-python = ">=3.10, <=3.13"
|
||||
|
||||
dependencies = [
|
||||
"cognee>=0.1.38",
|
||||
"cognee>=0.1.38,<1.0.0",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import asyncio
|
|||
import pathlib
|
||||
from cognee import config, add, cognify, search, SearchType, prune, visualize_graph
|
||||
|
||||
# from cognee.shared.utils import render_graph
|
||||
from cognee.low_level import DataPoint
|
||||
|
||||
|
||||
|
|
@ -50,10 +49,6 @@ async def main():
|
|||
# Cognify the text data.
|
||||
await cognify(graph_model=ProgrammingLanguage)
|
||||
|
||||
# # Get a graphistry url (Register for a free account at https://www.graphistry.com)
|
||||
# url = await render_graph()
|
||||
# print(f"Graphistry URL: {url}")
|
||||
|
||||
# Or use our simple graph preview
|
||||
graph_file_path = str(
|
||||
pathlib.Path(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import os
|
|||
import asyncio
|
||||
import pathlib
|
||||
from cognee import config, add, cognify, search, SearchType, prune, visualize_graph
|
||||
# from cognee.shared.utils import render_graph
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -30,10 +29,6 @@ async def main():
|
|||
# Cognify the text data.
|
||||
await cognify()
|
||||
|
||||
# # Get a graphistry url (Register for a free account at https://www.graphistry.com)
|
||||
# url = await render_graph()
|
||||
# print(f"Graphistry URL: {url}")
|
||||
|
||||
# Or use our simple graph preview
|
||||
graph_file_path = str(
|
||||
pathlib.Path(
|
||||
|
|
|
|||
|
|
@ -9,21 +9,28 @@ from cognee.infrastructure.databases.vector import get_vectordb_config
|
|||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.databases.relational import get_relational_config, get_migration_config
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
||||
|
||||
class config:
|
||||
@staticmethod
|
||||
def system_root_directory(system_root_directory: str):
|
||||
databases_directory_path = os.path.join(system_root_directory, "databases")
|
||||
base_config = get_base_config()
|
||||
base_config.system_root_directory = os.path.join(system_root_directory, ".cognee_system")
|
||||
|
||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
|
||||
relational_config = get_relational_config()
|
||||
relational_config.db_path = databases_directory_path
|
||||
LocalStorage.ensure_directory_exists(databases_directory_path)
|
||||
|
||||
graph_config = get_graph_config()
|
||||
graph_file_name = graph_config.graph_filename
|
||||
graph_config.graph_file_path = os.path.join(databases_directory_path, graph_file_name)
|
||||
# For Kuzu v0.11.0+, use single-file database with .kuzu extension
|
||||
if graph_config.graph_database_provider.lower() == "kuzu":
|
||||
graph_config.graph_file_path = os.path.join(
|
||||
databases_directory_path, f"{graph_file_name}.kuzu"
|
||||
)
|
||||
else:
|
||||
graph_config.graph_file_path = os.path.join(databases_directory_path, graph_file_name)
|
||||
|
||||
vector_config = get_vectordb_config()
|
||||
if vector_config.vector_db_provider == "lancedb":
|
||||
|
|
@ -32,7 +39,7 @@ class config:
|
|||
@staticmethod
|
||||
def data_root_directory(data_root_directory: str):
|
||||
base_config = get_base_config()
|
||||
base_config.data_root_directory = data_root_directory
|
||||
base_config.data_root_directory = os.path.join(data_root_directory, ".data_storage")
|
||||
|
||||
@staticmethod
|
||||
def monitoring_tool(monitoring_tool: object):
|
||||
|
|
|
|||
|
|
@ -1,22 +1,25 @@
|
|||
from typing import Union, BinaryIO, List
|
||||
from cognee.modules.ingestion import classify
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.sql import delete as sql_delete
|
||||
from cognee.modules.data.models import Data, DatasetData, Dataset
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from io import BytesIO
|
||||
import os
|
||||
import hashlib
|
||||
from uuid import UUID
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from io import BytesIO
|
||||
from sqlalchemy import select
|
||||
from typing import Union, BinaryIO, List
|
||||
from sqlalchemy.sql import delete as sql_delete
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.ingestion import classify
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.data.models import Data, DatasetData, Dataset
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from .exceptions import DocumentNotFoundError, DatasetNotFoundError, DocumentSubgraphNotFoundError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -56,7 +59,14 @@ async def delete(
|
|||
# Handle different input types
|
||||
if isinstance(data, str):
|
||||
if data.startswith("file://") or data.startswith("/"): # It's a file path
|
||||
with open(data.replace("file://", ""), mode="rb") as file:
|
||||
full_file_path = data.replace("file://", "")
|
||||
|
||||
file_dir = os.path.dirname(full_file_path)
|
||||
file_path = os.path.basename(full_file_path)
|
||||
|
||||
file_storage = get_file_storage(file_dir)
|
||||
|
||||
async with file_storage.open(file_path, mode="rb") as file:
|
||||
classified_data = classify(file)
|
||||
content_hash = classified_data.get_metadata()["content_hash"]
|
||||
return await delete_single_document(content_hash, dataset[0].id, mode)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
from cognee.modules.data.deletion import prune_system, prune_data
|
||||
from cognee.modules.data.deletion import prune_system as _prune_system, prune_data as _prune_data
|
||||
|
||||
|
||||
class prune:
|
||||
@staticmethod
|
||||
async def prune_data():
|
||||
await prune_data()
|
||||
await _prune_data()
|
||||
|
||||
@staticmethod
|
||||
async def prune_system(graph=True, vector=True, metadata=False):
|
||||
await prune_system(graph, vector, metadata)
|
||||
await _prune_system(graph, vector, metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||
|
||||
class BaseConfig(BaseSettings):
|
||||
data_root_directory: str = get_absolute_path(".data_storage")
|
||||
system_root_directory: str = get_absolute_path(".cognee_system")
|
||||
monitoring_tool: object = Observer.LANGFUSE
|
||||
graphistry_username: Optional[str] = os.getenv("GRAPHISTRY_USERNAME")
|
||||
graphistry_password: Optional[str] = os.getenv("GRAPHISTRY_PASSWORD")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import os
|
||||
import pathlib
|
||||
from contextvars import ContextVar
|
||||
from typing import Union
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||
from cognee.modules.users.methods import get_user
|
||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||
|
||||
# Note: ContextVar allows us to use different graph db configurations in Cognee
|
||||
# for different async tasks, threads and processes
|
||||
|
|
@ -32,6 +33,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
"""
|
||||
|
||||
base_config = get_base_config()
|
||||
|
||||
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
return
|
||||
|
||||
|
|
@ -40,16 +43,16 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||
|
||||
# TODO: Find better location for database files
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, f".cognee_system/databases/{user.id}")
|
||||
).resolve()
|
||||
data_root_directory = os.path.join(
|
||||
base_config.data_root_directory, str(user.tenant_id or user.id)
|
||||
)
|
||||
system_directory_path = os.path.join(
|
||||
base_config.system_root_directory, "databases", str(user.id)
|
||||
)
|
||||
|
||||
# Set vector and graph database configuration based on dataset database information
|
||||
vector_config = {
|
||||
"vector_db_url": os.path.join(cognee_directory_path, dataset_database.vector_database_name),
|
||||
"vector_db_url": os.path.join(system_directory_path, dataset_database.vector_database_name),
|
||||
"vector_db_key": "",
|
||||
"vector_db_provider": "lancedb",
|
||||
}
|
||||
|
|
@ -57,11 +60,16 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
graph_config = {
|
||||
"graph_database_provider": "kuzu",
|
||||
"graph_file_path": os.path.join(
|
||||
cognee_directory_path, dataset_database.graph_database_name
|
||||
system_directory_path, f"{dataset_database.graph_database_name}.kuzu"
|
||||
),
|
||||
}
|
||||
|
||||
storage_config = {
|
||||
"data_root_directory": data_root_directory,
|
||||
}
|
||||
|
||||
# Use ContextVar to use these graph and vector configurations are used
|
||||
# in the current async context across Cognee
|
||||
graph_db_config.set(graph_config)
|
||||
vector_db_config.set(vector_config)
|
||||
file_storage_config.set(storage_config)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from cognee.eval_framework.answer_generation.answer_generation_executor import (
|
|||
AnswerGeneratorExecutor,
|
||||
retriever_options,
|
||||
)
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.infrastructure.databases.relational.get_relational_engine import (
|
||||
get_relational_engine,
|
||||
get_relational_config,
|
||||
|
|
@ -22,7 +22,7 @@ async def create_and_insert_answers_table(questions_payload):
|
|||
relational_engine = get_relational_engine()
|
||||
|
||||
if relational_engine.engine.dialect.name == "sqlite":
|
||||
LocalStorage.ensure_directory_exists(relational_config.db_path)
|
||||
await get_file_storage(relational_config.db_path).ensure_directory_exists()
|
||||
|
||||
async with relational_engine.engine.begin() as connection:
|
||||
if len(AnswersBase.metadata.tables.keys()) > 0:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from cognee.shared.logging_utils import get_logger, ERROR
|
|||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
|
||||
from cognee.modules.data.models.questions_base import QuestionsBase
|
||||
from cognee.modules.data.models.questions_data import Questions
|
||||
|
|
@ -21,7 +21,7 @@ async def create_and_insert_questions_table(questions_payload):
|
|||
relational_engine = get_relational_engine()
|
||||
|
||||
if relational_engine.engine.dialect.name == "sqlite":
|
||||
LocalStorage.ensure_directory_exists(relational_config.db_path)
|
||||
await get_file_storage(relational_config.db_path).ensure_directory_exists()
|
||||
|
||||
async with relational_engine.engine.begin() as connection:
|
||||
if len(QuestionsBase.metadata.tables.keys()) > 0:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import List
|
|||
from cognee.eval_framework.evaluation.evaluation_executor import EvaluationExecutor
|
||||
from cognee.eval_framework.analysis.metrics_calculator import calculate_metrics_statistics
|
||||
from cognee.eval_framework.analysis.dashboard_generator import create_dashboard
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.infrastructure.databases.relational.get_relational_engine import (
|
||||
get_relational_engine,
|
||||
get_relational_config,
|
||||
|
|
@ -21,7 +21,7 @@ async def create_and_insert_metrics_table(questions_payload):
|
|||
relational_engine = get_relational_engine()
|
||||
|
||||
if relational_engine.engine.dialect.name == "sqlite":
|
||||
LocalStorage.ensure_directory_exists(relational_config.db_path)
|
||||
await get_file_storage(relational_config.db_path).ensure_directory_exists()
|
||||
|
||||
async with relational_engine.engine.begin() as connection:
|
||||
if len(MetricsBase.metadata.tables.keys()) > 0:
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from functools import lru_cache
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
import pydantic
|
||||
from pydantic import Field
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.root_dir import get_absolute_path
|
||||
|
||||
|
||||
class GraphConfig(BaseSettings):
|
||||
|
|
@ -55,8 +55,14 @@ class GraphConfig(BaseSettings):
|
|||
values.graph_filename = f"cognee_graph_{provider}"
|
||||
# Set file path based on graph database provider if no file path is provided
|
||||
if not values.graph_file_path:
|
||||
base = os.path.join(get_absolute_path(".cognee_system"), "databases")
|
||||
values.graph_file_path = os.path.join(base, values.graph_filename)
|
||||
base_config = get_base_config()
|
||||
|
||||
base = os.path.join(base_config.system_root_directory, "databases")
|
||||
# For Kuzu v0.11.0+, use single-file database with .kuzu extension
|
||||
if provider == "kuzu":
|
||||
values.graph_file_path = os.path.join(base, f"{values.graph_filename}.kuzu")
|
||||
else:
|
||||
values.graph_file_path = os.path.join(base, values.graph_filename)
|
||||
return values
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
"""Adapter for Kuzu graph database."""
|
||||
|
||||
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
from datetime import datetime, timezone
|
||||
import tempfile
|
||||
from uuid import UUID
|
||||
from kuzu import Connection
|
||||
from kuzu.database import Database
|
||||
from datetime import datetime, timezone
|
||||
from contextlib import asynccontextmanager
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
||||
|
||||
import kuzu
|
||||
from kuzu.database import Database
|
||||
from kuzu import Connection
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
record_graph_changes,
|
||||
|
|
@ -46,9 +46,38 @@ class KuzuAdapter(GraphDBInterface):
|
|||
def _initialize_connection(self) -> None:
|
||||
"""Initialize the Kuzu database connection and schema."""
|
||||
try:
|
||||
os.makedirs(self.db_path, exist_ok=True)
|
||||
if "s3://" in self.db_path:
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
|
||||
self.temp_graph_file = temp_file.name
|
||||
|
||||
run_sync(self.pull_from_s3())
|
||||
|
||||
self.db = Database(
|
||||
self.temp_graph_file,
|
||||
buffer_pool_size=256 * 1024 * 1024, # 256MB buffer pool
|
||||
max_db_size=1024 * 1024 * 1024,
|
||||
)
|
||||
else:
|
||||
# Ensure the parent directory exists before creating the database
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
|
||||
# If db_path is just a filename, db_dir will be empty string
|
||||
# In this case, use the directory containing the db_path or current directory
|
||||
if not db_dir:
|
||||
# If no directory in path, use the absolute path's directory
|
||||
abs_path = os.path.abspath(self.db_path)
|
||||
db_dir = os.path.dirname(abs_path)
|
||||
|
||||
file_storage = get_file_storage(db_dir)
|
||||
|
||||
run_sync(file_storage.ensure_directory_exists())
|
||||
|
||||
self.db = Database(
|
||||
self.db_path,
|
||||
buffer_pool_size=256 * 1024 * 1024, # 256MB buffer pool
|
||||
max_db_size=1024 * 1024 * 1024,
|
||||
)
|
||||
|
||||
self.db = Database(self.db_path)
|
||||
self.db.init_database()
|
||||
self.connection = Connection(self.db)
|
||||
# Create node table with essential fields and timestamp
|
||||
|
|
@ -77,6 +106,22 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Failed to initialize Kuzu database: {e}")
|
||||
raise e
|
||||
|
||||
async def push_to_s3(self) -> None:
|
||||
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
|
||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||
|
||||
s3_file_storage = S3FileStorage("")
|
||||
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
|
||||
|
||||
async def pull_from_s3(self) -> None:
|
||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||
|
||||
s3_file_storage = S3FileStorage("")
|
||||
try:
|
||||
s3_file_storage.s3.get(self.db_path, self.temp_graph_file, recursive=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
||||
"""
|
||||
Execute a Kuzu query asynchronously with automatic reconnection.
|
||||
|
|
@ -1385,41 +1430,11 @@ class KuzuAdapter(GraphDBInterface):
|
|||
"relationship_types": [rel[0] for rel in rel_types],
|
||||
}
|
||||
|
||||
async def get_node_labels_string(self) -> str:
|
||||
"""
|
||||
Get all node labels as a string.
|
||||
|
||||
This method aggregates all unique node labels from the graph into a single string
|
||||
representation, which can be helpful for overview and debugging purposes.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A string of all distinct node labels, separated by '|'.
|
||||
"""
|
||||
labels = await self.query("MATCH (n:Node) RETURN DISTINCT labels(n)")
|
||||
return "|".join(sorted(set([label[0] for label in labels])))
|
||||
|
||||
async def get_relationship_labels_string(self) -> str:
|
||||
"""
|
||||
Get all relationship types as a string.
|
||||
|
||||
This method aggregates all unique relationship types from the graph into a single string
|
||||
representation, providing an overview of the relationships defined in the graph.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: A string of all distinct relationship types, separated by '|'.
|
||||
"""
|
||||
types = await self.query("MATCH ()-[r:EDGE]->() RETURN DISTINCT r.relationship_name")
|
||||
return "|".join(sorted(set([t[0] for t in types])))
|
||||
|
||||
async def delete_graph(self) -> None:
|
||||
"""
|
||||
Delete all data from the graph directory.
|
||||
Delete all data from the graph database.
|
||||
|
||||
This method deletes all nodes and relationships from the graph directory
|
||||
This method deletes all nodes and relationships from the graph database.
|
||||
It raises exceptions for failures occurring during deletion processes.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -1432,8 +1447,13 @@ class KuzuAdapter(GraphDBInterface):
|
|||
if self.db:
|
||||
self.db.close()
|
||||
self.db = None
|
||||
if os.path.exists(self.db_path):
|
||||
shutil.rmtree(self.db_path)
|
||||
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
db_name = os.path.basename(self.db_path)
|
||||
file_storage = get_file_storage(db_dir)
|
||||
|
||||
if await file_storage.file_exists(db_name):
|
||||
await file_storage.remove_all()
|
||||
logger.info(f"Deleted Kuzu database files at {self.db_path}")
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -1454,9 +1474,15 @@ class KuzuAdapter(GraphDBInterface):
|
|||
if self.db:
|
||||
self.db.close()
|
||||
self.db = None
|
||||
if os.path.exists(self.db_path):
|
||||
shutil.rmtree(self.db_path)
|
||||
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
db_name = os.path.basename(self.db_path)
|
||||
file_storage = get_file_storage(db_dir)
|
||||
|
||||
if await file_storage.file_exists(db_name):
|
||||
await file_storage.remove_all()
|
||||
logger.info(f"Deleted Kuzu database files at {self.db_path}")
|
||||
|
||||
# Reinitialize the database
|
||||
self._initialize_connection()
|
||||
# Verify the database is empty
|
||||
|
|
@ -1472,61 +1498,6 @@ class KuzuAdapter(GraphDBInterface):
|
|||
logger.error(f"Error during database clearing: {e}")
|
||||
raise
|
||||
|
||||
async def save_graph_to_file(self, file_path: str) -> None:
|
||||
"""
|
||||
Export the Kuzu database to a file.
|
||||
|
||||
This method exports the entire Kuzu graph database to a specified file path, utilizing
|
||||
Kuzu's native export command. Ensure the directory exists prior to attempting the
|
||||
export, and manage related exceptions as they arise.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): Path where to export the database.
|
||||
"""
|
||||
try:
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
# Use Kuzu's native EXPORT command, output is Parquet
|
||||
export_query = f"EXPORT DATABASE '{file_path}'"
|
||||
await self.query(export_query)
|
||||
|
||||
logger.info(f"Graph exported to {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to export graph to file: {e}")
|
||||
raise
|
||||
|
||||
async def load_graph_from_file(self, file_path: str) -> None:
|
||||
"""
|
||||
Import a Kuzu database from a file.
|
||||
|
||||
This method imports a database from a specified file path, ensuring that the file exists
|
||||
before attempting to import. Errors during the import process are managed accordingly,
|
||||
allowing for smooth operation.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): Path to the exported database file.
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(file_path):
|
||||
logger.warning(f"File {file_path} not found")
|
||||
return
|
||||
|
||||
# Use Kuzu's native IMPORT command
|
||||
import_query = f"IMPORT DATABASE '{file_path}'"
|
||||
await self.query(import_query)
|
||||
|
||||
logger.info(f"Graph imported from {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import graph from file: {e}")
|
||||
raise
|
||||
|
||||
async def get_document_subgraph(self, content_hash: str):
|
||||
"""
|
||||
Get all nodes that should be deleted when removing a document.
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
"""Adapter for NetworkX graph database."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from uuid import UUID
|
||||
import networkx as nx
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, List, Union, Type, Tuple
|
||||
|
||||
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import Dict, Any, List, Union, Type, Tuple
|
||||
from uuid import UUID
|
||||
import aiofiles
|
||||
import aiofiles.os as aiofiles_os
|
||||
import networkx as nx
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||
GraphDBInterface,
|
||||
record_graph_changes,
|
||||
|
|
@ -19,7 +19,6 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
import numpy as np
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -551,11 +550,6 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
"""
|
||||
self.graph = nx.MultiDiGraph()
|
||||
|
||||
# Only create directory if file_path contains a directory
|
||||
file_dir = os.path.dirname(file_path)
|
||||
if file_dir and not os.path.exists(file_dir):
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
await self.save_graph_to_file(file_path)
|
||||
|
||||
async def save_graph_to_file(self, file_path: str = None) -> None:
|
||||
|
|
@ -573,9 +567,14 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
|
||||
graph_data = nx.readwrite.json_graph.node_link_data(self.graph, edges="links")
|
||||
|
||||
async with aiofiles.open(file_path, "w") as file:
|
||||
json_data = json.dumps(graph_data, cls=JSONEncoder)
|
||||
await file.write(json_data)
|
||||
file_dir_path = os.path.dirname(file_path)
|
||||
file_path = os.path.basename(file_path)
|
||||
|
||||
file_storage = get_file_storage(file_dir_path)
|
||||
|
||||
json_data = json.dumps(graph_data, cls=JSONEncoder)
|
||||
|
||||
await file_storage.store(file_path, json_data, overwrite=True)
|
||||
|
||||
async def load_graph_from_file(self, file_path: str = None):
|
||||
"""
|
||||
|
|
@ -590,9 +589,14 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
if not file_path:
|
||||
file_path = self.filename
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
async with aiofiles.open(file_path, "r") as file:
|
||||
graph_data = json.loads(await file.read())
|
||||
file_dir_path = os.path.dirname(file_path)
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
file_storage = get_file_storage(file_dir_path)
|
||||
|
||||
if await file_storage.file_exists(file_name):
|
||||
async with file_storage.open(file_name, "r") as file:
|
||||
graph_data = json.loads(file.read())
|
||||
for node in graph_data["nodes"]:
|
||||
try:
|
||||
if not isinstance(node["id"], UUID):
|
||||
|
|
@ -674,8 +678,12 @@ class NetworkXAdapter(GraphDBInterface):
|
|||
self.filename
|
||||
) # Assuming self.filename is defined elsewhere and holds the default graph file path
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
await aiofiles_os.remove(file_path)
|
||||
file_dir_path = os.path.dirname(file_path)
|
||||
file_path = os.path.basename(file_path)
|
||||
|
||||
file_storage = get_file_storage(file_dir_path)
|
||||
|
||||
await file_storage.remove(file_path)
|
||||
|
||||
self.graph = None
|
||||
logger.info("Graph deleted successfully.")
|
||||
|
|
|
|||
|
|
@ -1,17 +1,21 @@
|
|||
import os
|
||||
import asyncio
|
||||
from os import path
|
||||
import tempfile
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from typing import AsyncGenerator, List
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy import text, select, MetaData, Table, delete, inspect
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy import NullPool, text, select, MetaData, Table, delete, inspect
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||
from cognee.modules.data.models.Data import Data
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
|
||||
from ..ModelBase import Base
|
||||
|
||||
|
|
@ -29,11 +33,42 @@ class SQLAlchemyAdapter:
|
|||
self.db_path: str = None
|
||||
self.db_uri: str = connection_string
|
||||
|
||||
self.engine = create_async_engine(connection_string)
|
||||
if "sqlite" in connection_string:
|
||||
[prefix, db_path] = connection_string.split("///")
|
||||
self.db_path = db_path
|
||||
|
||||
if "s3://" in self.db_path:
|
||||
db_dir_path = path.dirname(self.db_path)
|
||||
file_storage = get_file_storage(db_dir_path)
|
||||
|
||||
run_sync(file_storage.ensure_directory_exists())
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
|
||||
self.temp_db_file = temp_file.name
|
||||
connection_string = prefix + "///" + self.temp_db_file
|
||||
|
||||
run_sync(self.pull_from_s3())
|
||||
|
||||
self.engine = create_async_engine(
|
||||
connection_string, poolclass=NullPool if "sqlite" in connection_string else None
|
||||
)
|
||||
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
self.db_path = connection_string.split("///")[1]
|
||||
async def push_to_s3(self) -> None:
|
||||
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_db_file"):
|
||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||
|
||||
s3_file_storage = S3FileStorage("")
|
||||
s3_file_storage.s3.put(self.temp_db_file, self.db_path, recursive=True)
|
||||
|
||||
async def pull_from_s3(self) -> None:
|
||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||
|
||||
s3_file_storage = S3FileStorage("")
|
||||
try:
|
||||
s3_file_storage.s3.get(self.db_path, self.temp_db_file, recursive=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
|
|
@ -249,13 +284,18 @@ class SQLAlchemyAdapter:
|
|||
# Don't delete local file unless this is the only reference to the file in the database
|
||||
if len(raw_data_location_entities) == 1:
|
||||
# delete local file only if it's created by cognee
|
||||
from cognee.base_config import get_base_config
|
||||
storage_config = get_storage_config()
|
||||
|
||||
config = get_base_config()
|
||||
if (
|
||||
storage_config["data_root_directory"]
|
||||
in raw_data_location_entities[0].raw_data_location
|
||||
):
|
||||
file_storage = get_file_storage(storage_config["data_root_directory"])
|
||||
|
||||
if config.data_root_directory in raw_data_location_entities[0].raw_data_location:
|
||||
if os.path.exists(raw_data_location_entities[0].raw_data_location):
|
||||
os.remove(raw_data_location_entities[0].raw_data_location)
|
||||
file_path = os.path.basename(raw_data_location_entities[0].raw_data_location)
|
||||
|
||||
if await file_storage.file_exists(file_path):
|
||||
await file_storage.remove(file_path)
|
||||
else:
|
||||
# Report bug as file should exist
|
||||
logger.error("Local file which should exist can't be found.")
|
||||
|
|
@ -434,11 +474,13 @@ class SQLAlchemyAdapter:
|
|||
Create the database if it does not exist, ensuring necessary directories are in place
|
||||
for SQLite.
|
||||
"""
|
||||
if self.engine.dialect.name == "sqlite" and not os.path.exists(self.db_path):
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
db_directory = path.dirname(self.db_path)
|
||||
LocalStorage.ensure_directory_exists(db_directory)
|
||||
db_name = path.basename(self.db_path)
|
||||
file_storage = get_file_storage(db_directory)
|
||||
|
||||
if not await file_storage.file_exists(db_name):
|
||||
await file_storage.ensure_directory_exists()
|
||||
|
||||
async with self.engine.begin() as connection:
|
||||
if len(Base.metadata.tables.keys()) > 0:
|
||||
|
|
@ -450,13 +492,13 @@ class SQLAlchemyAdapter:
|
|||
"""
|
||||
try:
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
|
||||
await self.engine.dispose(close=True)
|
||||
# Wait for the database connections to close and release the file (Windows)
|
||||
await asyncio.sleep(2)
|
||||
db_directory = path.dirname(self.db_path)
|
||||
LocalStorage.ensure_directory_exists(db_directory)
|
||||
with open(self.db_path, "w") as file:
|
||||
file.write("")
|
||||
file_path = path.basename(self.db_path)
|
||||
file_storage = get_file_storage(db_directory)
|
||||
await file_storage.remove(file_path)
|
||||
else:
|
||||
async with self.engine.begin() as connection:
|
||||
# Create a MetaData instance to load table information
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
from os import path
|
||||
import lancedb
|
||||
from pydantic import BaseModel
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
|
@ -7,7 +8,7 @@ from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin
|
|||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
|
|
@ -310,7 +311,9 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
await connection.drop_table(collection_name)
|
||||
|
||||
if self.url.startswith("/"):
|
||||
LocalStorage.remove_all(self.url)
|
||||
db_dir_path = path.dirname(self.url)
|
||||
db_file_name = path.basename(self.url)
|
||||
await get_file_storage(db_dir_path).remove_all(db_file_name)
|
||||
|
||||
def get_data_point_schema(self, model_type: BaseModel):
|
||||
related_models_fields = []
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
import asyncio
|
||||
import os
|
||||
from uuid import UUID
|
||||
from typing import List, Optional
|
||||
|
||||
|
|
@ -7,6 +8,7 @@ from cognee.shared.logging_utils import get_logger
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
|
|
@ -74,6 +76,34 @@ class MilvusAdapter(VectorDBInterface):
|
|||
"""
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
# Ensure the parent directory exists for local file-based Milvus databases
|
||||
if self.url and not self.url.startswith(("http://", "https://", "grpc://")):
|
||||
# This is likely a local file path, ensure the directory exists
|
||||
db_dir = os.path.dirname(self.url)
|
||||
if db_dir and not os.path.exists(db_dir):
|
||||
try:
|
||||
file_storage = get_file_storage(db_dir)
|
||||
if hasattr(file_storage, "ensure_directory_exists"):
|
||||
if asyncio.iscoroutinefunction(file_storage.ensure_directory_exists):
|
||||
# Run async function synchronously in this sync method
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If we're already in an async context, we can't use run_sync easily
|
||||
# Create the directory directly as a fallback
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
else:
|
||||
loop.run_until_complete(file_storage.ensure_directory_exists())
|
||||
else:
|
||||
file_storage.ensure_directory_exists()
|
||||
else:
|
||||
# Fallback to os.makedirs if file_storage doesn't have ensure_directory_exists
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not create directory {db_dir} using file_storage, falling back to os.makedirs: {e}"
|
||||
)
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
|
||||
if self.api_key:
|
||||
client = MilvusClient(uri=self.url, token=self.api_key)
|
||||
else:
|
||||
|
|
@ -343,8 +373,6 @@ class MilvusAdapter(VectorDBInterface):
|
|||
"""
|
||||
from pymilvus import MilvusException, exceptions
|
||||
|
||||
if limit <= 0:
|
||||
return []
|
||||
client = self.get_milvus_client()
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
|
|
|
|||
|
|
@ -1,3 +1 @@
|
|||
from .add_file_to_storage import add_file_to_storage
|
||||
from .remove_file_from_storage import remove_file_from_storage
|
||||
from .utils.get_file_metadata import get_file_metadata, FileMetadata
|
||||
|
|
|
|||
|
|
@ -1,22 +0,0 @@
|
|||
from typing import BinaryIO
|
||||
from cognee.root_dir import get_absolute_path
|
||||
from .storage.StorageManager import StorageManager
|
||||
from .storage.LocalStorage import LocalStorage
|
||||
|
||||
|
||||
async def add_file_to_storage(file_path: str, file: BinaryIO):
|
||||
"""
|
||||
Store a file in local storage.
|
||||
|
||||
This function initializes a storage manager and uses it to store the provided file at
|
||||
the specified file path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path where the file will be stored.
|
||||
- file (BinaryIO): The file object to be stored, which must be a binary file.
|
||||
"""
|
||||
storage_manager = StorageManager(LocalStorage(get_absolute_path("data/files")))
|
||||
|
||||
storage_manager.store(file_path, file)
|
||||
13
cognee/infrastructure/files/exceptions.py
Normal file
13
cognee/infrastructure/files/exceptions.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from fastapi import status
|
||||
|
||||
|
||||
class FileContentHashingError(Exception):
|
||||
"""Raised when the file content cannot be hashed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to hash content of the file.",
|
||||
name: str = "FileContentHashingError",
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
from cognee.root_dir import get_absolute_path
|
||||
from .storage.StorageManager import StorageManager
|
||||
from .storage.LocalStorage import LocalStorage
|
||||
|
||||
|
||||
async def remove_file_from_storage(file_path: str):
|
||||
"""
|
||||
Remove a specified file from storage.
|
||||
|
||||
This function initializes a storage manager with a local storage instance and calls the
|
||||
remove method of the storage manager to delete the file identified by the given path.
|
||||
Ensure that the file exists in the specified storage before calling this function to
|
||||
avoid
|
||||
potential errors.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to remove from storage.
|
||||
"""
|
||||
storage_manager = StorageManager(LocalStorage(get_absolute_path("data/files")))
|
||||
|
||||
storage_manager.remove(file_path)
|
||||
16
cognee/infrastructure/files/storage/FileBufferedReader.py
Normal file
16
cognee/infrastructure/files/storage/FileBufferedReader.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from io import BufferedReader
|
||||
|
||||
|
||||
class FileBufferedReader(BufferedReader):
|
||||
def __init__(self, file_obj, name):
|
||||
super().__init__(file_obj)
|
||||
self._file = file_obj
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def read(self, size: int = -1):
|
||||
data = self._file.read(size)
|
||||
return data
|
||||
256
cognee/infrastructure/files/storage/LocalFileStorage.py
Normal file
256
cognee/infrastructure/files/storage/LocalFileStorage.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
import os
|
||||
import shutil
|
||||
from urllib.parse import urlparse
|
||||
from contextlib import contextmanager
|
||||
from typing import BinaryIO, Optional, Union
|
||||
|
||||
from .FileBufferedReader import FileBufferedReader
|
||||
from .storage import Storage
|
||||
|
||||
|
||||
def get_parsed_path(file_path: str) -> str:
|
||||
# Check if this is actually a URL (has a scheme like file://, http://, etc.)
|
||||
if "://" in file_path:
|
||||
parsed_url = urlparse(file_path)
|
||||
|
||||
# Handle file:// URLs specially
|
||||
if parsed_url.scheme == "file":
|
||||
# On Windows, urlparse handles drive letters correctly
|
||||
# Convert the path component to a proper file path
|
||||
if os.name == "nt": # Windows
|
||||
# Remove leading slash from Windows paths like /C:/Users/...
|
||||
# but handle UNC paths like //server/share correctly
|
||||
parsed_path = parsed_url.path
|
||||
if parsed_path.startswith("/") and len(parsed_path) > 1 and parsed_path[2] == ":":
|
||||
# This is a Windows drive path like /C:/Users/...
|
||||
parsed_path = parsed_path[1:]
|
||||
elif parsed_path.startswith("///"):
|
||||
# This is a UNC path like ///server/share, convert to //server/share
|
||||
parsed_path = parsed_path[1:]
|
||||
else: # Unix-like systems
|
||||
parsed_path = parsed_url.path
|
||||
else:
|
||||
# For non-file URLs, use the path as-is
|
||||
parsed_path = parsed_url.path
|
||||
if (
|
||||
os.name == "nt"
|
||||
and parsed_path.startswith("/")
|
||||
and len(parsed_path) > 1
|
||||
and parsed_path[2] == ":"
|
||||
):
|
||||
parsed_path = parsed_path[1:]
|
||||
|
||||
# Normalize path separators to ensure consistency
|
||||
return os.path.normpath(parsed_path)
|
||||
else:
|
||||
# This is a regular file path, not a URL - normalize separators
|
||||
return os.path.normpath(file_path)
|
||||
|
||||
|
||||
class LocalFileStorage(Storage):
|
||||
"""
|
||||
Manage local file storage operations such as storing, retrieving, and managing files on
|
||||
the filesystem.
|
||||
"""
|
||||
|
||||
storage_path: Optional[str] = None
|
||||
|
||||
def __init__(self, storage_path: str):
|
||||
self.storage_path = storage_path
|
||||
|
||||
def store(self, file_path: str, data: Union[BinaryIO, str], overwrite: bool = False) -> str:
|
||||
"""
|
||||
Store data into a specified file path. The data can be either a string or a binary
|
||||
stream.
|
||||
|
||||
This method ensures that the storage directory exists before attempting to write the
|
||||
data. If the provided data is a stream, it reads from the stream and writes to the file;
|
||||
otherwise, it directly writes the provided data.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The relative path of the file where the data will be stored.
|
||||
- data (Union[BinaryIO, str]): The data to be stored, which can be a string or a
|
||||
binary stream.
|
||||
- overwrite (bool): If True, overwrite the existing file.
|
||||
"""
|
||||
parsed_storage_path = get_parsed_path(self.storage_path)
|
||||
full_file_path = os.path.join(parsed_storage_path, file_path)
|
||||
file_dir_path = os.path.dirname(full_file_path)
|
||||
|
||||
self.ensure_directory_exists(file_dir_path)
|
||||
|
||||
if overwrite or not os.path.exists(full_file_path):
|
||||
with open(
|
||||
full_file_path,
|
||||
mode="w" if isinstance(data, str) else "wb",
|
||||
encoding="utf-8" if isinstance(data, str) else None,
|
||||
) as file:
|
||||
if hasattr(data, "read"):
|
||||
data.seek(0)
|
||||
file.write(data.read())
|
||||
else:
|
||||
file.write(data)
|
||||
|
||||
file.close()
|
||||
|
||||
return "file://" + full_file_path
|
||||
|
||||
@contextmanager
|
||||
def open(self, file_path: str, mode: str = "rb", *args, **kwargs):
|
||||
"""
|
||||
Retrieve data from a specified file path, returning the content as bytes.
|
||||
|
||||
This method opens the file in read mode and reads its content. The function expects the
|
||||
file to exist; if it does not, a FileNotFoundError will be raised.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The relative path of the file to retrieve data from.
|
||||
- mode (str): The mode to open the file, with "rb" as the default for reading binary
|
||||
files. (default "rb")
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The content of the retrieved file as bytes.
|
||||
"""
|
||||
parsed_storage_path = get_parsed_path(self.storage_path)
|
||||
|
||||
full_file_path = os.path.join(parsed_storage_path, file_path)
|
||||
|
||||
# Add debug information for Windows path issues
|
||||
if not os.path.exists(full_file_path):
|
||||
# Try to provide helpful debug information
|
||||
if os.path.exists(parsed_storage_path):
|
||||
available_files = []
|
||||
try:
|
||||
available_files = os.listdir(parsed_storage_path)
|
||||
except (OSError, PermissionError):
|
||||
available_files = ["<unable to list directory>"]
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"File not found: '{full_file_path}'\n"
|
||||
f"Storage path: '{parsed_storage_path}'\n"
|
||||
f"Requested file: '{file_path}'\n"
|
||||
f"Storage path exists: {os.path.exists(parsed_storage_path)}\n"
|
||||
f"Available files in storage: {available_files[:10]}..." # Limit to first 10 files
|
||||
)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Storage directory does not exist: '{parsed_storage_path}'\n"
|
||||
f"Original storage path: '{self.storage_path}'\n"
|
||||
f"Requested file: '{file_path}'"
|
||||
)
|
||||
|
||||
with open(full_file_path, mode=mode, *args, **kwargs) as file:
|
||||
file = FileBufferedReader(file, name="file://" + full_file_path)
|
||||
|
||||
try:
|
||||
yield file
|
||||
finally:
|
||||
file.close()
|
||||
|
||||
def file_exists(self, file_path: str):
|
||||
"""
|
||||
Check if a specified file exists in the storage.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to check for existence.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the file exists, otherwise False.
|
||||
"""
|
||||
parsed_storage_path = get_parsed_path(self.storage_path)
|
||||
|
||||
return os.path.exists(os.path.join(parsed_storage_path, file_path))
|
||||
|
||||
def ensure_directory_exists(self, directory_path: str = None):
|
||||
"""
|
||||
Ensure that the specified directory exists, creating it if necessary.
|
||||
|
||||
If the directory already exists, no action is taken.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- directory_path (str): The path of the directory to check or create.
|
||||
"""
|
||||
if directory_path is None:
|
||||
directory_path = get_parsed_path(self.storage_path)
|
||||
elif not directory_path or directory_path.strip() == "":
|
||||
# Handle empty string case - use current directory or storage path
|
||||
directory_path = get_parsed_path(self.storage_path) or "."
|
||||
|
||||
if not os.path.exists(directory_path):
|
||||
os.makedirs(directory_path, exist_ok=True)
|
||||
|
||||
def copy_file(self, source_file_path: str, destination_file_path: str):
|
||||
"""
|
||||
Copy a file from a source path to a destination path.
|
||||
Files need to be in the same storage.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- source_file_path (str): The path of the file to be copied.
|
||||
- destination_file_path (str): The path where the file will be copied to.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: The path to the copied file.
|
||||
"""
|
||||
parsed_storage_path = get_parsed_path(self.storage_path)
|
||||
|
||||
return shutil.copy2(
|
||||
os.path.join(parsed_storage_path, source_file_path),
|
||||
os.path.join(parsed_storage_path, destination_file_path),
|
||||
)
|
||||
|
||||
def remove(self, file_path: str):
|
||||
"""
|
||||
Remove the specified file from the storage if it exists.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to be removed.
|
||||
"""
|
||||
parsed_storage_path = get_parsed_path(self.storage_path)
|
||||
full_file_path = os.path.join(parsed_storage_path, file_path)
|
||||
|
||||
if os.path.exists(full_file_path):
|
||||
os.remove(full_file_path)
|
||||
|
||||
def remove_all(self, tree_path: str = None):
|
||||
"""
|
||||
Remove an entire directory tree at the specified path, including all files and
|
||||
subdirectories.
|
||||
|
||||
If the directory does not exist, no action is taken and no exception is raised.
|
||||
|
||||
If directories don't exist in the storage we ignore it.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- tree_path (str): The root path of the directory tree to be removed.
|
||||
"""
|
||||
parsed_storage_path = get_parsed_path(self.storage_path)
|
||||
|
||||
if tree_path is None:
|
||||
tree_path = parsed_storage_path
|
||||
else:
|
||||
tree_path = os.path.join(parsed_storage_path, tree_path)
|
||||
|
||||
try:
|
||||
return shutil.rmtree(tree_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
|
@ -1,153 +0,0 @@
|
|||
import os
|
||||
import shutil
|
||||
from typing import BinaryIO, Union
|
||||
from .StorageManager import Storage
|
||||
|
||||
|
||||
class LocalStorage(Storage):
|
||||
"""
|
||||
Manage local file storage operations such as storing, retrieving, and managing files on
|
||||
the filesystem.
|
||||
"""
|
||||
|
||||
storage_path: str = None
|
||||
|
||||
def __init__(self, storage_path: str):
|
||||
self.storage_path = storage_path
|
||||
|
||||
def store(self, file_path: str, data: Union[BinaryIO, str]):
|
||||
"""
|
||||
Store data into a specified file path. The data can be either a string or a binary
|
||||
stream.
|
||||
|
||||
This method ensures that the storage directory exists before attempting to write the
|
||||
data. If the provided data is a stream, it reads from the stream and writes to the file;
|
||||
otherwise, it directly writes the provided data.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The relative path of the file where the data will be stored.
|
||||
- data (Union[BinaryIO, str]): The data to be stored, which can be a string or a
|
||||
binary stream.
|
||||
"""
|
||||
full_file_path = self.storage_path + "/" + file_path
|
||||
|
||||
LocalStorage.ensure_directory_exists(self.storage_path)
|
||||
|
||||
with open(
|
||||
full_file_path,
|
||||
mode="w" if isinstance(data, str) else "wb",
|
||||
encoding="utf-8" if isinstance(data, str) else None,
|
||||
) as f:
|
||||
if hasattr(data, "read"):
|
||||
data.seek(0)
|
||||
f.write(data.read())
|
||||
else:
|
||||
f.write(data)
|
||||
|
||||
def retrieve(self, file_path: str, mode: str = "rb"):
|
||||
"""
|
||||
Retrieve data from a specified file path, returning the content as bytes.
|
||||
|
||||
This method opens the file in read mode and reads its content. The function expects the
|
||||
file to exist; if it does not, a FileNotFoundError will be raised.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The relative path of the file to retrieve data from.
|
||||
- mode (str): The mode to open the file, with 'rb' as the default for reading binary
|
||||
files. (default 'rb')
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The content of the retrieved file as bytes.
|
||||
"""
|
||||
full_file_path = self.storage_path + "/" + file_path
|
||||
|
||||
with open(full_file_path, mode=mode) as f:
|
||||
f.seek(0)
|
||||
return f.read()
|
||||
|
||||
@staticmethod
|
||||
def file_exists(file_path: str):
|
||||
"""
|
||||
Check if a specified file exists in the filesystem.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to check for existence.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the file exists, otherwise False.
|
||||
"""
|
||||
return os.path.exists(file_path)
|
||||
|
||||
@staticmethod
|
||||
def ensure_directory_exists(file_path: str):
|
||||
"""
|
||||
Ensure that the specified directory exists, creating it if necessary.
|
||||
|
||||
If the directory already exists, no action is taken.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the directory to check or create.
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def remove(file_path: str):
|
||||
"""
|
||||
Remove the specified file from the filesystem if it exists.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to be removed.
|
||||
"""
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
@staticmethod
|
||||
def copy_file(source_file_path: str, destination_file_path: str):
|
||||
"""
|
||||
Copy a file from a source path to a destination path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- source_file_path (str): The path of the file to be copied.
|
||||
- destination_file_path (str): The path where the file will be copied to.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: The path to the copied file.
|
||||
"""
|
||||
return shutil.copy2(source_file_path, destination_file_path)
|
||||
|
||||
@staticmethod
|
||||
def remove_all(tree_path: str):
|
||||
"""
|
||||
Remove an entire directory tree at the specified path, including all files and
|
||||
subdirectories.
|
||||
|
||||
If the directory does not exist, no action is taken and no exception is raised.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- tree_path (str): The root path of the directory tree to be removed.
|
||||
"""
|
||||
try:
|
||||
shutil.rmtree(tree_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
215
cognee/infrastructure/files/storage/S3FileStorage.py
Normal file
215
cognee/infrastructure/files/storage/S3FileStorage.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
import os
|
||||
import s3fs
|
||||
from typing import BinaryIO, Union
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from cognee.api.v1.add.config import get_s3_config
|
||||
from cognee.infrastructure.utils.run_async import run_async
|
||||
from cognee.infrastructure.files.storage.FileBufferedReader import FileBufferedReader
|
||||
from .storage import Storage
|
||||
|
||||
|
||||
class S3FileStorage(Storage):
|
||||
"""
|
||||
Manage local file storage operations such as storing, retrieving, and managing files on
|
||||
the filesystem.
|
||||
"""
|
||||
|
||||
storage_path: str
|
||||
s3: s3fs.S3FileSystem
|
||||
|
||||
def __init__(self, storage_path: str):
|
||||
self.storage_path = storage_path
|
||||
s3_config = get_s3_config()
|
||||
if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None:
|
||||
self.s3 = s3fs.S3FileSystem(
|
||||
key=s3_config.aws_access_key_id,
|
||||
secret=s3_config.aws_secret_access_key,
|
||||
anon=False,
|
||||
endpoint_url="https://s3-eu-west-1.amazonaws.com",
|
||||
)
|
||||
else:
|
||||
raise ValueError("S3 credentials are not set in the configuration.")
|
||||
|
||||
async def store(
|
||||
self, file_path: str, data: Union[BinaryIO, str], overwrite: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Store data into a specified file path. The data can be either a string or a binary
|
||||
stream.
|
||||
|
||||
This method ensures that the storage directory exists before attempting to write the
|
||||
data. If the provided data is a stream, it reads from the stream and writes to the file;
|
||||
otherwise, it directly writes the provided data.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The relative path of the file where the data will be stored.
|
||||
- data (Union[BinaryIO, str]): The data to be stored, which can be a string or a
|
||||
binary stream.
|
||||
- overwrite (bool): If True, overwrite the existing file.
|
||||
"""
|
||||
full_file_path = os.path.join(self.storage_path.replace("s3://", ""), file_path)
|
||||
|
||||
file_dir_path = os.path.dirname(full_file_path)
|
||||
|
||||
await self.ensure_directory_exists(file_dir_path)
|
||||
|
||||
if overwrite or not await self.file_exists(file_path):
|
||||
|
||||
def save_data_to_file():
|
||||
with self.s3.open(
|
||||
full_file_path,
|
||||
mode="w" if isinstance(data, str) else "wb",
|
||||
encoding="utf-8" if isinstance(data, str) else None,
|
||||
) as file:
|
||||
if hasattr(data, "read"):
|
||||
data.seek(0)
|
||||
file.write(data.read())
|
||||
else:
|
||||
file.write(data)
|
||||
|
||||
file.close()
|
||||
|
||||
await run_async(save_data_to_file)
|
||||
|
||||
return "s3://" + full_file_path
|
||||
|
||||
@asynccontextmanager
|
||||
async def open(self, file_path: str, mode: str = "r"):
|
||||
"""
|
||||
Retrieve data from a specified file path, returning the content as bytes.
|
||||
|
||||
This method opens the file in read mode and reads its content. The function expects the
|
||||
file to exist; if it does not, a FileNotFoundError will be raised.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The relative path of the file to retrieve data from.
|
||||
- mode (str): The mode to open the file, with "r" as the default for reading binary
|
||||
files. (default "r")
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The content of the retrieved file as bytes.
|
||||
"""
|
||||
full_file_path = os.path.join(self.storage_path.replace("s3://", ""), file_path)
|
||||
|
||||
def get_file():
|
||||
return self.s3.open(full_file_path, mode=mode)
|
||||
|
||||
file = await run_async(get_file)
|
||||
file = FileBufferedReader(file, name="s3://" + full_file_path)
|
||||
|
||||
try:
|
||||
yield file
|
||||
finally:
|
||||
file.close()
|
||||
|
||||
async def file_exists(self, file_path: str):
|
||||
"""
|
||||
Check if a specified file exists in the filesystem.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to check for existence.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the file exists, otherwise False.
|
||||
"""
|
||||
return await run_async(
|
||||
self.s3.exists, os.path.join(self.storage_path.replace("s3://", ""), file_path)
|
||||
)
|
||||
|
||||
async def ensure_directory_exists(self, directory_path: str = None):
|
||||
"""
|
||||
Ensure that the specified directory exists, creating it if necessary.
|
||||
|
||||
If the directory already exists, no action is taken.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- directory_path (str): The path of the directory to check or create.
|
||||
"""
|
||||
if directory_path is None:
|
||||
directory_path = self.storage_path.replace("s3://", "")
|
||||
|
||||
def ensure_directory():
|
||||
if not self.s3.exists(directory_path):
|
||||
self.s3.makedirs(directory_path, exist_ok=True)
|
||||
|
||||
await run_async(ensure_directory)
|
||||
|
||||
async def copy_file(self, source_file_path: str, destination_file_path: str):
|
||||
"""
|
||||
Copy a file from a source path to a destination path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- source_file_path (str): The path of the file to be copied.
|
||||
- destination_file_path (str): The path where the file will be copied to.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: The path to the copied file.
|
||||
"""
|
||||
|
||||
def copy():
|
||||
return self.s3.copy(
|
||||
os.path.join(self.storage_path.replace("s3://", ""), source_file_path),
|
||||
os.path.join(self.storage_path.replace("s3://", ""), destination_file_path),
|
||||
recursive=True,
|
||||
)
|
||||
|
||||
return await run_async(copy)
|
||||
|
||||
async def remove(self, file_path: str):
|
||||
"""
|
||||
Remove the specified file from the filesystem if it exists.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to be removed.
|
||||
"""
|
||||
full_file_path = os.path.join(self.storage_path.replace("s3://", ""), file_path)
|
||||
|
||||
def remove_file():
|
||||
if self.s3.exists(full_file_path):
|
||||
self.s3.rm_file(full_file_path)
|
||||
|
||||
await run_async(remove_file)
|
||||
|
||||
async def remove_all(self, tree_path: str):
|
||||
"""
|
||||
Remove an entire directory tree at the specified path, including all files and
|
||||
subdirectories.
|
||||
|
||||
If the directory does not exist, no action is taken and no exception is raised.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- tree_path (str): The root path of the directory tree to be removed.
|
||||
"""
|
||||
if tree_path is None:
|
||||
tree_path = self.storage_path.replace("s3://", "")
|
||||
else:
|
||||
tree_path = os.path.join(self.storage_path.replace("s3://", ""), tree_path)
|
||||
|
||||
# async_remove_all = run_async(lambda: self.s3.rm(tree_path, recursive=True))
|
||||
|
||||
try:
|
||||
# await async_remove_all()
|
||||
await run_async(self.s3.rm, tree_path, recursive=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
|
@ -1,45 +1,8 @@
|
|||
from typing import Protocol, BinaryIO
|
||||
import inspect
|
||||
from typing import BinaryIO
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
class Storage(Protocol):
|
||||
"""
|
||||
Abstract interface for storage operations.
|
||||
"""
|
||||
|
||||
def store(self, file_path: str, data: bytes):
|
||||
"""
|
||||
Store data at the specified file path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path where the data will be stored.
|
||||
- data (bytes): The binary data to be stored.
|
||||
"""
|
||||
pass
|
||||
|
||||
def retrieve(self, file_path: str):
|
||||
"""
|
||||
Retrieve data from the specified file path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path from where the data will be retrieved.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def remove(file_path: str):
|
||||
"""
|
||||
Remove the storage at the specified file path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to be removed.
|
||||
"""
|
||||
pass
|
||||
from .storage import Storage
|
||||
|
||||
|
||||
class StorageManager:
|
||||
|
|
@ -48,8 +11,9 @@ class StorageManager:
|
|||
|
||||
Public methods include:
|
||||
- store: Store data in the specified path.
|
||||
- retrieve: Retrieve data from the specified path.
|
||||
- open: Open a file from the specified path.
|
||||
- remove: Remove the file at the specified path.
|
||||
- remove_all: Remove all files under the directory tree.
|
||||
"""
|
||||
|
||||
storage: Storage = None
|
||||
|
|
@ -57,7 +21,26 @@ class StorageManager:
|
|||
def __init__(self, storage: Storage):
|
||||
self.storage = storage
|
||||
|
||||
def store(self, file_path: str, data: BinaryIO):
|
||||
async def file_exists(self, file_path: str):
|
||||
"""
|
||||
Check if a specified file exists in the storage.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to check for existence.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the file exists, otherwise False.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(self.storage.file_exists):
|
||||
return await self.storage.file_exists(file_path)
|
||||
else:
|
||||
return self.storage.file_exists(file_path)
|
||||
|
||||
async def store(self, file_path: str, data: BinaryIO, overwrite: bool = False) -> str:
|
||||
"""
|
||||
Store data at the specified file path.
|
||||
|
||||
|
|
@ -66,16 +49,20 @@ class StorageManager:
|
|||
|
||||
- file_path (str): The path where the data should be stored.
|
||||
- data (BinaryIO): The data in a binary format that needs to be stored.
|
||||
- overwrite (bool): If True, overwrite the existing file.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the outcome of the store operation, as defined by the storage
|
||||
implementation.
|
||||
Returns the full path to the file.
|
||||
"""
|
||||
return self.storage.store(file_path, data)
|
||||
if inspect.iscoroutinefunction(self.storage.store):
|
||||
return await self.storage.store(file_path, data, overwrite)
|
||||
else:
|
||||
return self.storage.store(file_path, data, overwrite)
|
||||
|
||||
def retrieve(self, file_path: str):
|
||||
@asynccontextmanager
|
||||
async def open(self, file_path: str, encoding: str = None, *args, **kwargs):
|
||||
"""
|
||||
Retrieve data from the specified file path.
|
||||
|
||||
|
|
@ -89,9 +76,34 @@ class StorageManager:
|
|||
|
||||
Returns the retrieved data, as defined by the storage implementation.
|
||||
"""
|
||||
return self.storage.retrieve(file_path)
|
||||
# Check the actual storage type by class name to determine if open() is async or sync
|
||||
|
||||
def remove(self, file_path: str):
|
||||
if self.storage.__class__.__name__ == "S3FileStorage":
|
||||
# S3FileStorage.open() is async
|
||||
async with self.storage.open(file_path, *args, **kwargs) as file:
|
||||
yield file
|
||||
else:
|
||||
# LocalFileStorage.open() is sync
|
||||
with self.storage.open(file_path, *args, **kwargs) as file:
|
||||
yield file
|
||||
|
||||
async def ensure_directory_exists(self, directory_path: str = None):
|
||||
"""
|
||||
Ensure that the specified directory exists, creating it if necessary.
|
||||
|
||||
If the directory already exists, no action is taken.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- directory_path (str): The path of the directory to check or create.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(self.storage.ensure_directory_exists):
|
||||
return await self.storage.ensure_directory_exists(directory_path)
|
||||
else:
|
||||
return self.storage.ensure_directory_exists(directory_path)
|
||||
|
||||
async def remove(self, file_path: str):
|
||||
"""
|
||||
Remove the file at the specified path.
|
||||
|
||||
|
|
@ -106,4 +118,24 @@ class StorageManager:
|
|||
Returns the outcome of the remove operation, as defined by the storage
|
||||
implementation.
|
||||
"""
|
||||
return self.storage.remove(file_path)
|
||||
if inspect.iscoroutinefunction(self.storage.remove):
|
||||
return await self.storage.remove(file_path)
|
||||
else:
|
||||
return self.storage.remove(file_path)
|
||||
|
||||
async def remove_all(self, tree_path: str = None):
|
||||
"""
|
||||
Remove an entire directory tree at the specified path, including all files and
|
||||
subdirectories.
|
||||
|
||||
If the directory does not exist, no action is taken and no exception is raised.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- tree_path (str): The root path of the directory tree to be removed.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(self.storage.remove_all):
|
||||
return await self.storage.remove_all(tree_path)
|
||||
else:
|
||||
return self.storage.remove_all(tree_path)
|
||||
|
|
|
|||
|
|
@ -1 +1,3 @@
|
|||
from .LocalStorage import LocalStorage
|
||||
from .StorageManager import StorageManager
|
||||
from .get_file_storage import get_file_storage
|
||||
from .get_storage_config import get_storage_config
|
||||
|
|
|
|||
4
cognee/infrastructure/files/storage/config.py
Normal file
4
cognee/infrastructure/files/storage/config.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from contextvars import ContextVar
|
||||
|
||||
|
||||
file_storage_config = ContextVar("file_storage_config", default=None)
|
||||
23
cognee/infrastructure/files/storage/get_file_storage.py
Normal file
23
cognee/infrastructure/files/storage/get_file_storage.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import os
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
from .StorageManager import StorageManager
|
||||
|
||||
|
||||
def get_file_storage(storage_path: str) -> StorageManager:
|
||||
base_config = get_base_config()
|
||||
|
||||
# Use S3FileStorage if the storage_path is an S3 URL or if configured for S3
|
||||
if storage_path.startswith("s3://") or (
|
||||
os.getenv("STORAGE_BACKEND") == "s3"
|
||||
and "s3://" in base_config.system_root_directory
|
||||
and "s3://" in base_config.data_root_directory
|
||||
):
|
||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||
|
||||
return StorageManager(S3FileStorage(storage_path))
|
||||
else:
|
||||
from cognee.infrastructure.files.storage.LocalFileStorage import LocalFileStorage
|
||||
|
||||
return StorageManager(LocalFileStorage(storage_path))
|
||||
18
cognee/infrastructure/files/storage/get_storage_config.py
Normal file
18
cognee/infrastructure/files/storage/get_storage_config.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from cognee.base_config import get_base_config
|
||||
from .config import file_storage_config
|
||||
|
||||
|
||||
def get_global_storage_config():
|
||||
base_config = get_base_config()
|
||||
|
||||
return {
|
||||
"data_root_directory": base_config.data_root_directory,
|
||||
}
|
||||
|
||||
|
||||
def get_storage_config():
|
||||
context_config = file_storage_config.get()
|
||||
if context_config:
|
||||
return context_config
|
||||
|
||||
return get_global_storage_config()
|
||||
105
cognee/infrastructure/files/storage/storage.py
Normal file
105
cognee/infrastructure/files/storage/storage.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
from typing import BinaryIO, Protocol, Union
|
||||
|
||||
|
||||
class Storage(Protocol):
|
||||
storage_path: str
|
||||
|
||||
"""
|
||||
Abstract interface for storage operations.
|
||||
"""
|
||||
|
||||
def file_exists(self, file_path: str) -> bool:
|
||||
"""
|
||||
Check if a specified file exists in the storage.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to check for existence.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the file exists, otherwise False.
|
||||
"""
|
||||
pass
|
||||
|
||||
def store(self, file_path: str, data: Union[BinaryIO, str], overwrite: bool):
|
||||
"""
|
||||
Store data at the specified file path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path where the data will be stored.
|
||||
- data (bytes): The binary data to be stored.
|
||||
- overwrite (bool): If True, overwrite the existing file.
|
||||
"""
|
||||
pass
|
||||
|
||||
def open(self, file_path: str, mode: str = "r"):
|
||||
"""
|
||||
Retrieve file from the specified file path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path from where the data will be retrieved.
|
||||
- mode (str): The mode to open the file, with "r" as the default for reading text
|
||||
"""
|
||||
pass
|
||||
|
||||
def copy_file(self, source_file_path: str, destination_file_path: str):
|
||||
"""
|
||||
Copy a file from a source path to a destination path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- source_file_path (str): The path of the file to be copied.
|
||||
- destination_file_path (str): The path where the file will be copied to.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- str: The path to the copied file.
|
||||
"""
|
||||
pass
|
||||
|
||||
def ensure_directory_exists(self, directory_path: str = None):
|
||||
"""
|
||||
Ensure that the specified directory exists, creating it if necessary.
|
||||
|
||||
If the directory already exists, no action is taken.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- directory_path (str): The path of the directory to check or create.
|
||||
"""
|
||||
pass
|
||||
|
||||
def remove(self, file_path: str):
|
||||
"""
|
||||
Remove the storage at the specified file path.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- file_path (str): The path of the file to be removed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def remove_all(self, root_path: str = None):
|
||||
"""
|
||||
Remove an entire directory tree at the specified path, including all files and
|
||||
subdirectories.
|
||||
|
||||
If the directory does not exist, no action is taken and no exception is raised.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- tree_path (str): The root path of the directory tree to be removed.
|
||||
"""
|
||||
pass
|
||||
40
cognee/infrastructure/files/utils/get_file_content_hash.py
Normal file
40
cognee/infrastructure/files/utils/get_file_content_hash.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
import hashlib
|
||||
import os
|
||||
from os import path
|
||||
from typing import BinaryIO, Union
|
||||
|
||||
from ..storage import get_file_storage
|
||||
from ..exceptions import FileContentHashingError
|
||||
|
||||
|
||||
async def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
|
||||
h = hashlib.md5()
|
||||
|
||||
try:
|
||||
if isinstance(file_obj, str):
|
||||
# Normalize path separators to handle mixed separators on Windows
|
||||
normalized_path = os.path.normpath(file_obj)
|
||||
|
||||
file_dir_path = path.dirname(normalized_path)
|
||||
file_path = path.basename(normalized_path)
|
||||
|
||||
file_storage = get_file_storage(file_dir_path)
|
||||
|
||||
async with file_storage.open(file_path, "rb") as file:
|
||||
while True:
|
||||
# Reading is buffered, so we can read smaller chunks.
|
||||
chunk = file.read(h.block_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
else:
|
||||
while True:
|
||||
# Reading is buffered, so we can read smaller chunks.
|
||||
chunk = file_obj.read(h.block_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
|
||||
return h.hexdigest()
|
||||
except IOError as e:
|
||||
raise FileContentHashingError(message=f"Failed to hash data from {file_obj}: {e}")
|
||||
|
|
@ -3,7 +3,7 @@ import os.path
|
|||
from typing import BinaryIO, TypedDict
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.utils import get_file_content_hash
|
||||
from cognee.infrastructure.files.utils.get_file_content_hash import get_file_content_hash
|
||||
from .guess_file_type import guess_file_type
|
||||
|
||||
logger = get_logger("FileMetadata")
|
||||
|
|
@ -26,7 +26,7 @@ class FileMetadata(TypedDict):
|
|||
file_size: int
|
||||
|
||||
|
||||
def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
||||
async def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
||||
"""
|
||||
Retrieve metadata from a file object.
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
|||
"""
|
||||
try:
|
||||
file.seek(0)
|
||||
content_hash = get_file_content_hash(file)
|
||||
content_hash = await get_file_content_hash(file)
|
||||
file.seek(0)
|
||||
except io.UnsupportedOperation as error:
|
||||
logger.error(f"Error retrieving content hash for file: {file.name} \n{str(error)}\n\n")
|
||||
|
|
|
|||
65
cognee/infrastructure/files/utils/open_data_file.py
Normal file
65
cognee/infrastructure/files/utils/open_data_file.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import os
|
||||
from os import path
|
||||
from urllib.parse import urlparse
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_data_file(file_path: str, mode: str = "rb", encoding: str = None, **kwargs):
|
||||
# Check if this is a file URI BEFORE normalizing (which corrupts URIs)
|
||||
if file_path.startswith("file://"):
|
||||
# Normalize the file URI for Windows - replace backslashes with forward slashes
|
||||
normalized_file_uri = file_path.replace("\\", "/")
|
||||
|
||||
parsed_url = urlparse(normalized_file_uri)
|
||||
|
||||
# Convert URI path to file system path
|
||||
if os.name == "nt": # Windows
|
||||
# Handle Windows drive letters correctly
|
||||
fs_path = parsed_url.path
|
||||
if fs_path.startswith("/") and len(fs_path) > 1 and fs_path[2] == ":":
|
||||
fs_path = fs_path[1:] # Remove leading slash for Windows drive paths
|
||||
else: # Unix-like systems
|
||||
fs_path = parsed_url.path
|
||||
|
||||
# Now split the actual filesystem path
|
||||
actual_fs_path = os.path.normpath(fs_path)
|
||||
file_dir_path = path.dirname(actual_fs_path)
|
||||
file_name = path.basename(actual_fs_path)
|
||||
|
||||
elif file_path.startswith("s3://"):
|
||||
# Handle S3 URLs without normalization (which corrupts them)
|
||||
parsed_url = urlparse(file_path)
|
||||
|
||||
# For S3, reconstruct the directory path and filename
|
||||
s3_path = parsed_url.path.lstrip("/") # Remove leading slash
|
||||
|
||||
if "/" in s3_path:
|
||||
s3_dir = "/".join(s3_path.split("/")[:-1])
|
||||
s3_filename = s3_path.split("/")[-1]
|
||||
else:
|
||||
s3_dir = ""
|
||||
s3_filename = s3_path
|
||||
|
||||
# Extract filesystem path from S3 URL structure
|
||||
file_dir_path = (
|
||||
f"s3://{parsed_url.netloc}/{s3_dir}" if s3_dir else f"s3://{parsed_url.netloc}"
|
||||
)
|
||||
file_name = s3_filename
|
||||
|
||||
else:
|
||||
# Regular file path - normalize separators
|
||||
normalized_path = os.path.normpath(file_path)
|
||||
file_dir_path = path.dirname(normalized_path)
|
||||
file_name = path.basename(normalized_path)
|
||||
|
||||
# Validate that we have a proper filename
|
||||
if not file_name or file_name == "." or file_name == "..":
|
||||
raise ValueError(f"Invalid filename extracted: '{file_name}' from path: '{file_path}'")
|
||||
|
||||
file_storage = get_file_storage(file_dir_path)
|
||||
|
||||
async with file_storage.open(file_name, mode=mode, encoding=encoding, **kwargs) as file:
|
||||
yield file
|
||||
|
|
@ -1,16 +1,15 @@
|
|||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
import base64
|
||||
import instructor
|
||||
from typing import Type
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
rate_limit_sync,
|
||||
sleep_and_retry_async,
|
||||
)
|
||||
from openai import OpenAI
|
||||
import base64
|
||||
import os
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
|
||||
|
||||
class OllamaAPIAdapter(LLMInterface):
|
||||
|
|
@ -87,8 +86,8 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
|
||||
return response
|
||||
|
||||
@rate_limit_sync
|
||||
def create_transcript(self, input_file: str) -> str:
|
||||
@rate_limit_async
|
||||
async def create_transcript(self, input_file: str) -> str:
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
|
|
@ -107,10 +106,7 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
- str: The transcription of the audio as a string.
|
||||
"""
|
||||
|
||||
if not os.path.isfile(input_file):
|
||||
raise FileNotFoundError(f"The file {input_file} does not exist.")
|
||||
|
||||
with open(input_file, "rb") as audio_file:
|
||||
async with open_data_file(input_file, mode="rb") as audio_file:
|
||||
transcription = self.aclient.audio.transcriptions.create(
|
||||
model="whisper-1", # Ensure the correct model for transcription
|
||||
file=audio_file,
|
||||
|
|
@ -123,8 +119,8 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
|
||||
return transcription.text
|
||||
|
||||
@rate_limit_sync
|
||||
def transcribe_image(self, input_file: str) -> str:
|
||||
@rate_limit_async
|
||||
async def transcribe_image(self, input_file: str) -> str:
|
||||
"""
|
||||
Transcribe content from an image using base64 encoding.
|
||||
|
||||
|
|
@ -144,10 +140,7 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
- str: The transcription of the image's content as a string.
|
||||
"""
|
||||
|
||||
if not os.path.isfile(input_file):
|
||||
raise FileNotFoundError(f"The file {input_file} does not exist.")
|
||||
|
||||
with open(input_file, "rb") as image_file:
|
||||
async with open_data_file(input_file, mode="rb") as image_file:
|
||||
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
response = self.aclient.chat.completions.create(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import base64
|
||||
import litellm
|
||||
import instructor
|
||||
|
|
@ -12,7 +11,7 @@ from cognee.exceptions import InvalidValueError
|
|||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.infrastructure.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
rate_limit_sync,
|
||||
|
|
@ -222,8 +221,8 @@ class OpenAIAdapter(LLMInterface):
|
|||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
@rate_limit_sync
|
||||
def create_transcript(self, input):
|
||||
@rate_limit_async
|
||||
async def create_transcript(self, input):
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
|
|
@ -242,10 +241,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
The generated transcription of the audio file.
|
||||
"""
|
||||
|
||||
if not input.startswith("s3://") and not os.path.isfile(input):
|
||||
raise FileNotFoundError(f"The file {input} does not exist.")
|
||||
|
||||
with open_data_file(input, mode="rb") as audio_file:
|
||||
async with open_data_file(input, mode="rb") as audio_file:
|
||||
transcription = litellm.transcription(
|
||||
model=self.transcription_model,
|
||||
file=audio_file,
|
||||
|
|
@ -257,8 +253,8 @@ class OpenAIAdapter(LLMInterface):
|
|||
|
||||
return transcription
|
||||
|
||||
@rate_limit_sync
|
||||
def transcribe_image(self, input) -> BaseModel:
|
||||
@rate_limit_async
|
||||
async def transcribe_image(self, input) -> BaseModel:
|
||||
"""
|
||||
Generate a transcription of an image from a user query.
|
||||
|
||||
|
|
@ -276,7 +272,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
- BaseModel: A structured output generated by the model, returned as an instance of
|
||||
BaseModel.
|
||||
"""
|
||||
with open_data_file(input, mode="rb") as image_file:
|
||||
async with open_data_file(input, mode="rb") as image_file:
|
||||
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
return litellm.completion(
|
||||
|
|
|
|||
13
cognee/infrastructure/utils/run_async.py
Normal file
13
cognee/infrastructure/utils/run_async.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import asyncio
|
||||
from functools import partial
|
||||
|
||||
|
||||
async def run_async(func, *args, loop=None, executor=None, **kwargs):
|
||||
if loop is None:
|
||||
try:
|
||||
running_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
running_loop = asyncio.get_event_loop()
|
||||
|
||||
pfunc = partial(func, *args, **kwargs)
|
||||
return await running_loop.run_in_executor(executor, pfunc)
|
||||
25
cognee/infrastructure/utils/run_sync.py
Normal file
25
cognee/infrastructure/utils/run_sync.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import asyncio
|
||||
import threading
|
||||
|
||||
|
||||
def run_sync(coro, timeout=None):
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
def runner():
|
||||
nonlocal result, exception
|
||||
try:
|
||||
result = asyncio.run(coro)
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
thread = threading.Thread(target=runner)
|
||||
thread.start()
|
||||
thread.join(timeout)
|
||||
|
||||
if thread.is_alive():
|
||||
raise asyncio.TimeoutError("Coroutine execution timed out.")
|
||||
if exception:
|
||||
raise exception
|
||||
|
||||
return result
|
||||
|
|
@ -33,8 +33,8 @@ class LangchainChunker(Chunker):
|
|||
length_function=lambda text: len(text.split()),
|
||||
)
|
||||
|
||||
def read(self):
|
||||
for content_text in self.get_text():
|
||||
async def read(self):
|
||||
async for content_text in self.get_text():
|
||||
for chunk in self.splitter.split_text(content_text):
|
||||
embedding_engine = get_vector_engine().embedding_engine
|
||||
token_count = embedding_engine.tokenizer.count_tokens(chunk)
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ logger = get_logger()
|
|||
|
||||
|
||||
class TextChunker(Chunker):
|
||||
def read(self):
|
||||
async def read(self):
|
||||
paragraph_chunks = []
|
||||
for content_text in self.get_text():
|
||||
async for content_text in self.get_text():
|
||||
for chunk_data in chunk_by_paragraph(
|
||||
content_text,
|
||||
self.max_chunk_size,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
|
||||
|
||||
async def prune_data():
|
||||
base_config = get_base_config()
|
||||
data_root_directory = base_config.data_root_directory
|
||||
LocalStorage.remove_all(data_root_directory)
|
||||
storage_config = get_storage_config()
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
await get_file_storage(data_root_directory).remove_all()
|
||||
|
|
|
|||
|
|
@ -7,15 +7,16 @@ from .Document import Document
|
|||
class AudioDocument(Document):
|
||||
type: str = "audio"
|
||||
|
||||
def create_transcript(self):
|
||||
result = get_llm_client().create_transcript(self.raw_data_location)
|
||||
async def create_transcript(self):
|
||||
result = await get_llm_client().create_transcript(self.raw_data_location)
|
||||
return result.text
|
||||
|
||||
def read(self, chunker_cls: Chunker, max_chunk_size: int):
|
||||
# Transcribe the audio file
|
||||
async def read(self, chunker_cls: Chunker, max_chunk_size: int):
|
||||
async def get_text():
|
||||
# Transcribe the audio file
|
||||
yield await self.create_transcript()
|
||||
|
||||
text = self.create_transcript()
|
||||
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
|
||||
|
||||
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=lambda: [text])
|
||||
|
||||
yield from chunker.read()
|
||||
async for chunk in chunker.read():
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -10,5 +10,5 @@ class Document(DataPoint):
|
|||
mime_type: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
def read(self, chunker_cls: Chunker, max_chunk_size: int) -> str:
|
||||
async def read(self, chunker_cls: Chunker, max_chunk_size: int) -> str:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -7,14 +7,16 @@ from .Document import Document
|
|||
class ImageDocument(Document):
|
||||
type: str = "image"
|
||||
|
||||
def transcribe_image(self):
|
||||
result = get_llm_client().transcribe_image(self.raw_data_location)
|
||||
async def transcribe_image(self):
|
||||
result = await get_llm_client().transcribe_image(self.raw_data_location)
|
||||
return result.choices[0].message.content
|
||||
|
||||
def read(self, chunker_cls: Chunker, max_chunk_size: int):
|
||||
# Transcribe the image file
|
||||
text = self.transcribe_image()
|
||||
async def read(self, chunker_cls: Chunker, max_chunk_size: int):
|
||||
async def get_text():
|
||||
# Transcribe the image file
|
||||
yield await self.transcribe_image()
|
||||
|
||||
chunker = chunker_cls(self, get_text=lambda: [text], max_chunk_size=max_chunk_size)
|
||||
chunker = chunker_cls(self, get_text=get_text, max_chunk_size=max_chunk_size)
|
||||
|
||||
yield from chunker.read()
|
||||
async for chunk in chunker.read():
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -1,25 +1,27 @@
|
|||
from pypdf import PdfReader
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from .open_data_file import open_data_file
|
||||
from .Document import Document
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.data.processing.document_types.exceptions.exceptions import PyPdfInternalError
|
||||
|
||||
from .Document import Document
|
||||
|
||||
logger = get_logger("PDFDocument")
|
||||
|
||||
|
||||
class PdfDocument(Document):
|
||||
type: str = "pdf"
|
||||
|
||||
def read(self, chunker_cls: Chunker, max_chunk_size: int):
|
||||
with open_data_file(self.raw_data_location, mode="rb") as stream:
|
||||
async def read(self, chunker_cls: Chunker, max_chunk_size: int):
|
||||
async with open_data_file(self.raw_data_location, mode="rb") as stream:
|
||||
logger.info(f"Reading PDF:{self.raw_data_location}")
|
||||
|
||||
try:
|
||||
file = PdfReader(stream, strict=False)
|
||||
except Exception:
|
||||
raise PyPdfInternalError()
|
||||
|
||||
def get_text():
|
||||
async def get_text():
|
||||
try:
|
||||
for page in file.pages:
|
||||
page_text = page.extract_text()
|
||||
|
|
@ -29,4 +31,5 @@ class PdfDocument(Document):
|
|||
|
||||
chunker = chunker_cls(self, get_text=get_text, max_chunk_size=max_chunk_size)
|
||||
|
||||
yield from chunker.read()
|
||||
async for chunk in chunker.read():
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
from .Document import Document
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from .open_data_file import open_data_file
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class TextDocument(Document):
|
||||
type: str = "text"
|
||||
mime_type: str = "text/plain"
|
||||
|
||||
def read(self, chunker_cls: Chunker, max_chunk_size: int):
|
||||
def get_text():
|
||||
with open_data_file(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
||||
async def read(self, chunker_cls: Chunker, max_chunk_size: int):
|
||||
async def get_text():
|
||||
async with open_data_file(self.raw_data_location, mode="r", encoding="utf-8") as file:
|
||||
while True:
|
||||
text = file.read(1000000)
|
||||
if not text.strip():
|
||||
|
|
@ -17,4 +17,6 @@ class TextDocument(Document):
|
|||
yield text
|
||||
|
||||
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
|
||||
yield from chunker.read()
|
||||
|
||||
async for chunk in chunker.read():
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from io import StringIO
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from cognee.modules.data.exceptions import UnstructuredLibraryImportError
|
||||
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
|
||||
from .Document import Document
|
||||
|
||||
|
|
@ -10,15 +11,15 @@ from .Document import Document
|
|||
class UnstructuredDocument(Document):
|
||||
type: str = "unstructured"
|
||||
|
||||
def read(self, chunker_cls: Chunker, max_chunk_size: int) -> str:
|
||||
def get_text():
|
||||
async def read(self, chunker_cls: Chunker, max_chunk_size: int) -> AsyncGenerator[Any, Any]:
|
||||
async def get_text():
|
||||
try:
|
||||
from unstructured.partition.auto import partition
|
||||
except ModuleNotFoundError:
|
||||
raise UnstructuredLibraryImportError
|
||||
|
||||
if self.raw_data_location.startswith("s3://"):
|
||||
with open_data_file(self.raw_data_location, mode="rb") as f:
|
||||
async with open_data_file(self.raw_data_location, mode="rb") as f:
|
||||
elements = partition(file=f, content_type=self.mime_type)
|
||||
else:
|
||||
elements = partition(self.raw_data_location, content_type=self.mime_type)
|
||||
|
|
@ -34,4 +35,5 @@ class UnstructuredDocument(Document):
|
|||
|
||||
chunker = chunker_cls(self, get_text=get_text, max_chunk_size=max_chunk_size)
|
||||
|
||||
yield from chunker.read()
|
||||
async for chunk in chunker.read():
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -1,41 +0,0 @@
|
|||
from typing import IO, Optional
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
from cognee.api.v1.add.config import get_s3_config
|
||||
|
||||
|
||||
def open_data_file(
|
||||
file_path: str, mode: str = "rb", encoding: Optional[str] = None, **kwargs
|
||||
) -> IO:
|
||||
if file_path.startswith("s3://"):
|
||||
s3_config = get_s3_config()
|
||||
if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None:
|
||||
import s3fs
|
||||
|
||||
fs = s3fs.S3FileSystem(
|
||||
key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, anon=False
|
||||
)
|
||||
else:
|
||||
raise ValueError("S3 credentials are not set in the configuration.")
|
||||
|
||||
if "b" in mode:
|
||||
f = fs.open(file_path, mode=mode, **kwargs)
|
||||
if not hasattr(f, "name") or not f.name:
|
||||
f.name = file_path.split("/")[-1]
|
||||
return f
|
||||
else:
|
||||
return fs.open(file_path, mode=mode, encoding=encoding, **kwargs)
|
||||
elif file_path.startswith("file://"):
|
||||
# Handle local file URLs by properly parsing the URI
|
||||
parsed_url = urlparse(file_path)
|
||||
# On Windows, urlparse handles drive letters correctly
|
||||
# Convert the path component to a proper file path
|
||||
if os.name == "nt": # Windows
|
||||
# Remove leading slash from Windows paths like /C:/Users/...
|
||||
local_path = parsed_url.path.lstrip("/")
|
||||
else: # Unix-like systems
|
||||
local_path = parsed_url.path
|
||||
|
||||
return open(local_path, mode=mode, encoding=encoding, **kwargs)
|
||||
else:
|
||||
return open(file_path, mode=mode, encoding=encoding, **kwargs)
|
||||
|
|
@ -1,29 +1,29 @@
|
|||
from os import path
|
||||
from io import BufferedReader
|
||||
from typing import Union, BinaryIO, Optional, Any
|
||||
from .data_types import TextData, BinaryData
|
||||
from typing import Union, BinaryIO
|
||||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
from cognee.modules.ingestion.exceptions import IngestionError
|
||||
|
||||
try:
|
||||
from s3fs.core import S3File
|
||||
from cognee.modules.ingestion.data_types.S3BinaryData import S3BinaryData
|
||||
except ImportError:
|
||||
S3File = None
|
||||
S3BinaryData = None
|
||||
from .data_types import TextData, BinaryData, S3BinaryData
|
||||
|
||||
|
||||
def classify(data: Union[str, BinaryIO], filename: str = None, s3fs: Optional[Any] = None):
|
||||
def classify(data: Union[str, BinaryIO], filename: str = None):
|
||||
if isinstance(data, str):
|
||||
return TextData(data)
|
||||
|
||||
if isinstance(data, BufferedReader) or isinstance(data, SpooledTemporaryFile):
|
||||
return BinaryData(data, str(data.name).split("/")[-1] if data.name else filename)
|
||||
return BinaryData(
|
||||
data, str(data.name).split("/")[-1] if hasattr(data, "name") else filename
|
||||
)
|
||||
|
||||
try:
|
||||
from s3fs import S3File
|
||||
except ImportError:
|
||||
S3File = None
|
||||
|
||||
if S3File is not None:
|
||||
if isinstance(data, S3File):
|
||||
derived_filename = str(data.full_name).split("/")[-1] if data.full_name else filename
|
||||
return S3BinaryData(s3_path=data.full_name, name=derived_filename, s3=s3fs)
|
||||
return S3BinaryData(s3_path=path.join("s3://", data.bucket, data.key), name=data.key)
|
||||
|
||||
raise IngestionError(
|
||||
message=f"Type of data sent to classify(data: Union[str, BinaryIO) not supported or s3fs is not installed: {type(data)}"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from typing import BinaryIO
|
||||
from contextlib import asynccontextmanager
|
||||
from cognee.infrastructure.files import get_file_metadata, FileMetadata
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
from .IngestionData import IngestionData
|
||||
|
||||
|
||||
|
|
@ -22,16 +24,17 @@ class BinaryData(IngestionData):
|
|||
return metadata["content_hash"]
|
||||
|
||||
def get_metadata(self):
|
||||
self.ensure_metadata()
|
||||
run_sync(self.ensure_metadata())
|
||||
|
||||
return self.metadata
|
||||
|
||||
def ensure_metadata(self):
|
||||
async def ensure_metadata(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = get_file_metadata(self.data)
|
||||
self.metadata = await get_file_metadata(self.data)
|
||||
|
||||
if self.metadata["name"] is None:
|
||||
self.metadata["name"] = self.name
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
@asynccontextmanager
|
||||
async def get_data(self):
|
||||
yield self.data
|
||||
|
|
|
|||
|
|
@ -1,42 +1,55 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
import s3fs
|
||||
from contextlib import asynccontextmanager
|
||||
from cognee.infrastructure.files import get_file_metadata, FileMetadata
|
||||
from cognee.infrastructure.utils import run_sync
|
||||
from .IngestionData import IngestionData
|
||||
|
||||
|
||||
def create_s3_binary_data(
|
||||
s3_path: str, name: Optional[str] = None, s3: Optional[s3fs.S3FileSystem] = None
|
||||
) -> "S3BinaryData":
|
||||
return S3BinaryData(s3_path, name=name, s3=s3)
|
||||
def create_s3_binary_data(s3_path: str, name: Optional[str] = None) -> "S3BinaryData":
|
||||
return S3BinaryData(s3_path, name=name)
|
||||
|
||||
|
||||
class S3BinaryData(IngestionData):
|
||||
name: Optional[str] = None
|
||||
s3_path: str = None
|
||||
fs: s3fs.S3FileSystem = None
|
||||
metadata: Optional[FileMetadata] = None
|
||||
|
||||
def __init__(
|
||||
self, s3_path: str, name: Optional[str] = None, s3: Optional[s3fs.S3FileSystem] = None
|
||||
):
|
||||
def __init__(self, s3_path: str, name: Optional[str] = None):
|
||||
self.s3_path = s3_path
|
||||
self.name = name
|
||||
self.fs = s3 if s3 is not None else s3fs.S3FileSystem()
|
||||
|
||||
def get_identifier(self):
|
||||
metadata = self.get_metadata()
|
||||
return metadata["content_hash"]
|
||||
|
||||
def get_metadata(self):
|
||||
self.ensure_metadata()
|
||||
run_sync(self.ensure_metadata())
|
||||
return self.metadata
|
||||
|
||||
def ensure_metadata(self):
|
||||
async def ensure_metadata(self):
|
||||
if self.metadata is None:
|
||||
with self.fs.open(self.s3_path, "rb") as f:
|
||||
self.metadata = get_file_metadata(f)
|
||||
if self.metadata.get("name") is None:
|
||||
self.metadata["name"] = self.name or self.s3_path.split("/")[-1]
|
||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||
|
||||
def get_data(self):
|
||||
return self.fs.open(self.s3_path, "rb")
|
||||
file_dir_path = os.path.dirname(self.s3_path)
|
||||
file_path = os.path.basename(self.s3_path)
|
||||
|
||||
file_storage = S3FileStorage(file_dir_path)
|
||||
|
||||
async with file_storage.open(file_path, "rb") as file:
|
||||
self.metadata = await get_file_metadata(file)
|
||||
|
||||
if self.metadata.get("name") is None:
|
||||
self.metadata["name"] = self.name or file_path
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_data(self):
|
||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||
|
||||
file_dir_path = os.path.dirname(self.s3_path)
|
||||
file_path = os.path.basename(self.s3_path)
|
||||
|
||||
file_storage = S3FileStorage(file_dir_path)
|
||||
|
||||
async with file_storage.open(file_path, "rb") as file:
|
||||
yield file
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import BinaryIO
|
||||
from contextlib import asynccontextmanager
|
||||
from cognee.infrastructure.data.utils.extract_keywords import extract_keywords
|
||||
from .IngestionData import IngestionData
|
||||
|
||||
|
|
@ -28,5 +29,6 @@ class TextData(IngestionData):
|
|||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
@asynccontextmanager
|
||||
async def get_data(self):
|
||||
yield self.data
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .TextData import TextData, create_text_data
|
||||
from .BinaryData import BinaryData, create_binary_data
|
||||
from .S3BinaryData import S3BinaryData, create_s3_binary_data
|
||||
from .IngestionData import IngestionData
|
||||
|
|
|
|||
|
|
@ -1,29 +1,28 @@
|
|||
import os.path
|
||||
import hashlib
|
||||
from typing import BinaryIO, Union
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
from .classify import classify
|
||||
|
||||
|
||||
def save_data_to_file(data: Union[str, BinaryIO], filename: str = None):
|
||||
base_config = get_base_config()
|
||||
data_directory_path = base_config.data_root_directory
|
||||
async def save_data_to_file(data: Union[str, BinaryIO], filename: str = None):
|
||||
storage_config = get_storage_config()
|
||||
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
|
||||
classified_data = classify(data, filename)
|
||||
|
||||
storage_path = os.path.join(data_directory_path, "data")
|
||||
LocalStorage.ensure_directory_exists(storage_path)
|
||||
|
||||
file_metadata = classified_data.get_metadata()
|
||||
if "name" not in file_metadata or file_metadata["name"] is None:
|
||||
data_contents = classified_data.get_data().encode("utf-8")
|
||||
hash_contents = hashlib.md5(data_contents).hexdigest()
|
||||
file_metadata["name"] = "text_" + hash_contents + ".txt"
|
||||
file_name = file_metadata["name"]
|
||||
|
||||
# Don't save file if it already exists
|
||||
if not os.path.isfile(os.path.join(storage_path, file_name)):
|
||||
LocalStorage(storage_path).store(file_name, classified_data.get_data())
|
||||
async with classified_data.get_data() as data:
|
||||
if "name" not in file_metadata or file_metadata["name"] is None:
|
||||
data_contents = data.encode("utf-8")
|
||||
hash_contents = hashlib.md5(data_contents).hexdigest()
|
||||
file_metadata["name"] = "text_" + hash_contents + ".txt"
|
||||
|
||||
return "file://" + storage_path + "/" + file_name
|
||||
file_name = file_metadata["name"]
|
||||
|
||||
storage = get_file_storage(data_root_directory)
|
||||
|
||||
full_file_path = await storage.store(file_name, data)
|
||||
|
||||
return full_file_path
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
from uuid import UUID
|
||||
from typing import Union
|
||||
from uuid import NAMESPACE_OID, uuid5, UUID
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from uuid import UUID
|
|||
from typing import Any
|
||||
from functools import wraps
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -104,6 +105,14 @@ async def run_tasks(
|
|||
dataset_name=dataset.name,
|
||||
)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
if hasattr(graph_engine, "push_to_s3"):
|
||||
await graph_engine.push_to_s3()
|
||||
|
||||
relational_engine = get_relational_engine()
|
||||
if hasattr(relational_engine, "push_to_s3"):
|
||||
await relational_engine.push_to_s3()
|
||||
|
||||
except Exception as error:
|
||||
await log_pipeline_run_error(
|
||||
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
import networkx
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.infrastructure.files.storage.LocalFileStorage import LocalFileStorage
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -308,10 +308,12 @@ async def cognee_network_visualization(graph_data, destination_file_path: str =
|
|||
home_dir = os.path.expanduser("~")
|
||||
destination_file_path = os.path.join(home_dir, "graph_visualization.html")
|
||||
|
||||
LocalStorage.ensure_directory_exists(os.path.dirname(destination_file_path))
|
||||
dir_path = os.path.dirname(destination_file_path)
|
||||
file_path = os.path.basename(destination_file_path)
|
||||
|
||||
with open(destination_file_path, "w") as f:
|
||||
f.write(html_content)
|
||||
file_storage = LocalFileStorage(dir_path)
|
||||
|
||||
file_storage.store(file_path, html_content, overwrite=True)
|
||||
|
||||
logger.info(f"Graph visualization saved as {destination_file_path}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,7 @@
|
|||
"""This module contains utility functions for the cognee."""
|
||||
|
||||
import os
|
||||
from typing import BinaryIO, Union
|
||||
|
||||
import requests
|
||||
import hashlib
|
||||
from datetime import datetime, timezone
|
||||
import graphistry
|
||||
import networkx as nx
|
||||
|
|
@ -13,14 +10,12 @@ import matplotlib.pyplot as plt
|
|||
import http.server
|
||||
import socketserver
|
||||
from threading import Thread
|
||||
import sys
|
||||
import pathlib
|
||||
from uuid import uuid4
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
from uuid import uuid4
|
||||
import pathlib
|
||||
from cognee.shared.exceptions import IngestionError
|
||||
|
||||
# Analytics Proxy Url, currently hosted by Vercel
|
||||
proxy_url = "https://test.prometh.ai"
|
||||
|
|
@ -102,182 +97,6 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
|
|||
print(f"Error sending telemetry through proxy: {response.status_code}")
|
||||
|
||||
|
||||
def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
|
||||
h = hashlib.md5()
|
||||
|
||||
try:
|
||||
if isinstance(file_obj, str):
|
||||
with open(file_obj, "rb") as file:
|
||||
while True:
|
||||
# Reading is buffered, so we can read smaller chunks.
|
||||
chunk = file.read(h.block_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
else:
|
||||
while True:
|
||||
# Reading is buffered, so we can read smaller chunks.
|
||||
chunk = file_obj.read(h.block_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
|
||||
return h.hexdigest()
|
||||
except IOError as e:
|
||||
raise IngestionError(message=f"Failed to load data from {file}: {e}")
|
||||
|
||||
|
||||
def generate_color_palette(unique_layers):
|
||||
colormap = plt.cm.get_cmap("viridis", len(unique_layers))
|
||||
colors = [colormap(i) for i in range(len(unique_layers))]
|
||||
hex_colors = [
|
||||
"#%02x%02x%02x" % (int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
|
||||
for rgb in colors
|
||||
]
|
||||
|
||||
return dict(zip(unique_layers, hex_colors))
|
||||
|
||||
|
||||
async def register_graphistry():
|
||||
config = get_base_config()
|
||||
graphistry.register(
|
||||
api=3, username=config.graphistry_username, password=config.graphistry_password
|
||||
)
|
||||
|
||||
|
||||
def prepare_edges(graph, source, target, edge_key):
|
||||
edge_list = [
|
||||
{
|
||||
source: str(edge[0]),
|
||||
target: str(edge[1]),
|
||||
edge_key: str(edge[2]),
|
||||
}
|
||||
for edge in graph.edges(keys=True, data=True)
|
||||
]
|
||||
|
||||
return pd.DataFrame(edge_list)
|
||||
|
||||
|
||||
def prepare_nodes(graph, include_size=False):
|
||||
nodes_data = []
|
||||
for node in graph.nodes:
|
||||
node_info = graph.nodes[node]
|
||||
|
||||
if not node_info:
|
||||
continue
|
||||
|
||||
node_data = {
|
||||
**node_info,
|
||||
"id": str(node),
|
||||
"name": node_info["name"] if "name" in node_info else str(node),
|
||||
}
|
||||
|
||||
if include_size:
|
||||
default_size = 10 # Default node size
|
||||
larger_size = 20 # Size for nodes with specific keywords in their ID
|
||||
keywords = ["DOCUMENT", "User"]
|
||||
node_size = (
|
||||
larger_size if any(keyword in str(node) for keyword in keywords) else default_size
|
||||
)
|
||||
node_data["size"] = node_size
|
||||
|
||||
nodes_data.append(node_data)
|
||||
|
||||
return pd.DataFrame(nodes_data)
|
||||
|
||||
|
||||
async def render_graph(
|
||||
graph=None, include_nodes=True, include_color=False, include_size=False, include_labels=True
|
||||
):
|
||||
await register_graphistry()
|
||||
|
||||
if not isinstance(graph, nx.MultiDiGraph):
|
||||
graph_engine = await get_graph_engine()
|
||||
networkx_graph = nx.MultiDiGraph()
|
||||
|
||||
(nodes, edges) = await graph_engine.get_graph_data()
|
||||
|
||||
networkx_graph.add_nodes_from(nodes)
|
||||
networkx_graph.add_edges_from(edges)
|
||||
|
||||
graph = networkx_graph
|
||||
|
||||
edges = prepare_edges(graph, "source_node", "target_node", "relationship_name")
|
||||
plotter = graphistry.edges(edges, "source_node", "target_node")
|
||||
plotter = plotter.bind(edge_label="relationship_name")
|
||||
|
||||
if include_nodes:
|
||||
nodes = prepare_nodes(graph, include_size=include_size)
|
||||
plotter = plotter.nodes(nodes, "id")
|
||||
|
||||
if include_size:
|
||||
plotter = plotter.bind(point_size="size")
|
||||
|
||||
if include_color:
|
||||
pass
|
||||
# unique_layers = nodes["layer_description"].unique()
|
||||
# color_palette = generate_color_palette(unique_layers)
|
||||
# plotter = plotter.encode_point_color("layer_description", categorical_mapping=color_palette,
|
||||
# default_mapping="silver")
|
||||
|
||||
if include_labels:
|
||||
plotter = plotter.bind(point_label="name")
|
||||
|
||||
# Visualization
|
||||
url = plotter.plot(render=False, as_files=True, memoize=False)
|
||||
print(f"Graph is visualized at: {url}")
|
||||
return url
|
||||
|
||||
|
||||
# def sanitize_df(df):
|
||||
# """Replace NaNs and infinities in a DataFrame with None, making it JSON compliant."""
|
||||
# return df.replace([np.inf, -np.inf, np.nan], None)
|
||||
|
||||
|
||||
async def convert_to_serializable_graph(G):
|
||||
"""
|
||||
Convert a graph into a serializable format with stringified node and edge attributes.
|
||||
"""
|
||||
|
||||
(nodes, edges) = G
|
||||
|
||||
networkx_graph = nx.MultiDiGraph()
|
||||
networkx_graph.add_nodes_from(nodes)
|
||||
networkx_graph.add_edges_from(edges)
|
||||
|
||||
# Create a new graph to store the serializable version
|
||||
new_G = nx.MultiDiGraph()
|
||||
|
||||
# Serialize nodes
|
||||
for node, data in networkx_graph.nodes(data=True):
|
||||
serializable_data = {k: str(v) for k, v in data.items()}
|
||||
new_G.add_node(str(node), **serializable_data)
|
||||
|
||||
# Serialize edges
|
||||
for u, v, data in networkx_graph.edges(data=True):
|
||||
serializable_data = {k: str(v) for k, v in data.items()}
|
||||
new_G.add_edge(str(u), str(v), **serializable_data)
|
||||
|
||||
return new_G
|
||||
|
||||
|
||||
def generate_layout_positions(G, layout_func, layout_scale):
|
||||
"""
|
||||
Generate layout positions for the graph using the specified layout function.
|
||||
"""
|
||||
positions = layout_func(G)
|
||||
return {str(node): (x * layout_scale, y * layout_scale) for node, (x, y) in positions.items()}
|
||||
|
||||
|
||||
def assign_node_colors(G, node_attribute, palette):
|
||||
"""
|
||||
Assign colors to nodes based on a specified attribute and a given palette.
|
||||
"""
|
||||
unique_attrs = set(G.nodes[node].get(node_attribute, "Unknown") for node in G.nodes)
|
||||
color_map = {attr: palette[i % len(palette)] for i, attr in enumerate(unique_attrs)}
|
||||
return [color_map[G.nodes[node].get(node_attribute, "Unknown")] for node in G.nodes], color_map
|
||||
|
||||
|
||||
def embed_logo(p, layout_scale, logo_alpha, position):
|
||||
"""
|
||||
Embed a logo into the graph visualization as a watermark.
|
||||
|
|
@ -307,18 +126,6 @@ def embed_logo(p, layout_scale, logo_alpha, position):
|
|||
)
|
||||
|
||||
|
||||
def graph_to_tuple(graph):
|
||||
"""
|
||||
Converts a networkx graph to a tuple of (nodes, edges).
|
||||
|
||||
:param graph: A networkx graph.
|
||||
:return: A tuple (nodes, edges).
|
||||
"""
|
||||
nodes = list(graph.nodes(data=True)) # Get nodes with attributes
|
||||
edges = list(graph.edges(data=True)) # Get edges with attributes
|
||||
return (nodes, edges)
|
||||
|
||||
|
||||
def start_visualization_server(
|
||||
host="0.0.0.0", port=8001, handler_class=http.server.SimpleHTTPRequestHandler
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.data.processing.document_types.Document import Document
|
||||
from sqlalchemy import select
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from uuid import UUID
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from cognee.modules.data.processing.document_types.exceptions.exceptions import PyPdfInternalError
|
||||
|
|
@ -41,7 +41,9 @@ async def extract_chunks_from_documents(
|
|||
for document in documents:
|
||||
document_token_count = 0
|
||||
try:
|
||||
for document_chunk in document.read(max_chunk_size=max_chunk_size, chunker_cls=chunker):
|
||||
async for document_chunk in document.read(
|
||||
max_chunk_size=max_chunk_size, chunker_cls=chunker
|
||||
):
|
||||
document_token_count += document_chunk.chunk_size
|
||||
document_chunk.belongs_to_set = document.belongs_to_set
|
||||
yield document_chunk
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import aiofiles
|
|||
import pandas as pd
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
|
||||
from cognee.modules.graph.exceptions import EntityNotFoundError
|
||||
from cognee.modules.ingestion.exceptions import IngestionError
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import json
|
||||
import inspect
|
||||
from os import path
|
||||
from uuid import UUID
|
||||
from typing import Union, BinaryIO, Any, List, Optional
|
||||
|
||||
import cognee.modules.ingestion as ingestion
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -18,9 +20,6 @@ from cognee.modules.data.methods import (
|
|||
from .save_data_item_to_storage import save_data_item_to_storage
|
||||
|
||||
|
||||
from cognee.api.v1.add.config import get_s3_config
|
||||
|
||||
|
||||
async def ingest_data(
|
||||
data: Any,
|
||||
dataset_name: str,
|
||||
|
|
@ -31,23 +30,6 @@ async def ingest_data(
|
|||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
s3_config = get_s3_config()
|
||||
|
||||
fs = None
|
||||
if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None:
|
||||
import s3fs
|
||||
|
||||
fs = s3fs.S3FileSystem(
|
||||
key=s3_config.aws_access_key_id, secret=s3_config.aws_secret_access_key, anon=False
|
||||
)
|
||||
|
||||
def open_data_file(file_path: str):
|
||||
if file_path.startswith("s3://"):
|
||||
return fs.open(file_path, mode="rb")
|
||||
else:
|
||||
local_path = file_path.replace("file://", "")
|
||||
return open(local_path, mode="rb")
|
||||
|
||||
def get_external_metadata_dict(data_item: Union[BinaryIO, str, Any]) -> dict[str, Any]:
|
||||
if hasattr(data_item, "dict") and inspect.ismethod(getattr(data_item, "dict")):
|
||||
return {"metadata": data_item.dict(), "origin": str(type(data_item))}
|
||||
|
|
@ -92,11 +74,11 @@ async def ingest_data(
|
|||
dataset_data_map = {str(data.id): True for data in dataset_data}
|
||||
|
||||
for data_item in data:
|
||||
file_path = await save_data_item_to_storage(data_item, dataset_name)
|
||||
file_path = await save_data_item_to_storage(data_item)
|
||||
|
||||
# Ingest data and add metadata
|
||||
with open_data_file(file_path) as file:
|
||||
classified_data = ingestion.classify(file, s3fs=fs)
|
||||
async with open_data_file(file_path) as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
|
||||
# data_id is the hash of file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
|
|
|
|||
|
|
@ -3,30 +3,49 @@ from typing import Union, BinaryIO, Any
|
|||
|
||||
from cognee.modules.ingestion.exceptions import IngestionError
|
||||
from cognee.modules.ingestion import save_data_to_file
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any], dataset_name: str) -> str:
|
||||
class SaveDataSettings(BaseSettings):
|
||||
accept_local_file_path: bool = True
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
||||
settings = SaveDataSettings()
|
||||
|
||||
|
||||
async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str:
|
||||
if "llama_index" in str(type(data_item)):
|
||||
# Dynamic import is used because the llama_index module is optional.
|
||||
from .transform_data import get_data_from_llama_index
|
||||
|
||||
file_path = get_data_from_llama_index(data_item, dataset_name)
|
||||
file_path = await get_data_from_llama_index(data_item)
|
||||
|
||||
# data is a file object coming from upload.
|
||||
elif hasattr(data_item, "file"):
|
||||
file_path = save_data_to_file(data_item.file, filename=data_item.filename)
|
||||
file_path = await save_data_to_file(data_item.file, filename=data_item.filename)
|
||||
|
||||
elif isinstance(data_item, str):
|
||||
if data_item.startswith("s3://"):
|
||||
# data is s3 file or local file path
|
||||
if data_item.startswith("s3://") or data_item.startswith("file://"):
|
||||
file_path = data_item
|
||||
# data is a file path
|
||||
elif data_item.startswith("file://") or data_item.startswith("/"):
|
||||
if os.getenv("ACCEPT_LOCAL_FILE_PATH", "true").lower() == "false":
|
||||
elif data_item.startswith("/") or (
|
||||
os.name == "nt" and len(data_item) > 1 and data_item[1] == ":"
|
||||
):
|
||||
# Handle both Unix absolute paths (/path) and Windows absolute paths (C:\path)
|
||||
if settings.accept_local_file_path:
|
||||
# Normalize path separators before creating file URL
|
||||
normalized_path = os.path.normpath(data_item)
|
||||
# Use forward slashes in file URLs for consistency
|
||||
url_path = normalized_path.replace(os.sep, "/")
|
||||
file_path = "file://" + url_path
|
||||
else:
|
||||
raise IngestionError(message="Local files are not accepted.")
|
||||
file_path = data_item.replace("file://", "")
|
||||
# data is text
|
||||
else:
|
||||
file_path = save_data_to_file(data_item)
|
||||
file_path = await save_data_to_file(data_item)
|
||||
else:
|
||||
raise IngestionError(message=f"Data type not supported: {type(data_item)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from cognee.modules.ingestion import save_data_to_file
|
|||
from typing import Union
|
||||
|
||||
|
||||
def get_data_from_llama_index(data_point: Union[Document, ImageDocument], dataset_name: str) -> str:
|
||||
async def get_data_from_llama_index(data_point: Union[Document, ImageDocument]) -> str:
|
||||
"""
|
||||
Retrieve the file path based on the data point type.
|
||||
|
||||
|
|
@ -17,7 +17,6 @@ def get_data_from_llama_index(data_point: Union[Document, ImageDocument], datase
|
|||
|
||||
- data_point (Union[Document, ImageDocument]): An instance of Document or
|
||||
ImageDocument to extract data from.
|
||||
- dataset_name (str): The name of the dataset associated with the data point.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -29,11 +28,11 @@ def get_data_from_llama_index(data_point: Union[Document, ImageDocument], datase
|
|||
if isinstance(data_point, Document) and type(data_point) is Document:
|
||||
file_path = data_point.metadata.get("file_path")
|
||||
if file_path is None:
|
||||
file_path = save_data_to_file(data_point.text)
|
||||
file_path = await save_data_to_file(data_point.text)
|
||||
return file_path
|
||||
return file_path
|
||||
elif isinstance(data_point, ImageDocument) and type(data_point) is ImageDocument:
|
||||
if data_point.image_path is None:
|
||||
file_path = save_data_to_file(data_point.text)
|
||||
file_path = await save_data_to_file(data_point.text)
|
||||
return file_path
|
||||
return data_point.image_path
|
||||
|
|
|
|||
|
|
@ -1,32 +0,0 @@
|
|||
import asyncio
|
||||
from cognee.shared.utils import render_graph
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_library")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_library")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
graph_client = await get_graph_engine()
|
||||
graph = graph_client.graph
|
||||
|
||||
graph_url = await render_graph(graph)
|
||||
|
||||
print(graph_url)
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,8 +1,11 @@
|
|||
import sys
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.processing.document_types.AudioDocument import AudioDocument
|
||||
import sys
|
||||
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
|
||||
|
||||
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
|
||||
|
||||
|
|
@ -38,7 +41,8 @@ TEST_TEXT = """
|
|||
@patch.object(
|
||||
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
|
||||
)
|
||||
def test_AudioDocument(mock_engine):
|
||||
@pytest.mark.asyncio
|
||||
async def test_AudioDocument(mock_engine):
|
||||
document = AudioDocument(
|
||||
id=uuid.uuid4(),
|
||||
name="audio-dummy-test",
|
||||
|
|
@ -47,7 +51,7 @@ def test_AudioDocument(mock_engine):
|
|||
mime_type="",
|
||||
)
|
||||
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
|
||||
for ground_truth, paragraph_data in zip(
|
||||
async for ground_truth, paragraph_data in async_gen_zip(
|
||||
GROUND_TRUTH,
|
||||
document.read(chunker_cls=TextChunker, max_chunk_size=64),
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import sys
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument
|
||||
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
|
||||
import sys
|
||||
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
|
||||
|
||||
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
|
||||
|
||||
|
|
@ -21,7 +24,8 @@ The commotion has attracted an audience: a murder of crows has gathered in the l
|
|||
@patch.object(
|
||||
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
|
||||
)
|
||||
def test_ImageDocument(mock_engine):
|
||||
@pytest.mark.asyncio
|
||||
async def test_ImageDocument(mock_engine):
|
||||
document = ImageDocument(
|
||||
id=uuid.uuid4(),
|
||||
name="image-dummy-test",
|
||||
|
|
@ -30,7 +34,7 @@ def test_ImageDocument(mock_engine):
|
|||
mime_type="",
|
||||
)
|
||||
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
|
||||
for ground_truth, paragraph_data in zip(
|
||||
async for ground_truth, paragraph_data in async_gen_zip(
|
||||
GROUND_TRUTH,
|
||||
document.read(chunker_cls=TextChunker, max_chunk_size=64),
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.processing.document_types.PdfDocument import PdfDocument
|
||||
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
|
||||
from unittest.mock import patch
|
||||
import sys
|
||||
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
|
||||
|
||||
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
|
||||
|
||||
|
|
@ -18,7 +21,8 @@ GROUND_TRUTH = [
|
|||
@patch.object(
|
||||
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
|
||||
)
|
||||
def test_PdfDocument(mock_engine):
|
||||
@pytest.mark.asyncio
|
||||
async def test_PdfDocument(mock_engine):
|
||||
test_file_path = os.path.join(
|
||||
os.sep,
|
||||
*(os.path.dirname(__file__).split(os.sep)[:-2]),
|
||||
|
|
@ -33,7 +37,7 @@ def test_PdfDocument(mock_engine):
|
|||
mime_type="",
|
||||
)
|
||||
|
||||
for ground_truth, paragraph_data in zip(
|
||||
async for ground_truth, paragraph_data in async_gen_zip(
|
||||
GROUND_TRUTH, document.read(chunker_cls=TextChunker, max_chunk_size=1024)
|
||||
):
|
||||
assert ground_truth["word_count"] == paragraph_data.chunk_size, (
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
import os
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.processing.document_types.TextDocument import TextDocument
|
||||
from unittest.mock import patch
|
||||
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
|
||||
import sys
|
||||
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
|
||||
|
||||
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
|
||||
|
||||
|
|
@ -30,7 +31,8 @@ GROUND_TRUTH = {
|
|||
@patch.object(
|
||||
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
|
||||
)
|
||||
def test_TextDocument(mock_engine, input_file, chunk_size):
|
||||
@pytest.mark.asyncio
|
||||
async def test_TextDocument(mock_engine, input_file, chunk_size):
|
||||
test_file_path = os.path.join(
|
||||
os.sep,
|
||||
*(os.path.dirname(__file__).split(os.sep)[:-2]),
|
||||
|
|
@ -45,7 +47,7 @@ def test_TextDocument(mock_engine, input_file, chunk_size):
|
|||
mime_type="",
|
||||
)
|
||||
|
||||
for ground_truth, paragraph_data in zip(
|
||||
async for ground_truth, paragraph_data in async_gen_zip(
|
||||
GROUND_TRUTH[input_file],
|
||||
document.read(chunker_cls=TextChunker, max_chunk_size=chunk_size),
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.data.processing.document_types.UnstructuredDocument import UnstructuredDocument
|
||||
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
|
||||
import sys
|
||||
|
||||
chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentence")
|
||||
|
||||
|
|
@ -12,7 +14,8 @@ chunk_by_sentence_module = sys.modules.get("cognee.tasks.chunks.chunk_by_sentenc
|
|||
@patch.object(
|
||||
chunk_by_sentence_module, "get_embedding_engine", side_effect=mock_get_embedding_engine
|
||||
)
|
||||
def test_UnstructuredDocument(mock_engine):
|
||||
@pytest.mark.asyncio
|
||||
async def test_UnstructuredDocument(mock_engine):
|
||||
# Define file paths of test data
|
||||
pptx_file_path = os.path.join(
|
||||
os.sep,
|
||||
|
|
@ -76,7 +79,7 @@ def test_UnstructuredDocument(mock_engine):
|
|||
)
|
||||
|
||||
# Test PPTX
|
||||
for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||
async for paragraph_data in pptx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||
assert 19 == paragraph_data.chunk_size, f" 19 != {paragraph_data.chunk_size = }"
|
||||
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
|
||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
||||
|
|
@ -84,7 +87,7 @@ def test_UnstructuredDocument(mock_engine):
|
|||
)
|
||||
|
||||
# Test DOCX
|
||||
for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||
async for paragraph_data in docx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||
assert 16 == paragraph_data.chunk_size, f" 16 != {paragraph_data.chunk_size = }"
|
||||
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
|
||||
assert "sentence_end" == paragraph_data.cut_type, (
|
||||
|
|
@ -92,7 +95,7 @@ def test_UnstructuredDocument(mock_engine):
|
|||
)
|
||||
|
||||
# TEST CSV
|
||||
for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||
async for paragraph_data in csv_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||
assert 15 == paragraph_data.chunk_size, f" 15 != {paragraph_data.chunk_size = }"
|
||||
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
|
||||
f"Read text doesn't match expected text: {paragraph_data.text}"
|
||||
|
|
@ -102,7 +105,7 @@ def test_UnstructuredDocument(mock_engine):
|
|||
)
|
||||
|
||||
# Test XLSX
|
||||
for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||
async for paragraph_data in xlsx_document.read(chunker_cls=TextChunker, max_chunk_size=1024):
|
||||
assert 36 == paragraph_data.chunk_size, f" 36 != {paragraph_data.chunk_size = }"
|
||||
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
|
||||
assert "sentence_cut" == paragraph_data.cut_type, (
|
||||
|
|
|
|||
12
cognee/tests/integration/documents/async_gen_zip.py
Normal file
12
cognee/tests/integration/documents/async_gen_zip.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
async def async_gen_zip(iterable1, async_iterable2):
|
||||
it1 = iter(iterable1)
|
||||
it2 = async_iterable2.__aiter__()
|
||||
|
||||
while True:
|
||||
try:
|
||||
val1 = next(it1)
|
||||
val2 = await it2.__anext__()
|
||||
|
||||
yield val1, val2
|
||||
except (StopIteration, StopAsyncIteration):
|
||||
break
|
||||
|
|
@ -1,8 +1,9 @@
|
|||
import os
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
import pathlib
|
||||
import cognee
|
||||
|
||||
import cognee
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.files.storage import get_storage_config
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
|
@ -24,12 +25,12 @@ async def test_local_file_deletion(data_text, file_location):
|
|||
data_hash = hashlib.md5(encoded_text).hexdigest()
|
||||
# Get data entry from database based on hash contents
|
||||
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
||||
assert os.path.isfile(data.raw_data_location), (
|
||||
assert os.path.isfile(data.raw_data_location.replace("file://", "")), (
|
||||
f"Data location doesn't exist: {data.raw_data_location}"
|
||||
)
|
||||
# Test deletion of data along with local files created by cognee
|
||||
await engine.delete_data_entity(data.id)
|
||||
assert not os.path.exists(data.raw_data_location), (
|
||||
assert not os.path.exists(data.raw_data_location.replace("file://", "")), (
|
||||
f"Data location still exists after deletion: {data.raw_data_location}"
|
||||
)
|
||||
|
||||
|
|
@ -38,12 +39,12 @@ async def test_local_file_deletion(data_text, file_location):
|
|||
data = (
|
||||
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
||||
).one()
|
||||
assert os.path.isfile(data.raw_data_location), (
|
||||
assert os.path.isfile(data.raw_data_location.replace("file://", "")), (
|
||||
f"Data location doesn't exist: {data.raw_data_location}"
|
||||
)
|
||||
# Test local files not created by cognee won't get deleted
|
||||
await engine.delete_data_entity(data.id)
|
||||
assert os.path.exists(data.raw_data_location), (
|
||||
assert os.path.exists(data.raw_data_location.replace("file://", "")), (
|
||||
f"Data location doesn't exists: {data.raw_data_location}"
|
||||
)
|
||||
|
||||
|
|
@ -157,7 +158,8 @@ async def main():
|
|||
assert len(history) == 8, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
tables_in_database = await vector_engine.get_collection_names()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from cognee.modules.search.operations import get_history
|
|||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.utils import render_graph
|
||||
from cognee.low_level import DataPoint
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -57,9 +56,6 @@ async def main():
|
|||
|
||||
await cognee.cognify(graph_model=ProgrammingLanguage)
|
||||
|
||||
url = await render_graph()
|
||||
print(f"Graphistry URL: {url}")
|
||||
|
||||
graph_file_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import hashlib
|
|||
import os
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
import pathlib
|
||||
import pytest
|
||||
|
||||
import cognee
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
|
@ -101,6 +102,7 @@ async def test_deduplication():
|
|||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplication_postgres():
|
||||
cognee.config.set_vector_db_config(
|
||||
{"vector_db_url": "", "vector_db_key": "", "vector_db_provider": "pgvector"}
|
||||
|
|
@ -119,6 +121,7 @@ async def test_deduplication_postgres():
|
|||
await test_deduplication()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplication_sqlite():
|
||||
cognee.config.set_vector_db_config(
|
||||
{"vector_db_url": "", "vector_db_key": "", "vector_db_provider": "lancedb"}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import os
|
||||
import cognee
|
||||
import pathlib
|
||||
from cognee.infrastructure.files.storage import get_storage_config
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
# from cognee.shared.utils import render_graph
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -44,8 +44,6 @@ async def main():
|
|||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
# await render_graph(None, include_labels = True, include_nodes = True)
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
|
@ -81,7 +79,8 @@ async def main():
|
|||
|
||||
# Assert local data files are cleaned properly
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
# Assert relational, vector and graph databases have been cleaned properly
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import shutil
|
|||
import cognee
|
||||
import pathlib
|
||||
|
||||
from cognee.infrastructure.files.storage import get_storage_config
|
||||
from cognee.modules.engine.models import NodeSet
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -112,7 +113,8 @@ async def main():
|
|||
)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -78,7 +79,8 @@ async def main():
|
|||
|
||||
# Assert local data files are cleaned properly
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
# Assert relational, vector and graph databases have been cleaned properly
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
|
@ -89,16 +91,28 @@ async def main():
|
|||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
with open(get_relational_engine().db_path, "r") as file:
|
||||
content = file.read()
|
||||
assert content == "", "SQLite relational database is not empty"
|
||||
db_path = get_relational_engine().db_path
|
||||
dir_path = os.path.dirname(db_path)
|
||||
file_path = os.path.basename(db_path)
|
||||
file_storage = get_file_storage(dir_path)
|
||||
|
||||
assert not await file_storage.file_exists(file_path), (
|
||||
"SQLite relational database is not deleted"
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
|
||||
graph_config = get_graph_config()
|
||||
assert not os.path.exists(graph_config.graph_file_path) or not os.listdir(
|
||||
graph_config.graph_file_path
|
||||
), "Kuzu graph directory is not empty"
|
||||
# For Kuzu v0.11.0+, check if database file doesn't exist (single-file format with .kuzu extension)
|
||||
# For older versions or other providers, check if directory is empty
|
||||
if graph_config.graph_database_provider.lower() == "kuzu":
|
||||
assert not os.path.exists(graph_config.graph_file_path), (
|
||||
"Kuzu graph database file still exists"
|
||||
)
|
||||
else:
|
||||
assert not os.path.exists(graph_config.graph_file_path) or not os.listdir(
|
||||
graph_config.graph_file_path
|
||||
), "Graph database directory is not empty"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import os
|
|||
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.infrastructure.files.storage import get_storage_config
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -91,7 +92,8 @@ async def main():
|
|||
assert len(history) == 8, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.infrastructure.files.storage import get_storage_config
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -87,7 +88,8 @@ async def main():
|
|||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
milvus_client = get_vector_engine().get_milvus_client()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.infrastructure.files.storage import get_storage_config
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
|
@ -116,7 +117,8 @@ async def main():
|
|||
)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.infrastructure.files.storage import get_storage_config
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.data.models import Data
|
||||
|
|
@ -23,26 +24,28 @@ async def test_local_file_deletion(data_text, file_location):
|
|||
data_hash = hashlib.md5(encoded_text).hexdigest()
|
||||
# Get data entry from database based on hash contents
|
||||
data = (await session.scalars(select(Data).where(Data.content_hash == data_hash))).one()
|
||||
assert os.path.isfile(data.raw_data_location), (
|
||||
assert os.path.isfile(data.raw_data_location.replace("file://", "")), (
|
||||
f"Data location doesn't exist: {data.raw_data_location}"
|
||||
)
|
||||
# Test deletion of data along with local files created by cognee
|
||||
await engine.delete_data_entity(data.id)
|
||||
assert not os.path.exists(data.raw_data_location), (
|
||||
assert not os.path.exists(data.raw_data_location.replace("file://", "")), (
|
||||
f"Data location still exists after deletion: {data.raw_data_location}"
|
||||
)
|
||||
|
||||
async with engine.get_async_session() as session:
|
||||
# Get data entry from database based on file path
|
||||
data = (
|
||||
await session.scalars(select(Data).where(Data.raw_data_location == file_location))
|
||||
await session.scalars(
|
||||
select(Data).where(Data.raw_data_location == "file://" + file_location)
|
||||
)
|
||||
).one()
|
||||
assert os.path.isfile(data.raw_data_location), (
|
||||
assert os.path.isfile(data.raw_data_location.replace("file://", "")), (
|
||||
f"Data location doesn't exist: {data.raw_data_location}"
|
||||
)
|
||||
# Test local files not created by cognee won't get deleted
|
||||
await engine.delete_data_entity(data.id)
|
||||
assert os.path.exists(data.raw_data_location), (
|
||||
assert os.path.exists(data.raw_data_location.replace("file://", "")), (
|
||||
f"Data location doesn't exists: {data.raw_data_location}"
|
||||
)
|
||||
|
||||
|
|
@ -164,7 +167,8 @@ async def main():
|
|||
await test_local_file_deletion(text, explanation_file_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
tables_in_database = await vector_engine.get_table_names()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.infrastructure.files.storage import get_storage_config
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -83,7 +84,8 @@ async def main():
|
|||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
qdrant_client = get_vector_engine().get_qdrant_client()
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import os
|
|||
import shutil
|
||||
import cognee
|
||||
import pathlib
|
||||
from cognee.infrastructure.files.storage import get_storage_config
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -93,7 +93,8 @@ async def main():
|
|||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.infrastructure.files.storage import get_storage_config
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -79,7 +80,8 @@ async def main():
|
|||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
collections = await get_vector_engine().client.collections.list_all()
|
||||
|
|
|
|||
|
|
@ -2,13 +2,15 @@ import os
|
|||
import tempfile
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
|
||||
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
|
||||
|
||||
class TestOpenDataFile:
|
||||
"""Test cases for open_data_file function with file:// URL handling."""
|
||||
|
||||
def test_regular_file_path(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_file_path(self):
|
||||
"""Test that regular file paths work as before."""
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
|
||||
test_content = "Test content for regular file path"
|
||||
|
|
@ -16,13 +18,14 @@ class TestOpenDataFile:
|
|||
temp_file_path = f.name
|
||||
|
||||
try:
|
||||
with open_data_file(temp_file_path, mode="r") as f:
|
||||
async with open_data_file(temp_file_path, mode="r") as f:
|
||||
content = f.read()
|
||||
assert content == test_content
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
def test_file_url_text_mode(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_url_text_mode(self):
|
||||
"""Test that file:// URLs work correctly in text mode."""
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
|
||||
test_content = "Test content for file:// URL handling"
|
||||
|
|
@ -32,13 +35,14 @@ class TestOpenDataFile:
|
|||
try:
|
||||
# Use pathlib.Path.as_uri() for proper cross-platform file URL creation
|
||||
file_url = Path(temp_file_path).as_uri()
|
||||
with open_data_file(file_url, mode="r") as f:
|
||||
async with open_data_file(file_url, mode="r") as f:
|
||||
content = f.read()
|
||||
assert content == test_content
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
def test_file_url_binary_mode(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_url_binary_mode(self):
|
||||
"""Test that file:// URLs work correctly in binary mode."""
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
|
||||
test_content = "Test content for binary mode"
|
||||
|
|
@ -48,13 +52,14 @@ class TestOpenDataFile:
|
|||
try:
|
||||
# Use pathlib.Path.as_uri() for proper cross-platform file URL creation
|
||||
file_url = Path(temp_file_path).as_uri()
|
||||
with open_data_file(file_url, mode="rb") as f:
|
||||
async with open_data_file(file_url, mode="rb") as f:
|
||||
content = f.read()
|
||||
assert content == test_content.encode()
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
def test_file_url_with_encoding(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_url_with_encoding(self):
|
||||
"""Test that file:// URLs work with specific encoding."""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".txt", encoding="utf-8"
|
||||
|
|
@ -66,20 +71,22 @@ class TestOpenDataFile:
|
|||
try:
|
||||
# Use pathlib.Path.as_uri() for proper cross-platform file URL creation
|
||||
file_url = Path(temp_file_path).as_uri()
|
||||
with open_data_file(file_url, mode="r", encoding="utf-8") as f:
|
||||
async with open_data_file(file_url, mode="r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert content == test_content
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
def test_file_url_nonexistent_file(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_url_nonexistent_file(self):
|
||||
"""Test that file:// URLs raise appropriate error for nonexistent files."""
|
||||
file_url = "file:///nonexistent/path/to/file.txt"
|
||||
with pytest.raises(FileNotFoundError):
|
||||
with open_data_file(file_url, mode="r") as f:
|
||||
async with open_data_file(file_url, mode="r") as f:
|
||||
f.read()
|
||||
|
||||
def test_multiple_file_prefixes(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_file_prefixes(self):
|
||||
"""Test that multiple file:// prefixes are handled correctly."""
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
|
||||
test_content = "Test content"
|
||||
|
|
@ -91,7 +98,7 @@ class TestOpenDataFile:
|
|||
# Use proper file URL creation first
|
||||
proper_file_url = Path(temp_file_path).as_uri()
|
||||
file_url = f"file://{proper_file_url}"
|
||||
with open_data_file(file_url, mode="r") as f:
|
||||
async with open_data_file(file_url, mode="r") as f:
|
||||
content = f.read()
|
||||
# This should work because we only replace the first occurrence
|
||||
assert content == test_content
|
||||
|
|
|
|||
|
|
@ -16,11 +16,11 @@ class TestChunksRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunks_context_simple"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunks_context_simple"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -162,11 +162,11 @@ class TestChunksRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_empty"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_empty"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -183,16 +183,3 @@ class TestChunksRetriever:
|
|||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert len(context) == 0, "Found chunks when none should exist"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
test = TestChunksRetriever()
|
||||
|
||||
async def main():
|
||||
await test.test_chunk_context_simple()
|
||||
await test.test_chunk_context_complex()
|
||||
await test.test_chunk_context_on_empty_graph()
|
||||
|
||||
run(main())
|
||||
|
|
|
|||
|
|
@ -12,15 +12,17 @@ from cognee.modules.retrieval.graph_completion_context_extension_retriever impor
|
|||
)
|
||||
|
||||
|
||||
class TestGraphCompletionRetriever:
|
||||
class TestGraphCompletionWithContextExtensionRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_graph_completion_extension_context_simple",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_graph_completion_extension_context_simple",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -64,11 +66,13 @@ class TestGraphCompletionRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_graph_completion_extension_context_complex",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_graph_completion_extension_context_complex",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -143,11 +147,13 @@ class TestGraphCompletionRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_graph_completion_extension_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_graph_completion_extension_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -170,16 +176,3 @@ class TestGraphCompletionRetriever:
|
|||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
test = TestGraphCompletionRetriever()
|
||||
|
||||
async def main():
|
||||
await test.test_graph_completion_context_simple()
|
||||
await test.test_graph_completion_context_complex()
|
||||
await test.test_get_graph_completion_context_on_empty_graph()
|
||||
|
||||
run(main())
|
||||
|
|
|
|||
|
|
@ -10,15 +10,15 @@ from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
|||
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
|
||||
|
||||
|
||||
class TestGraphCompletionRetriever:
|
||||
class TestGraphCompletionCoTRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -62,11 +62,12 @@ class TestGraphCompletionRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_graph_completion_cot_context_complex",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -141,11 +142,13 @@ class TestGraphCompletionRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_graph_completion_cot_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_graph_completion_cot_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -168,16 +171,3 @@ class TestGraphCompletionRetriever:
|
|||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
test = TestGraphCompletionRetriever()
|
||||
|
||||
async def main():
|
||||
await test.test_graph_completion_context_simple()
|
||||
await test.test_graph_completion_context_complex()
|
||||
await test.test_get_graph_completion_context_on_empty_graph()
|
||||
|
||||
run(main())
|
||||
|
|
|
|||
|
|
@ -14,11 +14,11 @@ class TestGraphCompletionRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_context"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_context"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -55,11 +55,11 @@ class TestGraphCompletionRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -127,11 +127,13 @@ class TestGraphCompletionRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_graph_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_graph_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -147,16 +149,3 @@ class TestGraphCompletionRetriever:
|
|||
|
||||
context = await retriever.get_context("Who works at Figma?")
|
||||
assert context == "", "Context should be empty on an empty graph"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
test = TestGraphCompletionRetriever()
|
||||
|
||||
async def main():
|
||||
await test.test_graph_completion_context_simple()
|
||||
await test.test_graph_completion_context_complex()
|
||||
await test.test_get_graph_completion_context_on_empty_graph()
|
||||
|
||||
run(main())
|
||||
|
|
|
|||
|
|
@ -227,11 +227,11 @@ class TestInsightsRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_insights_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_empty"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_insights_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_empty"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_insights_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -249,13 +249,3 @@ class TestInsightsRetriever:
|
|||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == [], "Returned context should be empty on an empty graph"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
test = TestInsightsRetriever()
|
||||
|
||||
run(test.test_insights_context_simple())
|
||||
run(test.test_insights_context_complex())
|
||||
run(test.test_insights_context_on_empty_graph())
|
||||
|
|
|
|||
|
|
@ -16,11 +16,11 @@ class TestRAGCompletionRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rag_completion_context_simple(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_rag_context"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_rag_context"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -73,11 +73,11 @@ class TestRAGCompletionRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_rag_completion_context_complex(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -163,11 +163,13 @@ class TestRAGCompletionRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_rag_completion_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_get_rag_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context"
|
||||
pathlib.Path(__file__).parent,
|
||||
".data_storage/test_get_rag_completion_context_on_empty_graph",
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -184,13 +186,3 @@ class TestRAGCompletionRetriever:
|
|||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == "", "Returned context should be empty on an empty graph"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
test = TestRAGCompletionRetriever()
|
||||
|
||||
run(test.test_rag_completion_context_simple())
|
||||
run(test.test_rag_completion_context_complex())
|
||||
run(test.test_get_rag_completion_context_on_empty_graph())
|
||||
|
|
|
|||
|
|
@ -17,11 +17,11 @@ class TextSummariesRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_chunk_context(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -136,11 +136,11 @@ class TextSummariesRetriever:
|
|||
@pytest.mark.asyncio
|
||||
async def test_chunk_context_on_empty_graph(self):
|
||||
system_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_summary_context"
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.system_root_directory(system_directory_path)
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_summary_context"
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
|
|
@ -157,12 +157,3 @@ class TextSummariesRetriever:
|
|||
|
||||
context = await retriever.get_context("Christina Mayer")
|
||||
assert context == [], "Returned context should be empty on an empty graph"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from asyncio import run
|
||||
|
||||
test = TextSummariesRetriever()
|
||||
|
||||
run(test.test_chunk_context())
|
||||
run(test.test_chunk_context_on_empty_graph())
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
|
@ -6,12 +7,9 @@ from unittest.mock import patch, mock_open
|
|||
from io import BytesIO
|
||||
from uuid import uuid4
|
||||
|
||||
from cognee.shared.utils import (
|
||||
get_anonymous_id,
|
||||
get_file_content_hash,
|
||||
prepare_edges,
|
||||
prepare_nodes,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.files.utils.get_file_content_hash import get_file_content_hash
|
||||
from cognee.shared.utils import get_anonymous_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -28,48 +26,31 @@ def test_get_anonymous_id(mock_open_file, mock_makedirs, temp_dir):
|
|||
assert len(anon_id) > 0
|
||||
|
||||
|
||||
# @patch("requests.post")
|
||||
# def test_send_telemetry(mock_post):
|
||||
# mock_post.return_value.status_code = 200
|
||||
#
|
||||
# send_telemetry("test_event", "test_user", {"key": "value"})
|
||||
# mock_post.assert_called_once()
|
||||
#
|
||||
# args, kwargs = mock_post.call_args
|
||||
# assert kwargs["json"]["event_name"] == "test_event"
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_file_content_hash_file():
|
||||
temp_file_path = None
|
||||
text_content = "Test content with UTF-8: café ☕"
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt", encoding="utf-8") as f:
|
||||
test_content = text_content
|
||||
f.write(test_content)
|
||||
temp_file_path = f.name
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data=b"test_data")
|
||||
def test_get_file_content_hash_file(mock_open_file):
|
||||
import hashlib
|
||||
|
||||
expected_hash = hashlib.md5(b"test_data").hexdigest()
|
||||
result = get_file_content_hash("test_file.txt")
|
||||
assert result == expected_hash
|
||||
try:
|
||||
expected_hash = hashlib.md5(text_content.encode("utf-8")).hexdigest()
|
||||
result = await get_file_content_hash(temp_file_path)
|
||||
assert result == expected_hash
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_get_file_content_hash_stream():
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_file_content_hash_stream():
|
||||
stream = BytesIO(b"test_data")
|
||||
import hashlib
|
||||
|
||||
expected_hash = hashlib.md5(b"test_data").hexdigest()
|
||||
result = get_file_content_hash(stream)
|
||||
result = await get_file_content_hash(stream)
|
||||
assert result == expected_hash
|
||||
|
||||
|
||||
def test_prepare_edges():
|
||||
graph = nx.MultiDiGraph()
|
||||
graph.add_edge("A", "B", key="AB", weight=1)
|
||||
edges_df = prepare_edges(graph, "source", "target", "key")
|
||||
|
||||
assert isinstance(edges_df, pd.DataFrame)
|
||||
assert len(edges_df) == 1
|
||||
|
||||
|
||||
def test_prepare_nodes():
|
||||
graph = nx.Graph()
|
||||
graph.add_node(1, name="Node1")
|
||||
nodes_df = prepare_nodes(graph)
|
||||
|
||||
assert isinstance(nodes_df, pd.DataFrame)
|
||||
assert len(nodes_df) == 1
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from cognee import visualize_graph
|
|||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.pipelines import run_tasks, Task
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.shared.utils import render_graph
|
||||
|
||||
|
||||
class Person(DataPoint):
|
||||
|
|
@ -80,9 +79,6 @@ async def main():
|
|||
async for status in pipeline:
|
||||
print(status)
|
||||
|
||||
# Get a graphistry url (Register for a free account at https://www.graphistry.com)
|
||||
await render_graph()
|
||||
|
||||
# Or use our simple graph preview
|
||||
graph_file_path = str(
|
||||
os.path.join(os.path.dirname(__file__), ".artifacts/graph_visualization.html")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
from cognee.low_level import setup, DataPoint
|
||||
from cognee.pipelines import run_tasks, Task
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.shared.utils import render_graph
|
||||
|
||||
|
||||
class Products(DataPoint):
|
||||
|
|
@ -119,9 +118,6 @@ async def main():
|
|||
async for status in pipeline:
|
||||
print(status)
|
||||
|
||||
# Get a graphistry url (Register for a free account at https://www.graphistry.com)
|
||||
await render_graph()
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
products_results = await graph_engine.query(
|
||||
|
|
|
|||
2557
poetry.lock
generated
2557
poetry.lock
generated
File diff suppressed because it is too large
Load diff
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue