From c9d9672fed337d8451d3f7067ea9ff2061082d80 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 3 Jun 2024 21:49:10 +0200 Subject: [PATCH] fix: cognify status table update --- cognee/api/v1/add/add.py | 2 +- .../relational/duckdb/DuckDBAdapter.py | 14 ++++++--- .../embeddings/DefaultEmbeddingEngine.py | 31 +++++++++---------- cognee/infrastructure/llm/config.py | 3 +- cognee/infrastructure/llm/get_llm_client.py | 9 ++++++ .../modules/tasks/create_task_status_table.py | 2 +- cognee/modules/tasks/get_task_status.py | 4 +-- cognee/modules/tasks/update_task_status.py | 2 +- pyproject.toml | 2 +- 9 files changed, 41 insertions(+), 28 deletions(-) diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index 66e831dde..9b1ecf5ba 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -59,7 +59,7 @@ async def add_files(file_paths: List[str], dataset_name: str): if data_directory_path not in file_path: file_name = file_path.split("/")[-1] - file_directory_path = data_directory_path + "/" + (dataset_name.replace(".", "/") + "/" if dataset_name != "root" else "") + file_directory_path = data_directory_path + "/" + (dataset_name.replace(".", "/") + "/" if dataset_name != "main_dataset" else "") dataset_file_path = path.join(file_directory_path, file_name) LocalStorage.ensure_directory_exists(file_directory_path) diff --git a/cognee/infrastructure/databases/relational/duckdb/DuckDBAdapter.py b/cognee/infrastructure/databases/relational/duckdb/DuckDBAdapter.py index 8a8cc3cf2..4e136b950 100644 --- a/cognee/infrastructure/databases/relational/duckdb/DuckDBAdapter.py +++ b/cognee/infrastructure/databases/relational/duckdb/DuckDBAdapter.py @@ -13,7 +13,7 @@ class DuckDBAdapter(): return list( filter( - lambda table_name: table_name.endswith("staging") is False, + lambda schema_name: not schema_name.endswith("staging") and schema_name != "cognee", tables["schema_name"] ) ) @@ -22,14 +22,18 @@ class DuckDBAdapter(): with self.get_connection() as connection: return connection.sql(f"SELECT id, name, file_path, extension, mime_type, keywords FROM {dataset_name}.file_metadata;").to_df().to_dict("records") - def create_table(self, table_name: str, table_config: list[dict]): + def create_table(self, schema_name: str, table_name: str, table_config: list[dict]): fields_query_parts = [] for table_config_item in table_config: fields_query_parts.append(f"{table_config_item['name']} {table_config_item['type']}") with self.get_connection() as connection: - query = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(fields_query_parts)});" + query = f"CREATE SCHEMA IF NOT EXISTS {schema_name};" + connection.execute(query) + + with self.get_connection() as connection: + query = f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});" connection.execute(query) def delete_table(self, table_name: str): @@ -37,7 +41,7 @@ class DuckDBAdapter(): query = f"DROP TABLE IF EXISTS {table_name};" connection.execute(query) - def insert_data(self, table_name: str, data: list[dict]): + def insert_data(self, schema_name: str, table_name: str, data: list[dict]): def get_values(data_entry: list): return ", ".join([f"'{value}'" if isinstance(value, str) else value for value in data_entry]) @@ -45,7 +49,7 @@ class DuckDBAdapter(): values = ", ".join([f"({get_values(data_entry.values())})" for data_entry in data]) with self.get_connection() as connection: - query = f"INSERT INTO {table_name} ({columns}) VALUES {values};" + query = f"INSERT INTO {schema_name}.{table_name} ({columns}) VALUES {values};" connection.execute(query) def get_data(self, table_name: str, filters: dict = None): diff --git a/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py index 943351729..75fe0e18c 100644 --- a/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py @@ -1,18 +1,17 @@ import asyncio from typing import List, Optional -from openai import AsyncOpenAI from fastembed import TextEmbedding - +import litellm +from litellm import aembedding from cognee.root_dir import get_absolute_path from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine -from litellm import aembedding -import litellm litellm.set_verbose = True class DefaultEmbeddingEngine(EmbeddingEngine): embedding_model: str embedding_dimensions: int + def __init__( self, embedding_model: Optional[str], @@ -34,6 +33,7 @@ class DefaultEmbeddingEngine(EmbeddingEngine): class LiteLLMEmbeddingEngine(EmbeddingEngine): embedding_model: str embedding_dimensions: int + def __init__( self, embedding_model: Optional[str], @@ -41,8 +41,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): ): self.embedding_model = embedding_model self.embedding_dimensions = embedding_dimensions - import asyncio - from typing import List async def embed_text(self, text: List[str]) -> List[List[float]]: async def get_embedding(text_): @@ -52,19 +50,20 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): tasks = [get_embedding(text_) for text_ in text] result = await asyncio.gather(*tasks) return result + def get_vector_size(self) -> int: return self.embedding_dimensions -if __name__ == "__main__": - async def gg(): - openai_embedding_engine = LiteLLMEmbeddingEngine() - # print(openai_embedding_engine.embed_text(["Hello, how are you?"])) - # print(openai_embedding_engine.get_vector_size()) - # default_embedding_engine = DefaultEmbeddingEngine() - sds = await openai_embedding_engine.embed_text(["Hello, sadasdas are you?"]) - print(sds) - # print(default_embedding_engine.get_vector_size()) +# if __name__ == "__main__": +# async def gg(): +# openai_embedding_engine = LiteLLMEmbeddingEngine() +# # print(openai_embedding_engine.embed_text(["Hello, how are you?"])) +# # print(openai_embedding_engine.get_vector_size()) +# # default_embedding_engine = DefaultEmbeddingEngine() +# sds = await openai_embedding_engine.embed_text(["Hello, sadasdas are you?"]) +# print(sds) +# # print(default_embedding_engine.get_vector_size()) - asyncio.run(gg()) +# asyncio.run(gg()) diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 0663efea7..5b43aba96 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -1,3 +1,4 @@ +from typing import Optional from functools import lru_cache from pydantic_settings import BaseSettings, SettingsConfigDict @@ -5,7 +6,7 @@ class LLMConfig(BaseSettings): llm_provider: str = "openai" llm_model: str = "gpt-4o" llm_endpoint: str = "" - llm_api_key: str = "" + llm_api_key: Optional[str] = None llm_temperature: float = 0.0 llm_streaming: bool = False diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index e0127b8b5..c7481be01 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -16,15 +16,24 @@ def get_llm_client(): provider = LLMProvider(llm_config.llm_provider) if provider == LLMProvider.OPENAI: + if llm_config.llm_api_key is None: + raise ValueError("LLM API key is not set.") + from .openai.adapter import OpenAIAdapter return OpenAIAdapter(llm_config.llm_api_key, llm_config.llm_model, llm_config.llm_streaming) elif provider == LLMProvider.OLLAMA: + if llm_config.llm_api_key is None: + raise ValueError("LLM API key is not set.") + from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama") elif provider == LLMProvider.ANTHROPIC: from .anthropic.adapter import AnthropicAdapter return AnthropicAdapter(llm_config.llm_model) elif provider == LLMProvider.CUSTOM: + if llm_config.llm_api_key is None: + raise ValueError("LLM API key is not set.") + from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom") else: diff --git a/cognee/modules/tasks/create_task_status_table.py b/cognee/modules/tasks/create_task_status_table.py index c223f1889..f92709f06 100644 --- a/cognee/modules/tasks/create_task_status_table.py +++ b/cognee/modules/tasks/create_task_status_table.py @@ -4,7 +4,7 @@ def create_task_status_table(): config = get_relationaldb_config() db_engine = config.database_engine - db_engine.create_table("cognee_task_status", [ + db_engine.create_table("cognee.cognee", "cognee_task_status", [ dict(name = "data_id", type = "STRING"), dict(name = "status", type = "STRING"), dict(name = "created_at", type = "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"), diff --git a/cognee/modules/tasks/get_task_status.py b/cognee/modules/tasks/get_task_status.py index ea390d038..d2917687a 100644 --- a/cognee/modules/tasks/get_task_status.py +++ b/cognee/modules/tasks/get_task_status.py @@ -10,10 +10,10 @@ def get_task_status(data_ids: [str]): f"""SELECT data_id, status FROM ( SELECT data_id, status, ROW_NUMBER() OVER (PARTITION BY data_id ORDER BY created_at DESC) as rn - FROM cognee_task_status + FROM cognee.cognee.cognee_task_status WHERE data_id IN ({formatted_data_ids}) ) t WHERE rn = 1;""" ) - return results + return results[0] if len(results) > 0 else None diff --git a/cognee/modules/tasks/update_task_status.py b/cognee/modules/tasks/update_task_status.py index 09f7c6c10..1efb3823e 100644 --- a/cognee/modules/tasks/update_task_status.py +++ b/cognee/modules/tasks/update_task_status.py @@ -3,4 +3,4 @@ from cognee.infrastructure.databases.relational.config import get_relationaldb_c def update_task_status(data_id: str, status: str): config = get_relationaldb_config() db_engine = config.database_engine - db_engine.insert_data("cognee_task_status", [dict(data_id = data_id, status = status)]) + db_engine.insert_data("cognee.cognee", "cognee_task_status", [dict(data_id = data_id, status = status)]) diff --git a/pyproject.toml b/pyproject.toml index b04c29ede..ec05186ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cognee" -version = "0.1.9" +version = "0.1.11" description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." authors = ["Vasilije Markovic", "Boris Arzentar"] readme = "README.md"