fix: cognify status table update

This commit is contained in:
Boris Arzentar 2024-06-03 21:49:10 +02:00
parent 245ed66c9e
commit c9d9672fed
9 changed files with 41 additions and 28 deletions

View file

@ -59,7 +59,7 @@ async def add_files(file_paths: List[str], dataset_name: str):
if data_directory_path not in file_path: if data_directory_path not in file_path:
file_name = file_path.split("/")[-1] file_name = file_path.split("/")[-1]
file_directory_path = data_directory_path + "/" + (dataset_name.replace(".", "/") + "/" if dataset_name != "root" else "") file_directory_path = data_directory_path + "/" + (dataset_name.replace(".", "/") + "/" if dataset_name != "main_dataset" else "")
dataset_file_path = path.join(file_directory_path, file_name) dataset_file_path = path.join(file_directory_path, file_name)
LocalStorage.ensure_directory_exists(file_directory_path) LocalStorage.ensure_directory_exists(file_directory_path)

View file

@ -13,7 +13,7 @@ class DuckDBAdapter():
return list( return list(
filter( filter(
lambda table_name: table_name.endswith("staging") is False, lambda schema_name: not schema_name.endswith("staging") and schema_name != "cognee",
tables["schema_name"] tables["schema_name"]
) )
) )
@ -22,14 +22,18 @@ class DuckDBAdapter():
with self.get_connection() as connection: with self.get_connection() as connection:
return connection.sql(f"SELECT id, name, file_path, extension, mime_type, keywords FROM {dataset_name}.file_metadata;").to_df().to_dict("records") return connection.sql(f"SELECT id, name, file_path, extension, mime_type, keywords FROM {dataset_name}.file_metadata;").to_df().to_dict("records")
def create_table(self, table_name: str, table_config: list[dict]): def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
fields_query_parts = [] fields_query_parts = []
for table_config_item in table_config: for table_config_item in table_config:
fields_query_parts.append(f"{table_config_item['name']} {table_config_item['type']}") fields_query_parts.append(f"{table_config_item['name']} {table_config_item['type']}")
with self.get_connection() as connection: with self.get_connection() as connection:
query = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(fields_query_parts)});" query = f"CREATE SCHEMA IF NOT EXISTS {schema_name};"
connection.execute(query)
with self.get_connection() as connection:
query = f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"
connection.execute(query) connection.execute(query)
def delete_table(self, table_name: str): def delete_table(self, table_name: str):
@ -37,7 +41,7 @@ class DuckDBAdapter():
query = f"DROP TABLE IF EXISTS {table_name};" query = f"DROP TABLE IF EXISTS {table_name};"
connection.execute(query) connection.execute(query)
def insert_data(self, table_name: str, data: list[dict]): def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
def get_values(data_entry: list): def get_values(data_entry: list):
return ", ".join([f"'{value}'" if isinstance(value, str) else value for value in data_entry]) return ", ".join([f"'{value}'" if isinstance(value, str) else value for value in data_entry])
@ -45,7 +49,7 @@ class DuckDBAdapter():
values = ", ".join([f"({get_values(data_entry.values())})" for data_entry in data]) values = ", ".join([f"({get_values(data_entry.values())})" for data_entry in data])
with self.get_connection() as connection: with self.get_connection() as connection:
query = f"INSERT INTO {table_name} ({columns}) VALUES {values};" query = f"INSERT INTO {schema_name}.{table_name} ({columns}) VALUES {values};"
connection.execute(query) connection.execute(query)
def get_data(self, table_name: str, filters: dict = None): def get_data(self, table_name: str, filters: dict = None):

View file

@ -1,18 +1,17 @@
import asyncio import asyncio
from typing import List, Optional from typing import List, Optional
from openai import AsyncOpenAI
from fastembed import TextEmbedding from fastembed import TextEmbedding
import litellm
from litellm import aembedding
from cognee.root_dir import get_absolute_path from cognee.root_dir import get_absolute_path
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from litellm import aembedding
import litellm
litellm.set_verbose = True litellm.set_verbose = True
class DefaultEmbeddingEngine(EmbeddingEngine): class DefaultEmbeddingEngine(EmbeddingEngine):
embedding_model: str embedding_model: str
embedding_dimensions: int embedding_dimensions: int
def __init__( def __init__(
self, self,
embedding_model: Optional[str], embedding_model: Optional[str],
@ -34,6 +33,7 @@ class DefaultEmbeddingEngine(EmbeddingEngine):
class LiteLLMEmbeddingEngine(EmbeddingEngine): class LiteLLMEmbeddingEngine(EmbeddingEngine):
embedding_model: str embedding_model: str
embedding_dimensions: int embedding_dimensions: int
def __init__( def __init__(
self, self,
embedding_model: Optional[str], embedding_model: Optional[str],
@ -41,8 +41,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
): ):
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.embedding_dimensions = embedding_dimensions self.embedding_dimensions = embedding_dimensions
import asyncio
from typing import List
async def embed_text(self, text: List[str]) -> List[List[float]]: async def embed_text(self, text: List[str]) -> List[List[float]]:
async def get_embedding(text_): async def get_embedding(text_):
@ -52,19 +50,20 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
tasks = [get_embedding(text_) for text_ in text] tasks = [get_embedding(text_) for text_ in text]
result = await asyncio.gather(*tasks) result = await asyncio.gather(*tasks)
return result return result
def get_vector_size(self) -> int: def get_vector_size(self) -> int:
return self.embedding_dimensions return self.embedding_dimensions
if __name__ == "__main__": # if __name__ == "__main__":
async def gg(): # async def gg():
openai_embedding_engine = LiteLLMEmbeddingEngine() # openai_embedding_engine = LiteLLMEmbeddingEngine()
# print(openai_embedding_engine.embed_text(["Hello, how are you?"])) # # print(openai_embedding_engine.embed_text(["Hello, how are you?"]))
# print(openai_embedding_engine.get_vector_size()) # # print(openai_embedding_engine.get_vector_size())
# default_embedding_engine = DefaultEmbeddingEngine() # # default_embedding_engine = DefaultEmbeddingEngine()
sds = await openai_embedding_engine.embed_text(["Hello, sadasdas are you?"]) # sds = await openai_embedding_engine.embed_text(["Hello, sadasdas are you?"])
print(sds) # print(sds)
# print(default_embedding_engine.get_vector_size()) # # print(default_embedding_engine.get_vector_size())
asyncio.run(gg()) # asyncio.run(gg())

View file

@ -1,3 +1,4 @@
from typing import Optional
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@ -5,7 +6,7 @@ class LLMConfig(BaseSettings):
llm_provider: str = "openai" llm_provider: str = "openai"
llm_model: str = "gpt-4o" llm_model: str = "gpt-4o"
llm_endpoint: str = "" llm_endpoint: str = ""
llm_api_key: str = "" llm_api_key: Optional[str] = None
llm_temperature: float = 0.0 llm_temperature: float = 0.0
llm_streaming: bool = False llm_streaming: bool = False

View file

@ -16,15 +16,24 @@ def get_llm_client():
provider = LLMProvider(llm_config.llm_provider) provider = LLMProvider(llm_config.llm_provider)
if provider == LLMProvider.OPENAI: if provider == LLMProvider.OPENAI:
if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.")
from .openai.adapter import OpenAIAdapter from .openai.adapter import OpenAIAdapter
return OpenAIAdapter(llm_config.llm_api_key, llm_config.llm_model, llm_config.llm_streaming) return OpenAIAdapter(llm_config.llm_api_key, llm_config.llm_model, llm_config.llm_streaming)
elif provider == LLMProvider.OLLAMA: elif provider == LLMProvider.OLLAMA:
if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.")
from .generic_llm_api.adapter import GenericAPIAdapter from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama") return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
elif provider == LLMProvider.ANTHROPIC: elif provider == LLMProvider.ANTHROPIC:
from .anthropic.adapter import AnthropicAdapter from .anthropic.adapter import AnthropicAdapter
return AnthropicAdapter(llm_config.llm_model) return AnthropicAdapter(llm_config.llm_model)
elif provider == LLMProvider.CUSTOM: elif provider == LLMProvider.CUSTOM:
if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.")
from .generic_llm_api.adapter import GenericAPIAdapter from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom") return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
else: else:

View file

@ -4,7 +4,7 @@ def create_task_status_table():
config = get_relationaldb_config() config = get_relationaldb_config()
db_engine = config.database_engine db_engine = config.database_engine
db_engine.create_table("cognee_task_status", [ db_engine.create_table("cognee.cognee", "cognee_task_status", [
dict(name = "data_id", type = "STRING"), dict(name = "data_id", type = "STRING"),
dict(name = "status", type = "STRING"), dict(name = "status", type = "STRING"),
dict(name = "created_at", type = "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"), dict(name = "created_at", type = "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"),

View file

@ -10,10 +10,10 @@ def get_task_status(data_ids: [str]):
f"""SELECT data_id, status f"""SELECT data_id, status
FROM ( FROM (
SELECT data_id, status, ROW_NUMBER() OVER (PARTITION BY data_id ORDER BY created_at DESC) as rn SELECT data_id, status, ROW_NUMBER() OVER (PARTITION BY data_id ORDER BY created_at DESC) as rn
FROM cognee_task_status FROM cognee.cognee.cognee_task_status
WHERE data_id IN ({formatted_data_ids}) WHERE data_id IN ({formatted_data_ids})
) t ) t
WHERE rn = 1;""" WHERE rn = 1;"""
) )
return results return results[0] if len(results) > 0 else None

View file

@ -3,4 +3,4 @@ from cognee.infrastructure.databases.relational.config import get_relationaldb_c
def update_task_status(data_id: str, status: str): def update_task_status(data_id: str, status: str):
config = get_relationaldb_config() config = get_relationaldb_config()
db_engine = config.database_engine db_engine = config.database_engine
db_engine.insert_data("cognee_task_status", [dict(data_id = data_id, status = status)]) db_engine.insert_data("cognee.cognee", "cognee_task_status", [dict(data_id = data_id, status = status)])

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "cognee" name = "cognee"
version = "0.1.9" version = "0.1.11"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = ["Vasilije Markovic", "Boris Arzentar"] authors = ["Vasilije Markovic", "Boris Arzentar"]
readme = "README.md" readme = "README.md"