fix: set config before using it

This commit is contained in:
Boris Arzentar 2024-05-27 14:18:39 +02:00
parent 624a0ac873
commit aef78c4a8f
42 changed files with 220 additions and 503 deletions

View file

@ -1,4 +1,4 @@
export default function TextLogo({ width = 285, height = 81, color = 'currentColor' }) { export default function TextLogo({ width = 285, height = 81, color = 'white' }) {
return ( return (
<svg width={width} height={height} viewBox="0 0 285 81" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width={width} height={height} viewBox="0 0 285 81" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M42.0964 46.4597C39.6678 49.6212 36.2632 51.8922 32.4114 52.92C28.5596 53.9479 24.4762 53.6749 20.7954 52.1436C17.1147 50.6123 14.0426 47.9083 12.0565 44.4517C10.0704 40.9951 9.2813 36.9793 9.81189 33.0282C10.3425 29.0771 12.163 25.4118 14.9907 22.6016C17.8184 19.7914 21.4949 17.9937 25.4493 17.4877C29.4036 16.9816 33.4144 17.7956 36.8586 19.8032" stroke={color} strokeWidth="6.03175" strokeLinecap="round"/> <path d="M42.0964 46.4597C39.6678 49.6212 36.2632 51.8922 32.4114 52.92C28.5596 53.9479 24.4762 53.6749 20.7954 52.1436C17.1147 50.6123 14.0426 47.9083 12.0565 44.4517C10.0704 40.9951 9.2813 36.9793 9.81189 33.0282C10.3425 29.0771 12.163 25.4118 14.9907 22.6016C17.8184 19.7914 21.4949 17.9937 25.4493 17.4877C29.4036 16.9816 33.4144 17.7956 36.8586 19.8032" stroke={color} strokeWidth="6.03175" strokeLinecap="round"/>

View file

@ -1,6 +1,7 @@
.wizardContent { .wizardContent {
width: 100%; width: 100%;
max-width: 400px; max-width: 400px;
height: max-content;
background: linear-gradient(90deg, #D82EB5 0.52%, #9245FD 103.83%); background: linear-gradient(90deg, #D82EB5 0.52%, #9245FD 103.83%);
padding: 24px; padding: 24px;
margin: 0 auto; margin: 0 auto;

View file

@ -1,6 +1,6 @@
from .api.v1.config.config import config # from .api.v1.config.config import config
from .api.v1.add.add import add # from .api.v1.add.add import add
from .api.v1.cognify.cognify import cognify # from .api.v1.cognify.cognify import cognify
from .api.v1.datasets.datasets import datasets # from .api.v1.datasets.datasets import datasets
from .api.v1.search.search import search, SearchType # from .api.v1.search.search import search, SearchType
from .api.v1.prune import prune # from .api.v1.prune import prune

View file

@ -19,12 +19,7 @@ logging.basicConfig(
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from cognee.config import Config app = FastAPI(debug = True)
config = Config()
config.load()
app = FastAPI(debug=True)
origins = [ origins = [
"http://frontend:3000", "http://frontend:3000",
@ -59,12 +54,12 @@ class Payload(BaseModel):
@app.get("/datasets", response_model=list) @app.get("/datasets", response_model=list)
async def get_datasets(): async def get_datasets():
from cognee import datasets from cognee.api.v1.datasets.datasets import datasets
return datasets.list_datasets() return datasets.list_datasets()
@app.delete("/datasets/{dataset_id}", response_model=dict) @app.delete("/datasets/{dataset_id}", response_model=dict)
async def delete_dataset(dataset_id: str): async def delete_dataset(dataset_id: str):
from cognee import datasets from cognee.api.v1.datasets.datasets import datasets
datasets.delete_dataset(dataset_id) datasets.delete_dataset(dataset_id)
return JSONResponse( return JSONResponse(
status_code=200, status_code=200,
@ -73,22 +68,23 @@ async def delete_dataset(dataset_id: str):
@app.get("/datasets/{dataset_id}/graph", response_model=list) @app.get("/datasets/{dataset_id}/graph", response_model=list)
async def get_dataset_graph(dataset_id: str): async def get_dataset_graph(dataset_id: str):
from cognee import utils from cognee.utils import render_graph
from cognee.infrastructure import infrastructure_config from cognee.infrastructure.databases.graph import get_graph_config
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
graph_engine = infrastructure_config.get_config()["graph_engine"] graph_config = get_graph_config()
graph_engine = graph_config.graph_engine
graph_client = await get_graph_client(graph_engine) graph_client = await get_graph_client(graph_engine)
graph_url = await utils.render_graph(graph_client.graph) graph_url = await render_graph(graph_client.graph)
return JSONResponse( return JSONResponse(
status_code=200, status_code = 200,
content=str(graph_url), content = str(graph_url),
) )
@app.get("/datasets/{dataset_id}/data", response_model=list) @app.get("/datasets/{dataset_id}/data", response_model=list)
async def get_dataset_data(dataset_id: str): async def get_dataset_data(dataset_id: str):
from cognee import datasets from cognee.api.v1.datasets.datasets import datasets
dataset_data = datasets.list_data(dataset_id) dataset_data = datasets.list_data(dataset_id)
if dataset_data is None: if dataset_data is None:
raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.") raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.")
@ -105,7 +101,7 @@ async def get_dataset_data(dataset_id: str):
@app.get("/datasets/status", response_model=dict) @app.get("/datasets/status", response_model=dict)
async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None): async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset")] = None):
from cognee import datasets as cognee_datasets from cognee.api.v1.datasets.datasets import datasets as cognee_datasets
datasets_statuses = cognee_datasets.get_status(datasets) datasets_statuses = cognee_datasets.get_status(datasets)
return JSONResponse( return JSONResponse(
@ -115,7 +111,7 @@ async def get_dataset_status(datasets: Annotated[List[str], Query(alias="dataset
@app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse) @app.get("/datasets/{dataset_id}/data/{data_id}/raw", response_class=FileResponse)
async def get_raw_data(dataset_id: str, data_id: str): async def get_raw_data(dataset_id: str, data_id: str):
from cognee import datasets from cognee.api.v1.datasets.datasets import datasets
dataset_data = datasets.list_data(dataset_id) dataset_data = datasets.list_data(dataset_id)
if dataset_data is None: if dataset_data is None:
raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.") raise HTTPException(status_code=404, detail=f"Dataset ({dataset_id}) not found.")
@ -134,7 +130,7 @@ async def add(
data: List[UploadFile] = File(...), data: List[UploadFile] = File(...),
): ):
""" This endpoint is responsible for adding data to the graph.""" """ This endpoint is responsible for adding data to the graph."""
from cognee import add as cognee_add from cognee.api.v1.add import add as cognee_add
try: try:
if isinstance(data, str) and data.startswith("http"): if isinstance(data, str) and data.startswith("http"):
if "github" in data: if "github" in data:
@ -178,7 +174,7 @@ class CognifyPayload(BaseModel):
@app.post("/cognify", response_model=dict) @app.post("/cognify", response_model=dict)
async def cognify(payload: CognifyPayload): async def cognify(payload: CognifyPayload):
""" This endpoint is responsible for the cognitive processing of the content.""" """ This endpoint is responsible for the cognitive processing of the content."""
from cognee import cognify as cognee_cognify from cognee.api.v1.cognify.cognify import cognify as cognee_cognify
try: try:
await cognee_cognify(payload.datasets) await cognee_cognify(payload.datasets)
return JSONResponse( return JSONResponse(
@ -197,7 +193,7 @@ class SearchPayload(BaseModel):
@app.post("/search", response_model=dict) @app.post("/search", response_model=dict)
async def search(payload: SearchPayload): async def search(payload: SearchPayload):
""" This endpoint is responsible for searching for nodes in the graph.""" """ This endpoint is responsible for searching for nodes in the graph."""
from cognee import search as cognee_search from cognee.api.v1.search import search as cognee_search
try: try:
search_type = payload.query_params["searchType"] search_type = payload.query_params["searchType"]
params = { params = {
@ -254,17 +250,21 @@ def start_api_server(host: str = "0.0.0.0", port: int = 8000):
port (int): The port for the server. port (int): The port for the server.
""" """
try: try:
logger.info(f"Starting server at {host}:{port}") logger.info("Starting server at %s:%s", host, port)
from cognee import config, prune
from cognee.base_config import get_base_config
base_config = get_base_config()
data_directory_path = os.path.abspath(".data_storage") data_directory_path = os.path.abspath(".data_storage")
config.data_root_directory(data_directory_path) base_config.data_root_directory = data_directory_path
cognee_directory_path = os.path.abspath(".cognee_system") cognee_directory_path = os.path.abspath(".cognee_system")
config.system_root_directory(cognee_directory_path) base_config.system_root_directory = cognee_directory_path
asyncio.run(prune.prune_system()) from cognee.modules.data.deletion import prune_system
asyncio.run(prune_system())
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host = host, port = port)
except Exception as e: except Exception as e:
logger.exception(f"Failed to start server: {e}") logger.exception(f"Failed to start server: {e}")
# Here you could add any cleanup code or error recovery code. # Here you could add any cleanup code or error recovery code.

View file

@ -4,7 +4,6 @@ import asyncio
import dlt import dlt
import duckdb import duckdb
import cognee.modules.ingestion as ingestion import cognee.modules.ingestion as ingestion
from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.discovery import discover_directory_datasets from cognee.modules.discovery import discover_directory_datasets
from cognee.utils import send_telemetry from cognee.utils import send_telemetry
@ -12,9 +11,6 @@ from cognee.base_config import get_base_config
base_config = get_base_config() base_config = get_base_config()
from cognee.infrastructure.databases.relational.config import get_relationaldb_config from cognee.infrastructure.databases.relational.config import get_relationaldb_config
relational_config = get_relationaldb_config()
async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = None): async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_name: str = None):
if isinstance(data, str): if isinstance(data, str):
# data is a data directory path # data is a data directory path
@ -54,8 +50,6 @@ async def add_files(file_paths: List[str], dataset_name: str):
# infra_config = infrastructure_config.get_config() # infra_config = infrastructure_config.get_config()
data_directory_path = base_config.data_root_directory data_directory_path = base_config.data_root_directory
LocalStorage.ensure_directory_exists(relational_config.database_directory_path)
processed_file_paths = [] processed_file_paths = []
for file_path in file_paths: for file_path in file_paths:
@ -73,6 +67,7 @@ async def add_files(file_paths: List[str], dataset_name: str):
else: else:
processed_file_paths.append(file_path) processed_file_paths.append(file_path)
relational_config = get_relationaldb_config()
db = duckdb.connect(relational_config.db_file_path) db = duckdb.connect(relational_config.db_file_path)
destination = dlt.destinations.duckdb( destination = dlt.destinations.duckdb(

View file

@ -1,55 +0,0 @@
import asyncio
from uuid import UUID, uuid4
from typing import Union, BinaryIO, List
import cognee.modules.ingestion as ingestion
from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.databases.relational.config import get_relationaldb_config
relational_config = get_relationaldb_config()
class DatasetException(Exception):
message: str
def __init__(self, message: str):
self.message = message
async def add_standalone(
data: Union[str, BinaryIO, List[Union[str, BinaryIO]]],
dataset_id: UUID = uuid4(),
dataset_name: str = None
):
db_engine = relational_config.database_engine
if db_engine.is_db_done is not True:
await db_engine.ensure_tables()
if not data:
raise DatasetException("Data must be provided to cognee.add(data: str)")
if isinstance(data, list):
promises = []
for data_item in data:
promises.append(add_standalone(data_item, dataset_id, dataset_name))
results = await asyncio.gather(*promises)
return results
if is_data_path(data):
with open(data.replace("file://", ""), "rb") as file:
return await add_standalone(file, dataset_id, dataset_name)
classified_data = ingestion.classify(data)
data_id = ingestion.identify(classified_data)
await ingestion.save(dataset_id, dataset_name, data_id, classified_data)
return dataset_id
# await ingestion.vectorize(dataset_id, dataset_name, data_id, classified_data)
def is_data_path(data: str) -> bool:
return False if not isinstance(data, str) else data.startswith("file://")

View file

@ -5,9 +5,7 @@ import logging
import nltk import nltk
from nltk.corpus import stopwords from nltk.corpus import stopwords
from cognee.config import Config from cognee.config import Config
from cognee.infrastructure.data.chunking.LangchainChunkingEngine import LangchainChunkEngine
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.vector.embeddings.DefaultEmbeddingEngine import LiteLLMEmbeddingEngine
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \ from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
graph_ready_output, connect_nodes_in_graph graph_ready_output, connect_nodes_in_graph
from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks, add_data_chunks_basic_rag from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks, add_data_chunks_basic_rag
@ -21,38 +19,23 @@ from cognee.modules.cognify.graph.add_cognitive_layers import add_cognitive_laye
# from cognee.modules.cognify.graph.initialize_graph import initialize_graph # from cognee.modules.cognify.graph.initialize_graph import initialize_graph
from cognee.infrastructure.files.utils.guess_file_type import guess_file_type, FileTypeException from cognee.infrastructure.files.utils.guess_file_type import guess_file_type, FileTypeException
from cognee.infrastructure.files.utils.extract_text_from_file import extract_text_from_file from cognee.infrastructure.files.utils.extract_text_from_file import extract_text_from_file
from cognee.infrastructure import infrastructure_config
from cognee.modules.data.get_content_categories import get_content_categories from cognee.modules.data.get_content_categories import get_content_categories
from cognee.modules.data.get_content_summary import get_content_summary from cognee.modules.data.get_content_summary import get_content_summary
from cognee.modules.data.get_cognitive_layers import get_cognitive_layers from cognee.modules.data.get_cognitive_layers import get_cognitive_layers
from cognee.modules.data.get_layer_graphs import get_layer_graphs from cognee.modules.data.get_layer_graphs import get_layer_graphs
from cognee.shared.data_models import ChunkStrategy, KnowledgeGraph from cognee.shared.data_models import KnowledgeGraph
from cognee.utils import send_telemetry from cognee.utils import send_telemetry
from cognee.modules.tasks import create_task_status_table, update_task_status from cognee.modules.tasks import create_task_status_table, update_task_status
from cognee.shared.SourceCodeGraph import SourceCodeGraph from cognee.shared.SourceCodeGraph import SourceCodeGraph
from asyncio import Lock from asyncio import Lock
from cognee.modules.tasks import get_task_status from cognee.modules.tasks import get_task_status
from cognee.base_config import get_base_config
from cognee.infrastructure.data.chunking.config import get_chunk_config from cognee.infrastructure.data.chunking.config import get_chunk_config
from cognee.modules.cognify.config import get_cognify_config from cognee.modules.cognify.config import get_cognify_config
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
from cognee.infrastructure.databases.relational.config import get_relationaldb_config from cognee.infrastructure.databases.relational.config import get_relationaldb_config
graph_config = get_graph_config()
config = Config() config = Config()
config.load() config.load()
relational_config = get_relationaldb_config()
cognify_config = get_cognify_config()
chunk_config = get_chunk_config()
base_config = get_base_config()
embedding_config = get_embedding_config()
# aclient = instructor.patch(OpenAI())
USER_ID = "default_user" USER_ID = "default_user"
logger = logging.getLogger("cognify") logger = logging.getLogger("cognify")
@ -66,10 +49,11 @@ async def cognify(datasets: Union[str, List[str]] = None):
stopwords.ensure_loaded() stopwords.ensure_loaded()
create_task_status_table() create_task_status_table()
graph_config = get_graph_config()
graph_db_type = graph_config.graph_engine graph_db_type = graph_config.graph_engine
graph_client = await get_graph_client(graph_db_type) graph_client = await get_graph_client(graph_db_type)
relational_config = get_relationaldb_config()
db_engine = relational_config.database_engine db_engine = relational_config.database_engine
if datasets is None or len(datasets) == 0: if datasets is None or len(datasets) == 0:
@ -108,7 +92,7 @@ async def cognify(datasets: Union[str, List[str]] = None):
if dataset_name in added_dataset: if dataset_name in added_dataset:
dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset))) dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
chunk_config = get_chunk_config()
chunk_engine = chunk_config.chunk_engine chunk_engine = chunk_config.chunk_engine
chunk_strategy = chunk_config.chunk_strategy chunk_strategy = chunk_config.chunk_strategy
@ -190,10 +174,11 @@ async def cognify(datasets: Union[str, List[str]] = None):
async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict, document_id: str): async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict, document_id: str):
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).") print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
graph_config = get_graph_config()
graph_client = await get_graph_client(graph_config.graph_engine) graph_client = await get_graph_client(graph_config.graph_engine)
print("graph_client", graph_client) cognify_config = get_cognify_config()
graph_topology = cognify_config.graph_model graph_topology = cognify_config.graph_model
if graph_topology == SourceCodeGraph: if graph_topology == SourceCodeGraph:
classified_categories = [{"data_type": "text", "category_name": "Code and functions"}] classified_categories = [{"data_type": "text", "category_name": "Code and functions"}]
elif graph_topology == KnowledgeGraph: elif graph_topology == KnowledgeGraph:
@ -227,7 +212,7 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
if cognify_config.connect_documents is True: if cognify_config.connect_documents is True:
db_engine = relational_config.database_engine db_engine = get_relationaldb_config().database_engine
relevant_documents_to_connect = db_engine.fetch_cognify_data(excluded_document_id = document_id) relevant_documents_to_connect = db_engine.fetch_cognify_data(excluded_document_id = document_id)
list_of_nodes = [] list_of_nodes = []
@ -258,52 +243,52 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
if __name__ == "__main__": # if __name__ == "__main__":
async def test(): # async def test():
# await prune.prune_system() # # await prune.prune_system()
# # # # #
# from cognee.api.v1.add import add # # from cognee.api.v1.add import add
# data_directory_path = os.path.abspath("../../../.data") # # data_directory_path = os.path.abspath("../../../.data")
# # print(data_directory_path) # # # print(data_directory_path)
# # config.data_root_directory(data_directory_path) # # # config.data_root_directory(data_directory_path)
# # cognee_directory_path = os.path.abspath("../.cognee_system") # # # cognee_directory_path = os.path.abspath("../.cognee_system")
# # config.system_root_directory(cognee_directory_path) # # # config.system_root_directory(cognee_directory_path)
# # #
# await add("data://" +data_directory_path, "example") # # await add("data://" +data_directory_path, "example")
text = """import subprocess # text = """import subprocess
def show_all_processes(): # def show_all_processes():
process = subprocess.Popen(['ps', 'aux'], stdout=subprocess.PIPE) # process = subprocess.Popen(['ps', 'aux'], stdout=subprocess.PIPE)
output, error = process.communicate() # output, error = process.communicate()
if error: # if error:
print(f"Error: {error}") # print(f"Error: {error}")
else: # else:
print(output.decode()) # print(output.decode())
show_all_processes()""" # show_all_processes()"""
from cognee.api.v1.add import add # from cognee.api.v1.add import add
await add([text], "example_dataset") # await add([text], "example_dataset")
infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine() }) # infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine() })
from cognee.shared.SourceCodeGraph import SourceCodeGraph # from cognee.shared.SourceCodeGraph import SourceCodeGraph
from cognee.api.v1.config import config # from cognee.api.v1.config import config
# config.set_graph_model(SourceCodeGraph) # # config.set_graph_model(SourceCodeGraph)
# config.set_classification_model(CodeContentPrediction) # # config.set_classification_model(CodeContentPrediction)
# graph = await cognify() # # graph = await cognify()
vector_client = infrastructure_config.get_config("vector_engine") # vector_client = infrastructure_config.get_config("vector_engine")
out = await vector_client.search(collection_name ="basic_rag", query_text="show_all_processes", limit=10) # out = await vector_client.search(collection_name ="basic_rag", query_text="show_all_processes", limit=10)
print("results", out) # print("results", out)
# # #
# from cognee.utils import render_graph # # from cognee.utils import render_graph
# # #
# await render_graph(graph, include_color=True, include_nodes=False, include_size=False) # # await render_graph(graph, include_color=True, include_nodes=False, include_size=False)
import asyncio # import asyncio
asyncio.run(test()) # asyncio.run(test())

View file

@ -4,80 +4,93 @@ from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.data.chunking.config import get_chunk_config from cognee.infrastructure.data.chunking.config import get_chunk_config
from cognee.modules.cognify.config import get_cognify_config from cognee.modules.cognify.config import get_cognify_config
cognify_config = get_cognify_config()
chunk_config = get_chunk_config()
graph_config = get_graph_config()
base_config = get_base_config()
class config(): class config():
@staticmethod @staticmethod
def system_root_directory(system_root_directory: str): def system_root_directory(system_root_directory: str):
base_config = get_base_config()
base_config.system_root_directory = system_root_directory base_config.system_root_directory = system_root_directory
@staticmethod @staticmethod
def data_root_directory(data_root_directory: str): def data_root_directory(data_root_directory: str):
base_config = get_base_config()
base_config.data_root_directory = data_root_directory base_config.data_root_directory = data_root_directory
@staticmethod @staticmethod
def monitoring_tool(monitoring_tool: object): def monitoring_tool(monitoring_tool: object):
base_config = get_base_config()
base_config.monitoring_tool = monitoring_tool base_config.monitoring_tool = monitoring_tool
@staticmethod @staticmethod
def set_classification_model(classification_model: object): def set_classification_model(classification_model: object):
cognify_config = get_cognify_config()
cognify_config.classification_model = classification_model cognify_config.classification_model = classification_model
@staticmethod @staticmethod
def set_summarization_model(summarization_model: object): def set_summarization_model(summarization_model: object):
cognify_config = get_cognify_config()
cognify_config.summarization_model=summarization_model cognify_config.summarization_model=summarization_model
@staticmethod @staticmethod
def set_labeling_model(labeling_model: object): def set_labeling_model(labeling_model: object):
cognify_config = get_cognify_config()
cognify_config.labeling_model =labeling_model cognify_config.labeling_model =labeling_model
@staticmethod @staticmethod
def set_graph_model(graph_model: object): def set_graph_model(graph_model: object):
graph_config.graph_model =graph_model graph_config = get_graph_config()
graph_config.graph_model = graph_model
@staticmethod @staticmethod
def set_cognitive_layer_model(cognitive_layer_model: object): def set_cognitive_layer_model(cognitive_layer_model: object):
cognify_config.cognitive_layer_model =cognitive_layer_model cognify_config = get_cognify_config()
cognify_config.cognitive_layer_model = cognitive_layer_model
@staticmethod @staticmethod
def set_graph_engine(graph_engine: object): def set_graph_engine(graph_engine: object):
graph_config.graph_engine =graph_engine graph_config = get_graph_config()
graph_config.graph_engine = graph_engine
@staticmethod @staticmethod
def llm_provider(llm_provider: str): def llm_provider(llm_provider: str):
graph_config = get_graph_config()
graph_config.llm_provider = llm_provider graph_config.llm_provider = llm_provider
@staticmethod @staticmethod
def llm_endpoint(llm_endpoint: str): def llm_endpoint(llm_endpoint: str):
graph_config = get_graph_config()
graph_config.llm_endpoint = llm_endpoint graph_config.llm_endpoint = llm_endpoint
@staticmethod @staticmethod
def llm_model(llm_model: str): def llm_model(llm_model: str):
graph_config = get_graph_config()
graph_config.llm_model = llm_model graph_config.llm_model = llm_model
@staticmethod @staticmethod
def intra_layer_score_treshold(intra_layer_score_treshold: str): def intra_layer_score_treshold(intra_layer_score_treshold: str):
cognify_config.intra_layer_score_treshold =intra_layer_score_treshold cognify_config = get_cognify_config()
cognify_config.intra_layer_score_treshold = intra_layer_score_treshold
@staticmethod @staticmethod
def connect_documents(connect_documents: bool): def connect_documents(connect_documents: bool):
cognify_config = get_cognify_config()
cognify_config.connect_documents = connect_documents cognify_config.connect_documents = connect_documents
@staticmethod @staticmethod
def set_chunk_strategy(chunk_strategy: object): def set_chunk_strategy(chunk_strategy: object):
chunk_config = get_chunk_config()
chunk_config.chunk_strategy = chunk_strategy chunk_config.chunk_strategy = chunk_strategy
@staticmethod @staticmethod
def set_graph_topology(graph_topology: object): def set_graph_topology(graph_topology: object):
get_cognify_config.graph_topology =graph_topology cognify_config = get_cognify_config()
cognify_config.graph_topology = graph_topology

View file

@ -3,12 +3,11 @@ from cognee.modules.discovery import discover_directory_datasets
from cognee.modules.tasks import get_task_status from cognee.modules.tasks import get_task_status
from cognee.infrastructure.databases.relational.config import get_relationaldb_config from cognee.infrastructure.databases.relational.config import get_relationaldb_config
relational_config = get_relationaldb_config()
class datasets(): class datasets():
@staticmethod @staticmethod
def list_datasets(): def list_datasets():
db = relational_config.db_engine relational_config = get_relationaldb_config()
db = relational_config.database_engine
return db.get_datasets() return db.get_datasets()
@staticmethod @staticmethod
@ -17,7 +16,8 @@ class datasets():
@staticmethod @staticmethod
def list_data(dataset_name: str): def list_data(dataset_name: str):
db = relational_config.db_engine relational_config = get_relationaldb_config()
db = relational_config.database_engine
try: try:
return db.get_files_metadata(dataset_name) return db.get_files_metadata(dataset_name)
except CatalogException: except CatalogException:
@ -32,7 +32,8 @@ class datasets():
@staticmethod @staticmethod
def delete_dataset(dataset_id: str): def delete_dataset(dataset_id: str):
db = relational_config.db_engine relational_config = get_relationaldb_config()
db = relational_config.database_engine
try: try:
return db.delete_table(dataset_id) return db.delete_table(dataset_id)
except CatalogException: except CatalogException:

View file

@ -1,11 +1,11 @@
from cognee.modules.data.deletion import prune_system from cognee.modules.data.deletion import prune_system
from cognee.base_config import get_base_config from cognee.base_config import get_base_config
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
base_config = get_base_config()
class prune(): class prune():
@staticmethod @staticmethod
async def prune_data(): async def prune_data():
base_config = get_base_config()
data_root_directory = base_config.data_root_directory data_root_directory = base_config.data_root_directory
LocalStorage.remove_all(data_root_directory) LocalStorage.remove_all(data_root_directory)

View file

@ -11,21 +11,19 @@ from cognee.modules.search.graph.search_categories import search_categories
from cognee.modules.search.graph.search_neighbour import search_neighbour from cognee.modules.search.graph.search_neighbour import search_neighbour
from cognee.modules.search.graph.search_summary import search_summary from cognee.modules.search.graph.search_summary import search_summary
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.infrastructure import infrastructure_config
from cognee.utils import send_telemetry from cognee.utils import send_telemetry
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
graph_config = get_graph_config()
class SearchType(Enum): class SearchType(Enum):
ADJACENT = 'ADJACENT' ADJACENT = "ADJACENT"
SIMILARITY = 'SIMILARITY' SIMILARITY = "SIMILARITY"
CATEGORIES = 'CATEGORIES' CATEGORIES = "CATEGORIES"
NEIGHBOR = 'NEIGHBOR' NEIGHBOR = "NEIGHBOR"
SUMMARY = 'SUMMARY' SUMMARY = "SUMMARY"
SUMMARY_CLASSIFICATION = 'SUMMARY_CLASSIFICATION' SUMMARY_CLASSIFICATION = "SUMMARY_CLASSIFICATION"
NODE_CLASSIFICATION = 'NODE_CLASSIFICATION' NODE_CLASSIFICATION = "NODE_CLASSIFICATION"
DOCUMENT_CLASSIFICATION = 'DOCUMENT_CLASSIFICATION', DOCUMENT_CLASSIFICATION = "DOCUMENT_CLASSIFICATION",
CYPHER = 'CYPHER' CYPHER = "CYPHER"
@staticmethod @staticmethod
def from_str(name: str): def from_str(name: str):
@ -38,7 +36,7 @@ class SearchParameters(BaseModel):
search_type: SearchType search_type: SearchType
params: Dict[str, Any] params: Dict[str, Any]
@field_validator('search_type', mode='before') @field_validator("search_type", mode="before")
def convert_string_to_enum(cls, value): def convert_string_to_enum(cls, value):
if isinstance(value, str): if isinstance(value, str):
return SearchType.from_str(value) return SearchType.from_str(value)
@ -51,6 +49,7 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
async def specific_search(query_params: List[SearchParameters]) -> List: async def specific_search(query_params: List[SearchParameters]) -> List:
graph_config = get_graph_config()
graph_client = await get_graph_client(graph_config.graph_engine) graph_client = await get_graph_client(graph_config.graph_engine)
graph = graph_client.graph graph = graph_client.graph

View file

@ -1,6 +1,5 @@
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.root_dir import get_absolute_path from cognee.root_dir import get_absolute_path
from cognee.shared.data_models import MonitoringTool from cognee.shared.data_models import MonitoringTool

View file

@ -7,9 +7,9 @@ from typing import Optional, Dict, Any
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from dotenv import load_dotenv from dotenv import load_dotenv
from cognee.root_dir import get_absolute_path
from cognee.shared.data_models import ChunkStrategy, DefaultGraphModel
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
def load_dontenv(): def load_dontenv():
base_dir = Path(__file__).resolve().parent.parent base_dir = Path(__file__).resolve().parent.parent
# Load the .env file from the base directory # Load the .env file from the base directory
@ -33,27 +33,6 @@ class Config:
) )
) )
system_root_directory = get_absolute_path(".cognee_system")
logging.info("system_root_directory: %s", system_root_directory)
data_root_directory = os.getenv("DATA_PATH", get_absolute_path(".data"))
vectordb: str = os.getenv("VECTORDB", "weaviate")
qdrant_path: str = os.getenv("QDRANT_PATH", None)
qdrant_url: str = os.getenv("QDRANT_URL", None)
qdrant_api_key: str = os.getenv("QDRANT_API_KEY", None)
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
# Model parameters
llm_provider: str = os.getenv("LLM_PROVIDER", "openai") #openai, or custom or ollama
llm_model: str = os.getenv("LLM_MODEL", "gpt-4")
llm_api_key: str = os.getenv("LLM_API_KEY", os.getenv("OPENAI_API_KEY"))
llm_endpoint: str = os.getenv("LLM_ENDPOINT", None)
# custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1" # custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
# custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint # custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint
# custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY") # custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY")
@ -63,7 +42,6 @@ class Config:
# openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o" # openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o"
# model_endpoint: str = "openai" # model_endpoint: str = "openai"
# llm_api_key: Optional[str] = os.getenv("OPENAI_API_KEY") # llm_api_key: Optional[str] = os.getenv("OPENAI_API_KEY")
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
# openai_embedding_model = "text-embedding-3-large" # openai_embedding_model = "text-embedding-3-large"
# openai_embedding_dimensions = 3072 # openai_embedding_dimensions = 3072
# litellm_embedding_model = "text-embedding-3-large" # litellm_embedding_model = "text-embedding-3-large"
@ -77,23 +55,9 @@ class Config:
embedding_dimensions: int = 1024 embedding_dimensions: int = 1024
connect_documents: bool = False connect_documents: bool = False
# Database parameters
graph_database_provider: str = os.getenv("GRAPH_DB_PROVIDER", "NETWORKX")
graph_topology:str = DefaultGraphModel
cognitive_layers_limit: int = 2
from cognee.shared.data_models import MonitoringTool
# Monitoring tool
monitoring_tool: str = os.getenv("MONITORING_TOOL", MonitoringTool.LANGFUSE)
weaviate_url: str = os.getenv("WEAVIATE_URL")
weaviate_api_key: str = os.getenv("WEAVIATE_API_KEY")
# Model parameters and configuration for interlayer scoring # Model parameters and configuration for interlayer scoring
intra_layer_score_treshold: float = 0.98 intra_layer_score_treshold: float = 0.98
# Client ID # Client ID
anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex) anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex)

View file

@ -1,29 +1,18 @@
import logging import logging
import os
from cognee.config import Config from cognee.config import Config
from .data.chunking.config import get_chunk_config from .data.chunking.config import get_chunk_config
from .databases.relational import DatabaseEngine
from .llm.llm_interface import LLMInterface from .llm.llm_interface import LLMInterface
from .llm.get_llm_client import get_llm_client from .llm.get_llm_client import get_llm_client
from .files.storage import LocalStorage
from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \ from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \
LabeledContent, DefaultCognitiveLayer LabeledContent, DefaultCognitiveLayer
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
config = Config() config = Config()
config.load() config.load()
from cognee.infrastructure.databases.relational.config import get_relationaldb_config
from cognee.infrastructure.databases.vector.config import get_vectordb_config
vector_db_config = get_vectordb_config()
relational = get_relationaldb_config()
chunk_config = get_chunk_config() chunk_config = get_chunk_config()
class InfrastructureConfig(): class InfrastructureConfig():
system_root_directory: str = config.system_root_directory
data_root_directory: str = config.data_root_directory
llm_provider: str = config.llm_provider
database_engine: DatabaseEngine = None
graph_engine: GraphDBType = None graph_engine: GraphDBType = None
llm_engine: LLMInterface = None llm_engine: LLMInterface = None
classification_model = None classification_model = None
@ -34,28 +23,14 @@ class InfrastructureConfig():
intra_layer_score_treshold = None intra_layer_score_treshold = None
embedding_engine = None embedding_engine = None
connect_documents = config.connect_documents connect_documents = config.connect_documents
database_directory_path: str = None
database_file_path: str = None
chunk_strategy = chunk_config.chunk_strategy chunk_strategy = chunk_config.chunk_strategy
chunk_engine = None chunk_engine = None
graph_topology = config.graph_topology
monitoring_tool = config.monitoring_tool
llm_provider: str = None llm_provider: str = None
llm_model: str = None llm_model: str = None
llm_endpoint: str = None llm_endpoint: str = None
llm_api_key: str = None llm_api_key: str = None
def get_config(self, config_entity: str = None) -> dict: def get_config(self, config_entity: str = None) -> dict:
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
db_path = os.path.join(self.system_root_directory,relational.db_path)
LocalStorage.ensure_directory_exists(db_path)
self.database_engine = relational.db_engine
if self.graph_engine is None: if self.graph_engine is None:
self.graph_engine = GraphDBType.NETWORKX self.graph_engine = GraphDBType.NETWORKX
@ -86,29 +61,14 @@ class InfrastructureConfig():
if self.chunk_engine is None: if self.chunk_engine is None:
self.chunk_engine = chunk_config.chunk_engine self.chunk_engine = chunk_config.chunk_engine
if self.graph_topology is None:
self.graph_topology = config.graph_topology
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None: if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
self.llm_engine = get_llm_client() self.llm_engine = get_llm_client()
if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None:
self.database_directory_path = self.system_root_directory + "/" + relational.db_path
if self.database_directory_path is None:
self.database_directory_path = self.system_root_directory + "/" + relational.db_path
if (config_entity is None or config_entity == "database_file_path") and self.database_file_path is None:
self.database_file_path = self.system_root_directory + "/" + relational.db_path + "/" + relational.db_name
if config_entity is not None: if config_entity is not None:
return getattr(self, config_entity) return getattr(self, config_entity)
return { return {
"llm_engine": self.llm_engine, "llm_engine": self.llm_engine,
"database_engine": self.database_engine,
"system_root_directory": self.system_root_directory,
"data_root_directory": self.data_root_directory,
"graph_engine": self.graph_engine,
"classification_model": self.classification_model, "classification_model": self.classification_model,
"summarization_model": self.summarization_model, "summarization_model": self.summarization_model,
"labeling_model": self.labeling_model, "labeling_model": self.labeling_model,
@ -118,29 +78,11 @@ class InfrastructureConfig():
"intra_layer_score_treshold": self.intra_layer_score_treshold, "intra_layer_score_treshold": self.intra_layer_score_treshold,
"embedding_engine": self.embedding_engine, "embedding_engine": self.embedding_engine,
"connect_documents": self.connect_documents, "connect_documents": self.connect_documents,
"database_directory_path": self.database_directory_path,
"database_path": self.database_file_path,
"chunk_strategy": self.chunk_strategy, "chunk_strategy": self.chunk_strategy,
"chunk_engine": self.chunk_engine, "chunk_engine": self.chunk_engine,
"graph_topology": self.graph_topology
} }
def set_config(self, new_config: dict): def set_config(self, new_config: dict):
if "system_root_directory" in new_config:
self.system_root_directory = new_config["system_root_directory"]
if "data_root_directory" in new_config:
self.data_root_directory = new_config["data_root_directory"]
if "database_engine" in new_config:
self.database_engine = new_config["database_engine"]
if "llm_engine" in new_config:
self.llm_engine = new_config["llm_engine"]
if "graph_engine" in new_config:
self.graph_engine = new_config["graph_engine"]
if "classification_model" in new_config: if "classification_model" in new_config:
self.classification_model = new_config["classification_model"] self.classification_model = new_config["classification_model"]
@ -150,12 +92,6 @@ class InfrastructureConfig():
if "labeling_model" in new_config: if "labeling_model" in new_config:
self.labeling_model = new_config["labeling_model"] self.labeling_model = new_config["labeling_model"]
if "graph_model" in new_config:
self.graph_model = new_config["graph_model"]
if "llm_provider" in new_config:
self.llm_provider = new_config["llm_provider"]
if "cognitive_layer_model" in new_config: if "cognitive_layer_model" in new_config:
self.cognitive_layer_model = new_config["cognitive_layer_model"] self.cognitive_layer_model = new_config["cognitive_layer_model"]
@ -174,7 +110,4 @@ class InfrastructureConfig():
if "chunk_engine" in new_config: if "chunk_engine" in new_config:
self.chunk_engine = new_config["chunk_engine"] self.chunk_engine = new_config["chunk_engine"]
if "graph_topology" in new_config:
self.graph_topology = new_config["graph_topology"]
infrastructure_config = InfrastructureConfig() infrastructure_config = InfrastructureConfig()

View file

@ -0,0 +1 @@
from .config import get_graph_config

View file

@ -2,11 +2,8 @@
import os import os
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.relational.config import get_relationaldb_config from cognee.infrastructure.databases.relational.config import get_relationaldb_config
from cognee.shared.data_models import DefaultGraphModel, GraphDBType from cognee.shared.data_models import DefaultGraphModel, GraphDBType
relational_config = get_relationaldb_config()
base_config = get_base_config()
class GraphConfig(BaseSettings): class GraphConfig(BaseSettings):
graph_filename: str = "cognee_graph.pkl" graph_filename: str = "cognee_graph.pkl"
@ -15,7 +12,7 @@ class GraphConfig(BaseSettings):
graph_database_username: str = "" graph_database_username: str = ""
graph_database_password: str = "" graph_database_password: str = ""
graph_database_port: int = 123 graph_database_port: int = 123
graph_file_path: str = os.path.join(relational_config.database_directory_path,graph_filename) graph_file_path: str = os.path.join(get_relationaldb_config().db_path, graph_filename)
graph_engine: object = GraphDBType.NETWORKX graph_engine: object = GraphDBType.NETWORKX
graph_model: object = DefaultGraphModel graph_model: object = DefaultGraphModel

View file

@ -4,11 +4,11 @@ from cognee.shared.data_models import GraphDBType
from .config import get_graph_config from .config import get_graph_config
from .graph_db_interface import GraphDBInterface from .graph_db_interface import GraphDBInterface
from .networkx.adapter import NetworkXAdapter from .networkx.adapter import NetworkXAdapter
config = get_graph_config()
async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) -> GraphDBInterface : async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None) -> GraphDBInterface :
"""Factory function to get the appropriate graph client based on the graph type.""" """Factory function to get the appropriate graph client based on the graph type."""
config = get_graph_config()
if graph_type == GraphDBType.NEO4J: if graph_type == GraphDBType.NEO4J:
try: try:

View file

@ -1,28 +1,24 @@
import os import os
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.infrastructure.databases.relational import DuckDBAdapter
from cognee.base_config import get_base_config from cognee.base_config import get_base_config
config = get_base_config() from .create_relational_engine import create_relational_engine
class RelationalConfig(BaseSettings): class RelationalConfig(BaseSettings):
db_path: str = os.path.join(config.system_root_directory, "databases") db_path: str = os.path.join(get_base_config().system_root_directory, "databases")
db_name: str = "cognee.db" db_name: str = "cognee.db"
db_host: str = "localhost" db_host: str = "localhost"
db_port: str = "5432" db_port: str = "5432"
db_user: str = "cognee" db_user: str = "cognee"
db_password: str = "cognee" db_password: str = "cognee"
db_engine: object = DuckDBAdapter( database_engine: object = create_relational_engine(db_path, db_name)
db_name=db_name, db_file_path: str = os.path.join(db_path, db_name)
db_path=db_path
)
database_engine: object = db_engine
db_file_path:str = os.path.join(db_path, db_name)
database_path: str = os.path.join(config.system_root_directory, "databases")
database_directory_path: str = db_path
model_config = SettingsConfigDict(env_file = ".env", extra = "allow") model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
def create_engine(self):
return create_relational_engine(self.db_path, self.db_name)
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
"db_path": self.db_path, "db_path": self.db_path,
@ -31,10 +27,9 @@ class RelationalConfig(BaseSettings):
"db_port": self.db_port, "db_port": self.db_port,
"db_user": self.db_user, "db_user": self.db_user,
"db_password": self.db_password, "db_password": self.db_password,
"db_engine": self.db_engine, "db_engine": self.database_engine,
"database_path": self.database_path,
} }
@lru_cache @lru_cache
def get_relationaldb_config(): def get_relationaldb_config():
return RelationalConfig() return RelationalConfig()

View file

@ -0,0 +1,10 @@
from cognee.infrastructure.files.storage import LocalStorage
from cognee.infrastructure.databases.relational import DuckDBAdapter
def create_relational_engine(db_path: str, db_name: str):
LocalStorage.ensure_directory_exists(db_path)
return DuckDBAdapter(
db_name = db_name,
db_path = db_path,
)

View file

@ -5,33 +5,28 @@ from cognee.infrastructure.databases.relational.config import get_relationaldb_c
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
from .create_vector_engine import create_vector_engine from .create_vector_engine import create_vector_engine
embeddings_config = get_embedding_config()
relational_config = get_relationaldb_config()
lancedb_path = os.path.join(relational_config.database_directory_path, "cognee.lancedb")
class VectorConfig(BaseSettings): class VectorConfig(BaseSettings):
vector_db_url: str = lancedb_path vector_db_url: str = os.path.join(get_relationaldb_config().db_path, "cognee.lancedb")
vector_db_key: str = "" vector_db_key: str = ""
vector_engine_provider: str = "lancedb" vector_engine_provider: str = "lancedb"
vector_engine: object = create_vector_engine( vector_engine: object = create_vector_engine(
{ {
"vector_db_key": None, "vector_db_key": None,
"vector_db_url": lancedb_path, "vector_db_url": vector_db_url,
"vector_db_provider": "lancedb", "vector_db_provider": "lancedb",
}, },
embeddings_config.embedding_engine, get_embedding_config().embedding_engine,
) )
model_config = SettingsConfigDict(env_file = ".env", extra = "allow") model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
def create_engine(self): def create_engine(self):
if self.vector_engine_provider == "lancedb": if self.vector_engine_provider == "lancedb":
self.vector_db_url = lancedb_path self.vector_db_url = os.path.join(get_relationaldb_config().db_path, "cognee.lancedb")
self.vector_engine = create_vector_engine( self.vector_engine = create_vector_engine(
get_vectordb_config().to_dict(), get_vectordb_config().to_dict(),
embeddings_config.embedding_engine, get_embedding_config().embedding_engine,
) )
def to_dict(self) -> dict: def to_dict(self) -> dict:

View file

@ -35,7 +35,7 @@ class LocalStorage(Storage):
@staticmethod @staticmethod
def ensure_directory_exists(file_path: str): def ensure_directory_exists(file_path: str):
if not os.path.exists(file_path): if not os.path.exists(file_path):
os.makedirs(file_path) os.makedirs(file_path, exist_ok = True)
def remove(self, file_path: str): def remove(self, file_path: str):
os.remove(self.storage_path + "/" + file_path) os.remove(self.storage_path + "/" + file_path)

View file

@ -6,6 +6,7 @@ class LLMConfig(BaseSettings):
llm_model: str = "gpt-4o" llm_model: str = "gpt-4o"
llm_endpoint: str = "" llm_endpoint: str = ""
llm_api_key: str = "" llm_api_key: str = ""
llm_temperature: float = 0.0
model_config = SettingsConfigDict(env_file = ".env", extra = "allow") model_config = SettingsConfigDict(env_file = ".env", extra = "allow")

View file

@ -1,9 +1,6 @@
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.shared.data_models import DefaultContentPrediction, LabeledContent, SummarizedContent, \
from cognee.root_dir import get_absolute_path
from cognee.shared.data_models import MonitoringTool, DefaultContentPrediction, LabeledContent, SummarizedContent, \
DefaultCognitiveLayer, DefaultGraphModel, KnowledgeGraph DefaultCognitiveLayer, DefaultGraphModel, KnowledgeGraph
@ -23,7 +20,6 @@ class CognifyConfig(BaseSettings):
graph_model:object = KnowledgeGraph graph_model:object = KnowledgeGraph
model_config = SettingsConfigDict(env_file = ".env", extra = "allow") model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
def to_dict(self) -> dict: def to_dict(self) -> dict:

View file

@ -6,8 +6,7 @@ from cognee.infrastructure.databases.vector import DataPoint
# from cognee.utils import extract_pos_tags, extract_named_entities, extract_sentiment_vader # from cognee.utils import extract_pos_tags, extract_named_entities, extract_sentiment_vader
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.vector.config import get_vectordb_config from cognee.infrastructure.databases.vector.config import get_vectordb_config
graph_config = get_graph_config()
vectordb_config = get_vectordb_config()
class GraphLike(TypedDict): class GraphLike(TypedDict):
nodes: List nodes: List
edges: List edges: List
@ -19,7 +18,10 @@ async def add_cognitive_layer_graphs(
chunk_id: str, chunk_id: str,
layer_graphs: List[Tuple[str, GraphLike]], layer_graphs: List[Tuple[str, GraphLike]],
): ):
vectordb_config = get_vectordb_config()
vector_client = vectordb_config.vector_engine vector_client = vectordb_config.vector_engine
graph_config = get_graph_config()
graph_model = graph_config.graph_model graph_model = graph_config.graph_model
for (layer_id, layer_graph) in layer_graphs: for (layer_id, layer_graph) in layer_graphs:

View file

@ -2,17 +2,15 @@
from typing import TypedDict from typing import TypedDict
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from cognee.infrastructure.databases.vector.config import get_vectordb_config from cognee.infrastructure.databases.vector.config import get_vectordb_config
from cognee.infrastructure.databases.vector import DataPoint from cognee.infrastructure.databases.vector import DataPoint
config = get_vectordb_config()
class TextChunk(TypedDict): class TextChunk(TypedDict):
text: str text: str
chunk_id: str chunk_id: str
file_metadata: dict file_metadata: dict
async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]): async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
config = get_vectordb_config()
vector_client = config.vector_engine vector_client = config.vector_engine
identified_chunks = [] identified_chunks = []
@ -55,6 +53,7 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
async def add_data_chunks_basic_rag(dataset_data_chunks: dict[str, list[TextChunk]]): async def add_data_chunks_basic_rag(dataset_data_chunks: dict[str, list[TextChunk]]):
config = get_vectordb_config()
vector_client = config.vector_engine vector_client = config.vector_engine
identified_chunks = [] identified_chunks = []

View file

@ -4,11 +4,10 @@ from datetime import datetime
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.databases.vector import DataPoint from cognee.infrastructure.databases.vector import DataPoint
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.vector.config import get_vectordb_config from cognee.infrastructure.databases.vector.config import get_vectordb_config
graph_config = get_graph_config()
vectordb_config = get_vectordb_config()
async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]) -> None: async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]) -> None:
vectordb_config = get_vectordb_config()
vector_client = vectordb_config.vector_engine vector_client = vectordb_config.vector_engine
keyword_nodes = [] keyword_nodes = []

View file

@ -4,9 +4,6 @@ import uuid
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.vector.config import get_vectordb_config
graph_config = get_graph_config()
vectordb_config = get_vectordb_config()
async def group_nodes_by_layer(node_descriptions): async def group_nodes_by_layer(node_descriptions):
@ -42,6 +39,7 @@ async def connect_nodes_in_graph(graph, relationship_dict, score_threshold=0.9):
for relationship in relationships: for relationship in relationships:
if relationship['score'] > score_threshold: if relationship['score'] > score_threshold:
graph_config = get_graph_config()
# For NetworkX # For NetworkX
if graph_config.graph_engine == GraphDBType.NETWORKX: if graph_config.graph_engine == GraphDBType.NETWORKX:

View file

@ -1,14 +1,9 @@
""" This module is responsible for creating a semantic graph """ """ This module is responsible for creating a semantic graph """
from typing import Optional, Any from typing import Optional, Any
from pydantic import BaseModel from pydantic import BaseModel
# from cognee.infrastructure import infrastructure_config
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.vector.config import get_vectordb_config
graph_config = get_graph_config()
vectordb_config = get_vectordb_config()
async def generate_node_id(instance: BaseModel) -> str: async def generate_node_id(instance: BaseModel) -> str:
for field in ["id", "doc_id", "location_id", "type_id", "node_id"]: for field in ["id", "doc_id", "location_id", "type_id", "node_id"]:
if hasattr(instance, field): if hasattr(instance, field):
@ -46,7 +41,8 @@ async def add_node(client, parent_id: Optional[str], node_id: str, node_data: di
# print('NODE ID', node_id) # print('NODE ID', node_id)
# print('NODE DATA', node_data) # print('NODE DATA', node_data)
result = await client.add_node(node_id, node_properties = node_data) result = await client.add_node(node_id, node_properties = node_data)
print("added node", result)
graph_config = get_graph_config()
# Add an edge if a parent ID is provided and the graph engine is NETWORKX # Add an edge if a parent ID is provided and the graph engine is NETWORKX
if parent_id and "default_relationship" in node_data and graph_config.graph_engine == GraphDBType.NETWORKX: if parent_id and "default_relationship" in node_data and graph_config.graph_engine == GraphDBType.NETWORKX:

View file

@ -1,8 +1,5 @@
from typing import Dict, List from typing import Dict, List
from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.vector.config import get_vectordb_config from cognee.infrastructure.databases.vector.config import get_vectordb_config
graph_config = get_graph_config()
vectordb_config = get_vectordb_config()
async def resolve_cross_graph_references(nodes_by_layer: Dict): async def resolve_cross_graph_references(nodes_by_layer: Dict):
results = [] results = []
@ -19,6 +16,7 @@ async def resolve_cross_graph_references(nodes_by_layer: Dict):
return results return results
async def get_nodes_by_layer(layer_id: str, layer_nodes: List): async def get_nodes_by_layer(layer_id: str, layer_nodes: List):
vectordb_config = get_vectordb_config()
vector_engine = vectordb_config.vector_engine vector_engine = vectordb_config.vector_engine
score_points = await vector_engine.batch_search( score_points = await vector_engine.batch_search(

View file

@ -3,12 +3,9 @@ import dspy
import nltk import nltk
from nltk.corpus import stopwords from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize from nltk.tokenize import word_tokenize
from cognee.config import Config from cognee.infrastructure.llm import get_llm_config
from cognee.shared.data_models import KnowledgeGraph, Node, Edge from cognee.shared.data_models import KnowledgeGraph, Node, Edge
from cognee.utils import num_tokens_from_string, trim_text_to_max_tokens from cognee.utils import trim_text_to_max_tokens
config = Config()
config.load()
# """Instructions: # """Instructions:
# You are a top-tier algorithm designed for extracting information from text in structured formats to build a knowledge graph. # You are a top-tier algorithm designed for extracting information from text in structured formats to build a knowledge graph.
@ -41,7 +38,9 @@ def are_all_nodes_connected(graph: KnowledgeGraph) -> bool:
class ExtractKnowledgeGraph(dspy.Module): class ExtractKnowledgeGraph(dspy.Module):
def __init__(self, lm = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)): llm_config = get_llm_config()
def __init__(self, lm = dspy.OpenAI(model = llm_config.llm_model, api_key = llm_config.llm_api_key, model_type = "chat", max_tokens = 4096)):
super().__init__() super().__init__()
self.lm = lm self.lm = lm
dspy.settings.configure(lm=self.lm) dspy.settings.configure(lm=self.lm)
@ -50,7 +49,7 @@ class ExtractKnowledgeGraph(dspy.Module):
def forward(self, context: str, question: str): def forward(self, context: str, question: str):
context = remove_stop_words(context) context = remove_stop_words(context)
context = trim_text_to_max_tokens(context, 1500, config.llm_model) context = trim_text_to_max_tokens(context, 1500, self.llm_config.llm_model)
with dspy.context(lm = self.lm): with dspy.context(lm = self.lm):
graph = self.generate_graph(text = context).graph graph = self.generate_graph(text = context).graph

View file

@ -1,47 +0,0 @@
import logging
from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.data import Dataset, Data
from cognee.infrastructure.files import remove_file_from_storage
from cognee.infrastructure.databases.relational import DatabaseEngine
from cognee.infrastructure.databases.relational.config import get_relationaldb_config
config = get_relationaldb_config()
logger = logging.getLogger(__name__)
async def add_data_to_dataset(dataset: Dataset, data: Data):
db_engine: DatabaseEngine = config.database_engine
existing_dataset = (await db_engine.query_entity(dataset)).scalar()
existing_data = (await db_engine.query_entity(data)).scalar()
if existing_dataset:
if existing_data:
await remove_old_raw_data(existing_data.raw_data_location)
def update_raw_data():
existing_data.raw_data_location = data.raw_data_location
await db_engine.update(update_raw_data)
if existing_dataset.id == dataset.id and dataset.name is not None:
def update_name(): existing_dataset.name = dataset.name
await db_engine.update(update_name)
else:
await db_engine.update(lambda: existing_dataset.data.append(data))
else:
if existing_data:
await remove_old_raw_data(existing_data.raw_data_location)
existing_data.raw_data_location = data.raw_data_location
await db_engine.update(lambda: dataset.data.append(existing_data))
else:
await db_engine.update(lambda: dataset.data.append(data))
await db_engine.create(dataset)
async def remove_old_raw_data(data_location: str):
try:
await remove_file_from_storage(data_location)
except Exception:
logger.error("Failed to remove old raw data file: %s", data_location)

View file

@ -1,27 +0,0 @@
import asyncio
from uuid import UUID, uuid4
from cognee.infrastructure.data import Data, Dataset
from .add_data_to_dataset import add_data_to_dataset
from .data_types import IngestionData
async def save(dataset_id: UUID, dataset_name: str, data_id: UUID, data: IngestionData):
file_path = uuid4().hex + "." + data.get_extension()
promises = []
promises.append(
add_data_to_dataset(
Dataset(
id = dataset_id,
name = dataset_name if dataset_name else dataset_id.hex
),
Data(
id = data_id,
raw_data_location = file_path,
name = data.metadata["name"],
meta_data = data.metadata
)
)
)
await asyncio.gather(*promises)

View file

@ -5,7 +5,7 @@ from typing import Union, Dict
import networkx as nx import networkx as nx
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
graph_config = get_graph_config()
async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: dict = None) -> Dict[str, str]: async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: dict = None) -> Dict[str, str]:
""" """
Find the neighbours of a given node in the graph and return their descriptions. Find the neighbours of a given node in the graph and return their descriptions.
@ -23,7 +23,9 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param:
if node_id is None: if node_id is None:
return {} return {}
from cognee.infrastructure import infrastructure_config
graph_config = get_graph_config()
if graph_config.graph_engine == GraphDBType.NETWORKX: if graph_config.graph_engine == GraphDBType.NETWORKX:
if node_id not in graph: if node_id not in graph:
return {} return {}

View file

@ -1,19 +1,11 @@
from typing import Union, Dict
import re
from pydantic import BaseModel
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
""" Search categories in the graph and return their summary attributes. """ """ Search categories in the graph and return their summary attributes. """
from typing import Union
from cognee.shared.data_models import GraphDBType, DefaultContentPrediction import re
import networkx as nx import networkx as nx
from pydantic import BaseModel
from cognee.shared.data_models import GraphDBType
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
graph_config = get_graph_config()
from cognee.infrastructure.databases.vector.config import get_vectordb_config
vector_config = get_vectordb_config()
def strip_exact_regex(s, substring): def strip_exact_regex(s, substring):
# Escaping substring to be used in a regex pattern # Escaping substring to be used in a regex pattern
@ -25,7 +17,7 @@ def strip_exact_regex(s, substring):
class DefaultResponseModel(BaseModel): class DefaultResponseModel(BaseModel):
document_id: str document_id: str
async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str=None, infrastructure_config: Dict=None): async def search_categories(query:str, graph: Union[nx.Graph, any], query_label: str=None):
""" """
Filter nodes in the graph that contain the specified label and return their summary attributes. Filter nodes in the graph that contain the specified label and return their summary attributes.
This function supports both NetworkX graphs and Neo4j graph databases. This function supports both NetworkX graphs and Neo4j graph databases.
@ -33,7 +25,6 @@ async def search_categories(query:str, graph: Union[nx.Graph, any], query_label:
Parameters: Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. - graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
- query_label (str): The label to filter nodes by. - query_label (str): The label to filter nodes by.
- infrastructure_config (Dict): Configuration that includes the graph engine type.
Returns: Returns:
- Union[Dict, List[Dict]]: For NetworkX, returns a dictionary where keys are node identifiers, - Union[Dict, List[Dict]]: For NetworkX, returns a dictionary where keys are node identifiers,
@ -41,21 +32,22 @@ async def search_categories(query:str, graph: Union[nx.Graph, any], query_label:
each representing a node with 'nodeId' and 'summary'. each representing a node with 'nodeId' and 'summary'.
""" """
# Determine which client is in use based on the configuration # Determine which client is in use based on the configuration
from cognee.infrastructure import infrastructure_config graph_config = get_graph_config()
if graph_config.graph_engine == GraphDBType.NETWORKX: if graph_config.graph_engine == GraphDBType.NETWORKX:
categories_and_ids = [ categories_and_ids = [
{'document_id': strip_exact_regex(_, "DATA_SUMMARY__"), 'Summary': data['summary']} {"document_id": strip_exact_regex(_, "DATA_SUMMARY__"), "Summary": data["summary"]}
for _, data in graph.nodes(data=True) for _, data in graph.nodes(data=True)
if 'summary' in data if "summary" in data
] ]
connected_nodes = [] connected_nodes = []
for id in categories_and_ids: for id in categories_and_ids:
print("id", id) print("id", id)
connected_nodes.append(list(graph.neighbors(id['document_id']))) connected_nodes.append(list(graph.neighbors(id["document_id"])))
check_relevant_category = await categorize_relevant_category(query, categories_and_ids, response_model=DefaultResponseModel ) check_relevant_category = await categorize_relevant_category(query, categories_and_ids, response_model=DefaultResponseModel )
connected_nodes = list(graph.neighbors(check_relevant_category['document_id'])) connected_nodes = list(graph.neighbors(check_relevant_category["document_id"]))
descriptions = {node: graph.nodes[node].get('description', 'No desc available') for node in connected_nodes} descriptions = {node: graph.nodes[node].get("description", "No desc available") for node in connected_nodes}
return descriptions return descriptions
elif graph_config.graph_engine == GraphDBType.NEO4J: elif graph_config.graph_engine == GraphDBType.NEO4J:

View file

@ -1,27 +1,18 @@
from typing import Union, Dict
import re
import networkx as nx import networkx as nx
from pydantic import BaseModel from typing import Union
from cognee.modules.search.llm.extraction.categorize_relevant_category import categorize_relevant_category
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
graph_config = get_graph_config()
from cognee.infrastructure.databases.vector.config import get_vectordb_config
vector_config = get_vectordb_config()
async def search_cypher(query:str, graph: Union[nx.Graph, any]): async def search_cypher(query:str, graph: Union[nx.Graph, any]):
""" """
Use a Cypher query to search the graph and return the results. Use a Cypher query to search the graph and return the results.
""" """
graph_config = get_graph_config()
from cognee.infrastructure import infrastructure_config
if graph_config.graph_engine == GraphDBType.NEO4J: if graph_config.graph_engine == GraphDBType.NEO4J:
result = await graph.run(query) result = await graph.run(query)
return result return result
else: else:
raise ValueError("Unsupported graph engine type.") raise ValueError("Unsupported graph engine type.")

View file

@ -1,15 +1,10 @@
""" Fetches the context of a given node in the graph""" """ Fetches the context of a given node in the graph"""
from typing import Union, Dict
from neo4j import AsyncSession
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
import networkx as nx import networkx as nx
from typing import Union
from neo4j import AsyncSession
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
graph_config = get_graph_config()
from cognee.infrastructure.databases.vector.config import get_vectordb_config
vector_config = get_vectordb_config()
async def search_neighbour(graph: Union[nx.Graph, any], query: str, async def search_neighbour(graph: Union[nx.Graph, any], query: str,
other_param: dict = None): other_param: dict = None):
""" """
@ -25,19 +20,20 @@ async def search_neighbour(graph: Union[nx.Graph, any], query: str,
Returns: Returns:
- List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node. - List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node.
""" """
from cognee.infrastructure import infrastructure_config
node_id = other_param.get('node_id') if other_param else query node_id = other_param.get('node_id') if other_param else query
if node_id is None: if node_id is None:
return [] return []
graph_config = get_graph_config()
if graph_config.graph_engine == GraphDBType.NETWORKX: if graph_config.graph_engine == GraphDBType.NETWORKX:
relevant_context = [] relevant_context = []
target_layer_uuid = graph.nodes[node_id].get('layer_uuid') target_layer_uuid = graph.nodes[node_id].get("layer_uuid")
for n, attr in graph.nodes(data=True): for n, attr in graph.nodes(data=True):
if attr.get('layer_uuid') == target_layer_uuid and 'description' in attr: if attr.get("layer_uuid") == target_layer_uuid and "description" in attr:
relevant_context.append(attr['description']) relevant_context.append(attr["description"])
return relevant_context return relevant_context

View file

@ -1,23 +1,16 @@
import re
from typing import Union, Dict from typing import Union, Dict
import networkx as nx import networkx as nx
from cognee.infrastructure import infrastructure_config
from cognee.modules.search.llm.extraction.categorize_relevant_summary import categorize_relevant_summary from cognee.modules.search.llm.extraction.categorize_relevant_summary import categorize_relevant_summary
from cognee.shared.data_models import GraphDBType, ResponseSummaryModel from cognee.shared.data_models import GraphDBType, ResponseSummaryModel
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
graph_config = get_graph_config()
from cognee.infrastructure.databases.vector.config import get_vectordb_config
vector_config = get_vectordb_config()
import re
def strip_exact_regex(s, substring): def strip_exact_regex(s, substring):
# Escaping substring to be used in a regex pattern # Escaping substring to be used in a regex pattern
pattern = re.escape(substring) pattern = re.escape(substring)
# Regex to match the exact substring at the start and end # Regex to match the exact substring at the start and end
return re.sub(f"^{pattern}|{pattern}$", "", s) return re.sub(f"^{pattern}|{pattern}$", "", s)
async def search_summary( query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]: async def search_summary( query: str, graph: Union[nx.Graph, any]) -> Dict[str, str]:
""" """
Filter nodes based on a condition (such as containing 'SUMMARY' in their identifiers) and return their summary attributes. Filter nodes based on a condition (such as containing 'SUMMARY' in their identifiers) and return their summary attributes.
@ -32,13 +25,13 @@ async def search_summary( query: str, graph: Union[nx.Graph, any]) -> Dict[str,
Returns: Returns:
- Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'summary' attributes. - Dict[str, str]: A dictionary where keys are node identifiers containing the query string, and values are their 'summary' attributes.
""" """
graph_config = get_graph_config()
if graph_config.graph_engine == GraphDBType.NETWORKX: if graph_config.graph_engine == GraphDBType.NETWORKX:
print("graph", graph)
summaries_and_ids = [ summaries_and_ids = [
{'document_id': strip_exact_regex(_, "DATA_SUMMARY__"), 'Summary': data['summary']} {"document_id": strip_exact_regex(_, "DATA_SUMMARY__"), "Summary": data["summary"]}
for _, data in graph.nodes(data=True) for _, data in graph.nodes(data=True)
if 'summary' in data if "summary" in data
] ]
print("summaries_and_ids", summaries_and_ids) print("summaries_and_ids", summaries_and_ids)
check_relevant_summary = await categorize_relevant_summary(query, summaries_and_ids, response_model=ResponseSummaryModel) check_relevant_summary = await categorize_relevant_summary(query, summaries_and_ids, response_model=ResponseSummaryModel)

View file

@ -1,10 +1,10 @@
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
graph_config = get_graph_config()
from cognee.infrastructure.databases.vector.config import get_vectordb_config from cognee.infrastructure.databases.vector.config import get_vectordb_config
vector_config = get_vectordb_config()
async def search_similarity(query: str, graph): async def search_similarity(query: str, graph):
graph_config = get_graph_config()
graph_db_type = graph_config.graph_engine graph_db_type = graph_config.graph_engine
graph_client = await get_graph_client(graph_db_type) graph_client = await get_graph_client(graph_db_type)
@ -17,6 +17,8 @@ async def search_similarity(query: str, graph):
graph_nodes = [] graph_nodes = []
vector_config = get_vectordb_config()
for layer_id in unique_layer_uuids: for layer_id in unique_layer_uuids:
vector_engine = vector_config.vector_engine vector_engine = vector_config.vector_engine

View file

@ -1,10 +1,8 @@
from cognee.infrastructure.InfrastructureConfig import infrastructure_config
from cognee.infrastructure.databases.relational.config import get_relationaldb_config from cognee.infrastructure.databases.relational.config import get_relationaldb_config
config = get_relationaldb_config()
def create_task_status_table(): def create_task_status_table():
db_engine = config.db_engine config = get_relationaldb_config()
db_engine = config.database_engine
db_engine.create_table("cognee_task_status", [ db_engine.create_table("cognee_task_status", [
dict(name = "data_id", type = "STRING"), dict(name = "data_id", type = "STRING"),

View file

@ -2,7 +2,7 @@ from cognee.infrastructure.databases.relational.config import get_relationaldb_c
def get_task_status(data_ids: [str]): def get_task_status(data_ids: [str]):
relational_config = get_relationaldb_config() relational_config = get_relationaldb_config()
db_engine = relational_config.db_engine db_engine = relational_config.database_engine
formatted_data_ids = ", ".join([f"'{data_id}'" for data_id in data_ids]) formatted_data_ids = ", ".join([f"'{data_id}'" for data_id in data_ids])

View file

@ -1,8 +1,6 @@
from cognee.infrastructure.InfrastructureConfig import infrastructure_config
from cognee.infrastructure.databases.relational.config import get_relationaldb_config from cognee.infrastructure.databases.relational.config import get_relationaldb_config
config = get_relationaldb_config()
def update_task_status(data_id: str, status: str): def update_task_status(data_id: str, status: str):
db_engine = config.db_engine config = get_relationaldb_config()
db_engine = config.database_engine
db_engine.insert_data("cognee_task_status", [dict(data_id = data_id, status = status)]) db_engine.insert_data("cognee_task_status", [dict(data_id = data_id, status = status)])

View file

@ -2,9 +2,8 @@ from deepeval.dataset import EvaluationDataset
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Type from typing import List, Type, Dict
from deepeval.test_case import LLMTestCase from deepeval.test_case import LLMTestCase
from deepeval.dataset import Golden
import dotenv import dotenv
dotenv.load_dotenv() dotenv.load_dotenv()
@ -42,7 +41,6 @@ print(dataset)
import logging import logging
from typing import List, Dict
from cognee.infrastructure import infrastructure_config from cognee.infrastructure import infrastructure_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)