Compare commits

...
Sign in to create a new pull request.

4 commits

Author SHA1 Message Date
vasilije
a8378a429d format 2025-08-12 15:59:39 +02:00
Vasilije
7b0879f88a
Merge branch 'dev' into refactor/remove-unnecessary-code 2025-08-12 15:42:03 +02:00
Boris Arzentar
0696d4c340
Merge remote-tracking branch 'origin/dev' into refactor/remove-unnecessary-code 2025-07-25 12:01:45 +02:00
Boris Arzentar
d9761118b2
refactor: remove unnecessary methods and enforce config objects 2025-07-25 12:01:38 +02:00
46 changed files with 232 additions and 541 deletions

View file

@ -7,13 +7,13 @@ from cognee.shared.logging_utils import get_logger
from cognee.shared.data_models import KnowledgeGraph
from cognee.infrastructure.llm import get_max_chunk_tokens
from cognee.modules.pipelines import cognee_pipeline
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.models import User
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.pipelines import cognee_pipeline
from cognee.modules.pipelines.models.PipelineRunInfo import PipelineRunCompleted, PipelineRunErrored
from cognee.modules.pipelines.queues.pipeline_run_info_queues import push_to_queue
from cognee.modules.users.models import User
from cognee.tasks.documents import (
check_permissions_on_dataset,
@ -30,14 +30,14 @@ update_status_lock = asyncio.Lock()
async def cognify(
datasets: Union[str, list[str], list[UUID]] = None,
user: User = None,
graph_model: BaseModel = KnowledgeGraph,
datasets: Optional[Union[str, list[str], list[UUID]]] = None,
user: Optional[User] = None,
graph_model: Optional[BaseModel] = KnowledgeGraph,
chunker=TextChunker,
chunk_size: int = None,
chunk_size: Optional[int] = None,
ontology_file_path: Optional[str] = None,
vector_db_config: dict = None,
graph_db_config: dict = None,
vector_db_config: Optional[dict] = None,
graph_db_config: Optional[dict] = None,
run_in_background: bool = False,
incremental_loading: bool = True,
):
@ -212,8 +212,8 @@ async def run_cognify_blocking(
tasks,
user,
datasets,
graph_db_config: dict = None,
vector_db_config: dict = False,
graph_db_config=None,
vector_db_config=None,
incremental_loading: bool = True,
):
total_run_info = {}
@ -239,8 +239,8 @@ async def run_cognify_as_background_process(
tasks,
user,
datasets,
graph_db_config: dict = None,
vector_db_config: dict = False,
graph_db_config=None,
vector_db_config=None,
incremental_loading: bool = True,
):
# Convert dataset to list if it's a string

View file

@ -2,7 +2,7 @@
import os
from cognee.base_config import get_base_config
from cognee.exceptions import InvalidValueError, InvalidAttributeError
from cognee.exceptions import InvalidAttributeError
from cognee.modules.cognify.config import get_cognify_config
from cognee.infrastructure.data.chunking.config import get_chunk_config
from cognee.infrastructure.databases.vector import get_vectordb_config
@ -107,12 +107,12 @@ class config:
chunk_config.chunk_engine = chunk_engine
@staticmethod
def set_chunk_overlap(chunk_overlap: object):
def set_chunk_overlap(chunk_overlap: int):
chunk_config = get_chunk_config()
chunk_config.chunk_overlap = chunk_overlap
@staticmethod
def set_chunk_size(chunk_size: object):
def set_chunk_size(chunk_size: int):
chunk_config = get_chunk_config()
chunk_config.chunk_size = chunk_size

View file

@ -10,19 +10,10 @@ class BaseConfig(BaseSettings):
data_root_directory: str = get_absolute_path(".data_storage")
system_root_directory: str = get_absolute_path(".cognee_system")
monitoring_tool: object = Observer.LANGFUSE
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY")
langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST")
default_user_email: Optional[str] = os.getenv("DEFAULT_USER_EMAIL")
default_user_password: Optional[str] = os.getenv("DEFAULT_USER_PASSWORD")
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
return {
"data_root_directory": self.data_root_directory,
"system_root_directory": self.system_root_directory,
"monitoring_tool": self.monitoring_tool,
}
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@lru_cache

View file

@ -1,17 +1,19 @@
import os
from contextvars import ContextVar
from typing import Union
from typing import Optional, Union
from uuid import UUID
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.graph.config import GraphConfig
from cognee.infrastructure.databases.vector.config import VectorConfig
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
from cognee.infrastructure.files.storage.config import file_storage_config
from cognee.infrastructure.files.storage.config import StorageConfig, file_storage_config
from cognee.modules.users.methods import get_user
# Note: ContextVar allows us to use different graph db configurations in Cognee
# for different async tasks, threads and processes
vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
vector_db_config = ContextVar[Optional[VectorConfig]]("vector_db_config", default=None)
graph_db_config = ContextVar[Optional[GraphConfig]]("graph_db_config", default=None)
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
@ -51,24 +53,24 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
)
# Set vector and graph database configuration based on dataset database information
vector_config = {
"vector_db_url": os.path.join(
databases_directory_path, dataset_database.vector_database_name
vector_config = VectorConfig(
vector_db_url=os.path.join(
databases_directory_path, str(dataset_database.vector_database_name)
),
"vector_db_key": "",
"vector_db_provider": "lancedb",
}
vector_db_key="",
vector_db_provider="lancedb",
)
graph_config = {
"graph_database_provider": "kuzu",
"graph_file_path": os.path.join(
databases_directory_path, dataset_database.graph_database_name
graph_config = GraphConfig(
graph_database_provider="kuzu",
graph_file_path=os.path.join(
databases_directory_path, str(dataset_database.graph_database_name)
),
}
)
storage_config = {
"data_root_directory": data_root_directory,
}
storage_config = StorageConfig(
data_root_directory=data_root_directory,
)
# Use ContextVar to use these graph and vector configurations are used
# in the current async context across Cognee

View file

@ -1,4 +1,5 @@
from functools import lru_cache
from deprecated import deprecated
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List, Optional
@ -45,6 +46,7 @@ class EvalConfig(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@deprecated(reason="Call model_dump() instead of calling to_dict() method.")
def to_dict(self) -> dict:
return {
"building_corpus_from_scratch": self.building_corpus_from_scratch,

View file

@ -1,7 +1,6 @@
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.infrastructure.data.chunking.DefaultChunkEngine import DefaultChunkEngine
from cognee.shared.data_models import ChunkStrategy, ChunkEngine
@ -12,27 +11,11 @@ class ChunkConfig(BaseSettings):
chunk_size: int = 1500
chunk_overlap: int = 10
chunk_strategy: object = ChunkStrategy.PARAGRAPH
chunk_engine: object = ChunkEngine.DEFAULT_ENGINE
chunk_strategy: ChunkStrategy = ChunkStrategy.PARAGRAPH
chunk_engine: ChunkEngine = ChunkEngine.DEFAULT_ENGINE
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
"""
Convert the chunk settings to a dictionary format.
Returns:
--------
- dict: A dictionary representation of the chunk configuration settings.
"""
return {
"chunk_size": self.chunk_size,
"chunk_overlap": self.chunk_overlap,
"chunk_strategy": self.chunk_strategy,
"chunk_engine": self.chunk_engine,
}
@lru_cache
def get_chunk_config():

View file

@ -1,25 +1,12 @@
from typing import Dict
from cognee.shared.data_models import ChunkEngine, ChunkStrategy
from cognee.shared.data_models import ChunkEngine
class ChunkingConfig(Dict):
"""
Represent configuration settings for chunking operations, inheriting from the built-in
Dict class. The class contains the following public attributes:
- vector_db_url: A string representing the URL of the vector database.
- vector_db_key: A string representing the key for accessing the vector database.
- vector_db_provider: A string representing the provider of the vector database.
"""
vector_db_url: str
vector_db_key: str
vector_db_provider: str
def create_chunking_engine(config: ChunkingConfig):
def create_chunking_engine(
chunk_size: int,
chunk_overlap: int,
chunk_engine: ChunkEngine,
chunk_strategy: ChunkStrategy,
):
"""
Create a chunking engine based on the provided configuration.
@ -30,7 +17,7 @@ def create_chunking_engine(config: ChunkingConfig):
Parameters:
-----------
- config (ChunkingConfig): Configuration object containing the settings for the
- config (ChunkConfig): Configuration object containing the settings for the
chunking engine, including the engine type, chunk size, chunk overlap, and chunk
strategy.
@ -40,27 +27,27 @@ def create_chunking_engine(config: ChunkingConfig):
An instance of the selected chunking engine class (LangchainChunkEngine,
DefaultChunkEngine, or HaystackChunkEngine).
"""
if config["chunk_engine"] == ChunkEngine.LANGCHAIN_ENGINE:
if chunk_engine == ChunkEngine.LANGCHAIN_ENGINE:
from cognee.infrastructure.data.chunking.LangchainChunkingEngine import LangchainChunkEngine
return LangchainChunkEngine(
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],
chunk_strategy=config["chunk_strategy"],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
chunk_strategy=chunk_strategy,
)
elif config["chunk_engine"] == ChunkEngine.DEFAULT_ENGINE:
elif chunk_engine == ChunkEngine.DEFAULT_ENGINE:
from cognee.infrastructure.data.chunking.DefaultChunkEngine import DefaultChunkEngine
return DefaultChunkEngine(
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],
chunk_strategy=config["chunk_strategy"],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
chunk_strategy=chunk_strategy,
)
elif config["chunk_engine"] == ChunkEngine.HAYSTACK_ENGINE:
elif chunk_engine == ChunkEngine.HAYSTACK_ENGINE:
from cognee.infrastructure.data.chunking.HaystackChunkEngine import HaystackChunkEngine
return HaystackChunkEngine(
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],
chunk_strategy=config["chunk_strategy"],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
chunk_strategy=chunk_strategy,
)

View file

@ -13,4 +13,10 @@ def get_chunk_engine():
Returns an instance of the chunking engine created based on the configuration
settings.
"""
return create_chunking_engine(get_chunk_config().to_dict())
chunk_config = get_chunk_config()
return create_chunking_engine(
chunk_engine=chunk_config.chunk_engine,
chunk_size=chunk_config.chunk_size,
chunk_overlap=chunk_config.chunk_overlap,
chunk_strategy=chunk_config.chunk_strategy,
)

View file

@ -1,10 +1,11 @@
"""This module contains the configuration for the graph database."""
import os
import pydantic
from typing import Optional
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
import pydantic
from pydantic import Field
from cognee.base_config import get_base_config
from cognee.shared.data_models import KnowledgeGraph
@ -14,10 +15,6 @@ class GraphConfig(BaseSettings):
Represents the configuration for a graph system, including parameters for graph file
storage and database connections.
Public methods:
- to_dict
- to_hashable_dict
Instance variables:
- graph_filename
- graph_database_provider
@ -31,10 +28,7 @@ class GraphConfig(BaseSettings):
- model_config
"""
# Using Field we are able to dynamically load current GRAPH_DATABASE_PROVIDER value in the model validator part
# and determine default graph db file and path based on this parameter if no values are provided
graph_database_provider: str = Field("kuzu", env="GRAPH_DATABASE_PROVIDER")
graph_database_provider: str = "kuzu"
graph_database_url: str = ""
graph_database_name: str = ""
graph_database_username: str = ""
@ -65,54 +59,6 @@ class GraphConfig(BaseSettings):
return values
def to_dict(self) -> dict:
"""
Return the configuration as a dictionary.
This dictionary contains all the configurations related to the graph, which includes
details for file storage and database connectivity.
Returns:
--------
- dict: A dictionary representation of the configuration settings.
"""
return {
"graph_filename": self.graph_filename,
"graph_database_provider": self.graph_database_provider,
"graph_database_url": self.graph_database_url,
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
"graph_file_path": self.graph_file_path,
"graph_model": self.graph_model,
"graph_topology": self.graph_topology,
"model_config": self.model_config,
}
def to_hashable_dict(self) -> dict:
"""
Return a hashable dictionary with essential database configuration parameters.
This dictionary excludes certain non-hashable objects and focuses on unique identifiers
for database configurations.
Returns:
--------
- dict: A dictionary representation of the essential database configuration
settings.
"""
return {
"graph_database_provider": self.graph_database_provider,
"graph_database_url": self.graph_database_url,
"graph_database_name": self.graph_database_name,
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
"graph_file_path": self.graph_file_path,
}
@lru_cache
def get_graph_config():
@ -128,15 +74,18 @@ def get_graph_config():
- GraphConfig: A GraphConfig instance containing the graph configuration settings.
"""
context_config = get_graph_context_config()
if context_config:
return context_config
return GraphConfig()
def get_graph_context_config():
def get_graph_context_config() -> Optional[GraphConfig]:
"""This function will get the appropriate graph db config based on async context.
This allows the use of multiple graph databases for different threads, async tasks and parallelization
"""
from cognee.context_global_variables import graph_db_config
if graph_db_config.get():
return graph_db_config.get()
return get_graph_config().to_hashable_dict()
return graph_db_config.get()

View file

@ -2,7 +2,7 @@
from functools import lru_cache
from .config import get_graph_context_config
from .config import get_graph_config
from .graph_db_interface import GraphDBInterface
from .supported_databases import supported_databases
@ -10,20 +10,16 @@ from .supported_databases import supported_databases
async def get_graph_engine() -> GraphDBInterface:
"""Factory function to get the appropriate graph client based on the graph type."""
# Get appropriate graph configuration based on current async context
config = get_graph_context_config()
config = get_graph_config()
graph_client = create_graph_engine(**config)
# Async functions can't be cached. After creating and caching the graph engine
# handle all necessary async operations for different graph types bellow.
# Run any adapterspecific async initialization
if hasattr(graph_client, "initialize"):
await graph_client.initialize()
# Handle loading of graph for NetworkX
if config["graph_database_provider"].lower() == "networkx" and graph_client.graph is None:
await graph_client.load_graph_from_file()
graph_client = create_graph_engine(
graph_database_provider=config.graph_database_provider,
graph_file_path=config.graph_file_path,
graph_database_url=config.graph_database_url,
graph_database_username=config.graph_database_username,
graph_database_password=config.graph_database_password,
graph_database_port=config.graph_database_port,
)
return graph_client
@ -36,7 +32,7 @@ def create_graph_engine(
graph_database_name="",
graph_database_username="",
graph_database_password="",
graph_database_port="",
graph_database_port=None,
):
"""
Create a graph engine based on the specified provider type.

View file

@ -8,9 +8,10 @@ from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError
from contextlib import asynccontextmanager
from typing import Optional, Any, List, Dict, Type, Tuple
from typing import AsyncGenerator, Optional, Any, List, Dict, Type, Tuple
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
@ -68,16 +69,17 @@ class Neo4jAdapter(GraphDBInterface):
notifications_min_severity="OFF",
)
async def initialize(self) -> None:
"""
Initializes the database: adds uniqueness constraint on id and performs indexing
"""
await self.query(
(f"CREATE CONSTRAINT IF NOT EXISTS FOR (n:`{BASE_LABEL}`) REQUIRE n.id IS UNIQUE;")
run_sync(
self.query(
(f"CREATE CONSTRAINT IF NOT EXISTS FOR (n:`{BASE_LABEL}`) REQUIRE n.id IS UNIQUE;")
)
)
@asynccontextmanager
async def get_session(self) -> AsyncSession:
async def get_session(self) -> AsyncGenerator[AsyncSession]:
"""
Get a session for database operations.
"""

View file

@ -11,6 +11,7 @@ from typing import Dict, Any, List, Union, Type, Tuple
from cognee.infrastructure.databases.exceptions.exceptions import NodesetFilterNotSupportedError
from cognee.infrastructure.files.storage import get_file_storage
from cognee.infrastructure.utils.run_sync import run_sync
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
@ -42,6 +43,8 @@ class NetworkXAdapter(GraphDBInterface):
def __init__(self, filename="cognee_graph.pkl"):
self.filename = filename
run_sync(self.load_graph_from_file())
async def get_graph_data(self):
"""
Retrieve graph data including nodes and edges.
@ -576,7 +579,7 @@ class NetworkXAdapter(GraphDBInterface):
await file_storage.store(file_path, json_data, overwrite=True)
async def load_graph_from_file(self, file_path: str = None):
async def load_graph_from_file(self):
"""
Load graph data asynchronously from a specified file in JSON format.
@ -586,8 +589,8 @@ class NetworkXAdapter(GraphDBInterface):
- file_path (str): The file path from which to load the graph data; if None, loads
from the default filename. (default None)
"""
if not file_path:
file_path = self.filename
file_path = self.filename
try:
file_dir_path = os.path.dirname(file_path)
file_name = os.path.basename(file_path)

View file

@ -17,7 +17,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
EdgeData,
Node,
)
from cognee.infrastructure.databases.vector.embeddings import EmbeddingEngine
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface
from cognee.infrastructure.engine import DataPoint
@ -80,7 +80,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
self,
database_url: str,
database_port: int,
embedding_engine=EmbeddingEngine,
embedding_engine: EmbeddingEngine,
):
self.driver = FalkorDB(
host=database_url,
@ -213,7 +213,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
A tuple containing the query string and parameters dictionary.
"""
node_label = type(data_point).__name__
property_names = DataPoint.get_embeddable_property_names(data_point)
property_names = data_point.get_embeddable_property_names()
properties = {
**data_point.model_dump(),
@ -357,7 +357,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
vector_map = {}
for data_point in data_points:
property_names = DataPoint.get_embeddable_property_names(data_point)
property_names = data_point.get_embeddable_property_names()
key = str(data_point.id)
vector_map[key] = {}
@ -377,7 +377,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
vectorized_values[vector_map[str(data_point.id)][property_name]]
if vector_map[str(data_point.id)][property_name] is not None
else None
for property_name in DataPoint.get_embeddable_property_names(data_point)
for property_name in data_point.get_embeddable_property_names()
]
query, params = await self.create_data_point_query(data_point, vectorized_data)

View file

@ -32,26 +32,6 @@ class RelationalConfig(BaseSettings):
return values
def to_dict(self) -> dict:
"""
Return the database configuration as a dictionary.
Returns:
--------
- dict: A dictionary containing database configuration settings including db_path,
db_name, db_host, db_port, db_username, db_password, and db_provider.
"""
return {
"db_path": self.db_path,
"db_name": self.db_name,
"db_host": self.db_host,
"db_port": self.db_port,
"db_username": self.db_username,
"db_password": self.db_password,
"db_provider": self.db_provider,
}
@lru_cache
def get_relational_config():
@ -75,9 +55,6 @@ class MigrationConfig(BaseSettings):
"""
Manage and configure migration settings for a database, inheriting from BaseSettings.
Public methods:
- to_dict: Convert the migration configuration to a dictionary format.
Instance variables:
- migration_db_path: Path to the migration database.
- migration_db_name: Name of the migration database.
@ -98,25 +75,6 @@ class MigrationConfig(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
"""
Convert migration configuration to dictionary format.
Returns:
--------
- dict: A dictionary containing the migration configuration details.
"""
return {
"migration_db_path": self.migration_db_path,
"migration_db_name": self.migration_db_name,
"migration_db_host": self.migration_db_host,
"migration_db_port": self.migration_db_port,
"migration_db_username": self.migration_db_username,
"migration_db_password": self.migration_db_password,
"migration_db_provider": self.migration_db_provider,
}
@lru_cache
def get_migration_config():

View file

@ -18,4 +18,12 @@ def get_relational_engine():
"""
relational_config = get_relational_config()
return create_relational_engine(**relational_config.to_dict())
return create_relational_engine(
db_path=relational_config.db_path,
db_name=relational_config.db_name,
db_host=relational_config.db_host,
db_port=relational_config.db_port,
db_username=relational_config.db_username,
db_password=relational_config.db_password,
db_provider=relational_config.db_provider,
)

View file

@ -295,10 +295,10 @@ class SQLAlchemyAdapter:
storage_config = get_storage_config()
if (
storage_config["data_root_directory"]
storage_config.data_root_directory
in raw_data_location_entities[0].raw_data_location
):
file_storage = get_file_storage(storage_config["data_root_directory"])
file_storage = get_file_storage(storage_config.data_root_directory)
file_path = os.path.basename(raw_data_location_entities[0].raw_data_location)

View file

@ -1,4 +1,5 @@
import os
from typing import Optional
import pydantic
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
@ -10,10 +11,6 @@ class VectorConfig(BaseSettings):
"""
Manage the configuration settings for the vector database.
Public methods:
- to_dict: Convert the configuration to a dictionary.
Instance variables:
- vector_db_url: The URL of the vector database.
@ -39,22 +36,6 @@ class VectorConfig(BaseSettings):
return values
def to_dict(self) -> dict:
"""
Convert the configuration settings to a dictionary.
Returns:
--------
- dict: A dictionary containing the vector database configuration settings.
"""
return {
"vector_db_url": self.vector_db_url,
"vector_db_port": self.vector_db_port,
"vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider,
}
@lru_cache
def get_vectordb_config():
@ -71,13 +52,16 @@ def get_vectordb_config():
- VectorConfig: An instance of `VectorConfig` containing the vector database
configuration.
"""
context_config = get_vectordb_context_config()
if context_config:
return context_config
return VectorConfig()
def get_vectordb_context_config():
def get_vectordb_context_config() -> Optional[VectorConfig]:
"""This function will get the appropriate vector db config based on async context."""
from cognee.context_global_variables import vector_db_config
if vector_db_config.get():
return vector_db_config.get()
return get_vectordb_config().to_dict()
return vector_db_config.get()

View file

@ -1,3 +1,4 @@
from typing import Optional
from .supported_databases import supported_databases
from .embeddings import get_embedding_engine
@ -8,7 +9,7 @@ from functools import lru_cache
def create_vector_engine(
vector_db_provider: str,
vector_db_url: str,
vector_db_port: str = "",
vector_db_port: Optional[int] = None,
vector_db_key: str = "",
):
"""
@ -26,7 +27,7 @@ def create_vector_engine(
-----------
- vector_db_url (str): The URL for the vector database instance.
- vector_db_port (str): The port for the vector database instance. Required for some
- vector_db_port (int): The port for the vector database instance. Required for some
providers.
- vector_db_key (str): The API key or access token for the vector database instance.
- vector_db_provider (str): The name of the vector database provider to use (e.g.,

View file

@ -7,9 +7,6 @@ class EmbeddingConfig(BaseSettings):
"""
Manage configuration settings for embedding operations, including provider, model
details, API configuration, and tokenizer settings.
Public methods:
- to_dict: Serialize the configuration settings to a dictionary.
"""
embedding_provider: Optional[str] = "openai"
@ -22,26 +19,6 @@ class EmbeddingConfig(BaseSettings):
huggingface_tokenizer: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
"""
Serialize all embedding configuration settings to a dictionary.
Returns:
--------
- dict: A dictionary containing the embedding configuration settings.
"""
return {
"embedding_provider": self.embedding_provider,
"embedding_model": self.embedding_model,
"embedding_dimensions": self.embedding_dimensions,
"embedding_endpoint": self.embedding_endpoint,
"embedding_api_key": self.embedding_api_key,
"embedding_api_version": self.embedding_api_version,
"embedding_max_tokens": self.embedding_max_tokens,
"huggingface_tokenizer": self.huggingface_tokenizer,
}
@lru_cache
def get_embedding_config():

View file

@ -1,7 +1,14 @@
from .config import get_vectordb_context_config
from .config import get_vectordb_config
from .create_vector_engine import create_vector_engine
def get_vector_engine():
# Get appropriate vector db configuration based on current async context
return create_vector_engine(**get_vectordb_context_config())
vector_config = get_vectordb_config()
return create_vector_engine(
vector_db_provider=vector_config.vector_db_provider,
vector_db_url=vector_config.vector_db_url,
vector_db_port=vector_config.vector_db_port,
vector_db_key=vector_config.vector_db_key,
)

View file

@ -1,12 +1,14 @@
from sqlalchemy import text
from ..get_vector_engine import get_vector_engine, get_vectordb_context_config
from ..config import get_vectordb_config
from ..get_vector_engine import get_vector_engine
async def create_db_and_tables():
# Get appropriate vector db configuration based on current async context
vector_config = get_vectordb_context_config()
vector_config = get_vectordb_config()
vector_engine = get_vector_engine()
if vector_config["vector_db_provider"] == "pgvector":
if vector_config.vector_db_provider == "pgvector":
async with vector_engine.engine.begin() as connection:
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))

View file

@ -1,12 +1,10 @@
import pickle
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, ConfigDict
from datetime import datetime, timezone
from typing import Optional, List
from typing_extensions import TypedDict
from typing import Optional, Any, Dict, List
from datetime import datetime, timezone
from pydantic import BaseModel, Field, ConfigDict
# Define metadata type
class MetaData(TypedDict):
"""
Represent a metadata structure with type and index fields.
@ -16,7 +14,6 @@ class MetaData(TypedDict):
index_fields: list[str]
# Updated DataPoint model with versioning and new fields
class DataPoint(BaseModel):
"""
Model representing a data point with versioning and metadata support.
@ -26,12 +23,6 @@ class DataPoint(BaseModel):
- get_embeddable_properties
- get_embeddable_property_names
- update_version
- to_json
- from_json
- to_pickle
- from_pickle
- to_dict
- from_dict
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@ -85,52 +76,39 @@ class DataPoint(BaseModel):
return attribute.strip()
return attribute
@classmethod
def get_embeddable_properties(self, data_point: "DataPoint"):
def get_embeddable_properties(self):
"""
Retrieve a list of embeddable properties from the data point.
This method returns a list of attribute values based on the index fields defined in the
data point's metadata. If there are no index fields, it returns an empty list.
Parameters:
-----------
- data_point ('DataPoint'): The DataPoint instance from which to retrieve embeddable
properties.
Returns:
--------
A list of embeddable property values, or an empty list if none exist.
"""
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0:
return [
getattr(data_point, field, None) for field in data_point.metadata["index_fields"]
]
if self.metadata and len(self.metadata["index_fields"]) > 0:
return [getattr(self, field, None) for field in self.metadata["index_fields"]]
return []
@classmethod
def get_embeddable_property_names(self, data_point: "DataPoint"):
def get_embeddable_property_names(self):
"""
Retrieve the names of embeddable properties defined in the metadata.
If no index fields are defined in the metadata, this method will return an empty list.
Parameters:
-----------
- data_point ('DataPoint'): The DataPoint instance from which to retrieve embeddable
property names.
Returns:
--------
A list of property names corresponding to the index fields, or an empty list if none
exist.
"""
return data_point.metadata["index_fields"] or []
if self.metadata:
return self.metadata["index_fields"] or []
return []
def update_version(self):
"""
@ -141,80 +119,3 @@ class DataPoint(BaseModel):
"""
self.version += 1
self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)
# JSON Serialization
def to_json(self) -> str:
"""
Serialize the DataPoint instance to a JSON string format.
This method uses the model's built-in serialization functionality to convert the
instance into a JSON-compatible string.
Returns:
--------
- str: The JSON string representation of the DataPoint instance.
"""
return self.json()
@classmethod
def from_json(self, json_str: str):
"""
Deserialize a DataPoint instance from a JSON string.
The method transforms the input JSON string back into a DataPoint instance using model
validation.
Parameters:
-----------
- json_str (str): The JSON string representation of a DataPoint instance to be
deserialized.
Returns:
--------
A new DataPoint instance created from the JSON data.
"""
return self.model_validate_json(json_str)
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""
Convert the DataPoint instance to a dictionary representation.
This method uses the model's built-in functionality to serialize the instance attributes
to a dictionary, which can optionally include additional arguments.
Parameters:
-----------
- **kwargs: Additional keyword arguments for serialization options.
Returns:
--------
- Dict[str, Any]: A dictionary representation of the DataPoint instance.
"""
return self.model_dump(**kwargs)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DataPoint":
"""
Instantiate a DataPoint from a dictionary of attribute values.
The method validates the incoming dictionary data against the model's schema and
constructs a new DataPoint instance accordingly.
Parameters:
-----------
- data (Dict[str, Any]): A dictionary containing the attributes of a DataPoint
instance.
Returns:
--------
- 'DataPoint': A new DataPoint instance constructed from the provided dictionary
data.
"""
return cls.model_validate(data)

View file

@ -1,4 +1,16 @@
from typing import Optional
from contextvars import ContextVar
from pydantic_settings import BaseSettings, SettingsConfigDict
file_storage_config = ContextVar("file_storage_config", default=None)
class StorageConfig(BaseSettings):
"""
Manage configuration settings for file storage.
"""
data_root_directory: str = ""
model_config = SettingsConfigDict(env_file=".env", extra="allow")
file_storage_config = ContextVar[Optional[StorageConfig]]("file_storage_config", default=None)

View file

@ -1,17 +1,18 @@
from cognee.base_config import get_base_config
from .config import file_storage_config
from .config import file_storage_config, StorageConfig
def get_global_storage_config():
base_config = get_base_config()
return {
"data_root_directory": base_config.data_root_directory,
}
return StorageConfig(
data_root_directory=base_config.data_root_directory,
)
def get_storage_config():
context_config = file_storage_config.get()
if context_config:
return context_config

View file

@ -30,7 +30,6 @@ class LLMConfig(BaseSettings):
Public methods include:
- ensure_env_vars_for_ollama
- to_dict
"""
structured_output_framework: str = "instructor"
@ -153,38 +152,6 @@ class LLMConfig(BaseSettings):
return self
def to_dict(self) -> dict:
"""
Convert the LLMConfig instance into a dictionary representation.
Returns:
--------
- dict: A dictionary containing the configuration settings of the LLMConfig
instance.
"""
return {
"provider": self.llm_provider,
"model": self.llm_model,
"endpoint": self.llm_endpoint,
"api_key": self.llm_api_key,
"api_version": self.llm_api_version,
"temperature": self.llm_temperature,
"streaming": self.llm_streaming,
"max_tokens": self.llm_max_tokens,
"transcription_model": self.transcription_model,
"graph_prompt_path": self.graph_prompt_path,
"rate_limit_enabled": self.llm_rate_limit_enabled,
"rate_limit_requests": self.llm_rate_limit_requests,
"rate_limit_interval": self.llm_rate_limit_interval,
"embedding_rate_limit_enabled": self.embedding_rate_limit_enabled,
"embedding_rate_limit_requests": self.embedding_rate_limit_requests,
"embedding_rate_limit_interval": self.embedding_rate_limit_interval,
"fallback_api_key": self.fallback_api_key,
"fallback_endpoint": self.fallback_endpoint,
"fallback_model": self.fallback_model,
}
@lru_cache
def get_llm_config():

View file

@ -1,8 +1,6 @@
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.shared.data_models import DefaultContentPrediction, SummarizedContent
from typing import Optional
import os
class CognifyConfig(BaseSettings):
@ -10,12 +8,6 @@ class CognifyConfig(BaseSettings):
summarization_model: object = SummarizedContent
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
return {
"classification_model": self.classification_model,
"summarization_model": self.summarization_model,
}
@lru_cache
def get_cognify_config():

View file

@ -3,5 +3,5 @@ from cognee.infrastructure.files.storage import get_file_storage, get_storage_co
async def prune_data():
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
data_root_directory = storage_config.data_root_directory
await get_file_storage(data_root_directory).remove_all()

View file

@ -39,16 +39,3 @@ class Data(Base):
lazy="noload",
cascade="all, delete",
)
def to_json(self) -> dict:
return {
"id": str(self.id),
"name": self.name,
"extension": self.extension,
"mimeType": self.mime_type,
"rawDataLocation": self.raw_data_location,
"createdAt": self.created_at.isoformat(),
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
"nodeSet": self.node_set,
# "datasets": [dataset.to_json() for dataset in self.datasets]
}

View file

@ -28,13 +28,3 @@ class Dataset(Base):
lazy="noload",
cascade="all, delete",
)
def to_json(self) -> dict:
return {
"id": str(self.id),
"name": self.name,
"createdAt": self.created_at.isoformat(),
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
"ownerId": str(self.owner_id),
"data": [data.to_json() for data in self.data],
}

View file

@ -27,14 +27,3 @@ class GraphRelationshipLedger(Base):
Index("idx_graph_relationship_ledger_source_node_id", "source_node_id"),
Index("idx_graph_relationship_ledger_destination_node_id", "destination_node_id"),
)
def to_json(self) -> dict:
return {
"id": str(self.id),
"source_node_id": str(self.parent_id),
"destination_node_id": str(self.child_id),
"creator_function": self.creator_function,
"created_at": self.created_at.isoformat(),
"deleted_at": self.deleted_at.isoformat() if self.deleted_at else None,
"user_id": str(self.user_id),
}

View file

@ -7,7 +7,7 @@ from .classify import classify
async def save_data_to_file(data: Union[str, BinaryIO], filename: str = None):
storage_config = get_storage_config()
data_root_directory = storage_config["data_root_directory"]
data_root_directory = storage_config.data_root_directory
classified_data = classify(data, filename)

View file

@ -0,0 +1,36 @@
def log_cognee_configuration(logger):
"""Log the current database configuration for all database types"""
# NOTE: Has to be imporated at runtime to avoid circular import
from cognee.infrastructure.databases.relational.config import get_relational_config
from cognee.infrastructure.databases.vector.config import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
try:
# Log relational database configuration
relational_config = get_relational_config()
logger.info(f"Relational database: {relational_config.db_provider}")
if relational_config.db_provider == "postgres":
logger.info(f"Postgres host: {relational_config.db_host}:{relational_config.db_port}")
logger.info(f"Postgres database: {relational_config.db_name}")
elif relational_config.db_provider == "sqlite":
logger.info(f"SQLite path: {relational_config.db_path}/{relational_config.db_name}")
logger.info(f"SQLite database: {relational_config.db_name}")
# Log vector database configuration
vector_config = get_vectordb_config()
logger.info(f"Vector database: {vector_config.vector_db_provider}")
if vector_config.vector_db_provider == "lancedb":
logger.info(f"Vector database path: {vector_config.vector_db_url}")
else:
logger.info(f"Vector database URL: {vector_config.vector_db_url}")
# Log graph database configuration
graph_config = get_graph_config()
logger.info(f"Graph database: {graph_config.graph_database_provider}")
if graph_config.graph_database_provider == "kuzu":
logger.info(f"Graph database path: {graph_config.graph_file_path}")
else:
logger.info(f"Graph database URL: {graph_config.graph_database_url}")
except Exception as e:
logger.warning(f"Could not retrieve database configuration: {str(e)}")

View file

@ -1,7 +1,9 @@
import asyncio
from uuid import UUID
from typing import Union
from typing import Optional, Union
from cognee.infrastructure.databases.graph.config import GraphConfig
from cognee.infrastructure.databases.vector.config import VectorConfig
from cognee.shared.logging_utils import get_logger
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.models import Data, Dataset

View file

@ -38,7 +38,7 @@ def get_current_settings() -> SettingsDict:
vector_config = get_vectordb_config()
relational_config = get_relational_config()
return dict(
return SettingsDict(
llm={
"provider": llm_config.llm_provider,
"model": llm_config.llm_model,

View file

@ -4,12 +4,12 @@ import logging
import structlog
import traceback
import platform
from datetime import datetime
from typing import Protocol
from pathlib import Path
import importlib.metadata
from datetime import datetime
from cognee import __version__ as cognee_version
from typing import Protocol
from cognee.modules.logging.log_cognee_configuration import log_cognee_configuration
# Configure external library logging
@ -165,44 +165,6 @@ def get_logger(name=None, level=None) -> LoggerInterface:
return logger
def log_database_configuration(logger):
"""Log the current database configuration for all database types"""
# NOTE: Has to be imporated at runtime to avoid circular import
from cognee.infrastructure.databases.relational.config import get_relational_config
from cognee.infrastructure.databases.vector.config import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
try:
# Log relational database configuration
relational_config = get_relational_config()
logger.info(f"Relational database: {relational_config.db_provider}")
if relational_config.db_provider == "postgres":
logger.info(f"Postgres host: {relational_config.db_host}:{relational_config.db_port}")
logger.info(f"Postgres database: {relational_config.db_name}")
elif relational_config.db_provider == "sqlite":
logger.info(f"SQLite path: {relational_config.db_path}")
logger.info(f"SQLite database: {relational_config.db_name}")
# Log vector database configuration
vector_config = get_vectordb_config()
logger.info(f"Vector database: {vector_config.vector_db_provider}")
if vector_config.vector_db_provider == "lancedb":
logger.info(f"Vector database path: {vector_config.vector_db_url}")
else:
logger.info(f"Vector database URL: {vector_config.vector_db_url}")
# Log graph database configuration
graph_config = get_graph_config()
logger.info(f"Graph database: {graph_config.graph_database_provider}")
if graph_config.graph_database_provider == "kuzu":
logger.info(f"Graph database path: {graph_config.graph_file_path}")
else:
logger.info(f"Graph database URL: {graph_config.graph_database_url}")
except Exception as e:
logger.warning(f"Could not retrieve database configuration: {str(e)}")
def cleanup_old_logs(logs_dir, max_files):
"""
Removes old log files, keeping only the most recent ones.
@ -396,9 +358,8 @@ def setup_logging(log_level=None, name=None):
)
logger.info("Want to learn more? Visit the Cognee documentation: https://docs.cognee.ai")
# Log database configuration
log_database_configuration(logger)
log_cognee_configuration(logger)
# Return the configured logger
return logger

View file

@ -1,18 +1,13 @@
"""This module contains utility functions for the cognee."""
import os
import pathlib
import requests
from datetime import datetime, timezone
import networkx as nx
import matplotlib.pyplot as plt
import http.server
import socketserver
from threading import Thread
import pathlib
from uuid import uuid4
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.graph import get_graph_engine
from threading import Thread
from datetime import datetime, timezone
# Analytics Proxy Url, currently hosted by Vercel

View file

@ -158,7 +158,7 @@ async def main():
assert len(history) == 8, "Search history is not correct."
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
data_root_directory = get_storage_config().data_root_directory
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)

View file

@ -121,7 +121,7 @@ async def main():
# Assert local data files are cleaned properly
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
data_root_directory = get_storage_config().data_root_directory
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
# Assert relational, vector and graph databases have been cleaned properly

View file

@ -113,7 +113,7 @@ async def main():
)
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
data_root_directory = get_storage_config().data_root_directory
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)

View file

@ -79,7 +79,7 @@ async def main():
# Assert local data files are cleaned properly
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
data_root_directory = get_storage_config().data_root_directory
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
# Assert relational, vector and graph databases have been cleaned properly

View file

@ -92,7 +92,7 @@ async def main():
assert len(history) == 8, "Search history is not correct."
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
data_root_directory = get_storage_config().data_root_directory
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)

View file

@ -117,7 +117,7 @@ async def main():
)
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
data_root_directory = get_storage_config().data_root_directory
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)

View file

@ -167,7 +167,7 @@ async def main():
await test_local_file_deletion(text, explanation_file_path)
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
data_root_directory = get_storage_config().data_root_directory
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)

View file

@ -84,7 +84,7 @@ async def main():
assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
data_root_directory = get_storage_config().data_root_directory
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)

View file

@ -93,7 +93,7 @@ async def main():
assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
data_root_directory = get_storage_config().data_root_directory
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)

View file

@ -75,7 +75,7 @@ async def main():
# Assert local data files are cleaned properly
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
data_root_directory = get_storage_config().data_root_directory
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
# Assert relational, vector and graph databases have been cleaned properly