add test for linter

This commit is contained in:
Vasilije 2024-05-25 23:06:13 +02:00
parent a3e218e5a4
commit 630588bd46
11 changed files with 38 additions and 41 deletions

View file

@ -98,9 +98,9 @@ class Config:
anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex)
#Chunking parameters
chunk_size: int = 1500
chunk_overlap: int = 0
chunk_strategy: str = ChunkStrategy.PARAGRAPH
# chunk_size: int = 1500
# chunk_overlap: int = 0
# chunk_strategy: str = ChunkStrategy.PARAGRAPH
def load(self):
"""Loads the configuration from a file or environment variables."""

View file

@ -1,10 +1,13 @@
import os
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.base_config import get_base_config
base_config = get_base_config()
class VectorConfig(BaseSettings):
vector_db_url: str = ""
vector_db_key: str = ""
vector_db_path: str = ""
vector_db_path: str = os.path.join(base_config.database_directory_path + "cognee.lancedb")
vector_db_engine: object = ""
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")

View file

@ -2,6 +2,7 @@ from typing import BinaryIO
from pypdf import PdfReader
def extract_text_from_file(file: BinaryIO, file_type) -> str:
"""Extract text from a file"""
if file_type.extension == "pdf":
reader = PdfReader(stream = file)
pages = list(reader.pages[:3])

View file

@ -11,6 +11,7 @@ class FileMetadata(TypedDict):
keywords: list[str]
def get_file_metadata(file: BinaryIO) -> FileMetadata:
"""Get metadata from a file"""
file.seek(0)
file_type = guess_file_type(file)

View file

@ -1,4 +1,5 @@
import os
def get_file_size(file_path: str):
"""Get the size of a file"""
return os.path.getsize(file_path)

View file

@ -9,6 +9,7 @@ class FileTypeException(Exception):
self.message = message
class TxtFileType(filetype.Type):
"""Text file type"""
MIME = "text/plain"
EXTENSION = "txt"

View file

@ -1,4 +1,5 @@
def is_text_content(content):
"""Check if the content is text."""
# Check for null bytes
if b'\0' in content:
return False

View file

@ -1,3 +1,4 @@
'''Adapter for Generic API LLM provider API'''
import asyncio
from typing import List, Type
from pydantic import BaseModel
@ -5,18 +6,20 @@ import instructor
from tenacity import retry, stop_after_attempt
import openai
from cognee.config import Config
from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.shared.data_models import MonitoringTool
from cognee.base_config import get_base_config
from cognee.infrastructure.llm.config import get_llm_config
config = Config()
config.load()
llm_config = get_llm_config()
base_config = get_base_config()
if config.monitoring_tool == MonitoringTool.LANGFUSE:
if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
from langfuse.openai import AsyncOpenAI, OpenAI
elif config.monitoring_tool == MonitoringTool.LANGSMITH:
elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
from langsmith import wrappers
from openai import AsyncOpenAI
AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
@ -34,7 +37,7 @@ class GenericAPIAdapter(LLMInterface):
self.model = model
self.api_key = api_key
if infrastructure_config.get_config()["llm_provider"] == "groq":
if llm_config.llm_provider == "groq":
from groq import groq
self.aclient = instructor.from_openai(
client = groq.Groq(

View file

@ -5,6 +5,9 @@ import logging
# from cognee.infrastructure.llm import llm_config
from cognee.config import Config
from cognee.infrastructure.llm import get_llm_config
# Define an Enum for LLM Providers
class LLMProvider(Enum):
OPENAI = "openai"
@ -12,24 +15,23 @@ class LLMProvider(Enum):
ANTHROPIC = "anthropic"
CUSTOM = "custom"
config = Config()
config.load()
llm_config = get_llm_config()
def get_llm_client():
"""Get the LLM client based on the configuration using Enums."""
# logging.error(json.dumps(llm_config.to_dict()))
provider = LLMProvider(config.llm_provider)
provider = LLMProvider(llm_config.llm_provider)
if provider == LLMProvider.OPENAI:
from .openai.adapter import OpenAIAdapter
return OpenAIAdapter(config.llm_api_key, config.llm_model)
return OpenAIAdapter(llm_config.llm_api_key, llm_config.llm_model)
elif provider == LLMProvider.OLLAMA:
from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(config.llm_endpoint, config.llm_api_key, config.llm_model, "Ollama")
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
elif provider == LLMProvider.ANTHROPIC:
from .anthropic.adapter import AnthropicAdapter
return AnthropicAdapter(config.llm_model)
return AnthropicAdapter(llm_config.llm_model)
elif provider == LLMProvider.CUSTOM:
from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(config.llm_endpoint, config.llm_api_key, config.llm_model, "Custom")
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
else:
raise ValueError(f"Unsupported LLM provider: {provider}")

View file

@ -6,26 +6,6 @@ from pydantic import BaseModel
class LLMInterface(Protocol):
""" LLM Interface """
# @abstractmethod
# async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"):
# """To get text embeddings, import/call this function"""
# raise NotImplementedError
#
# @abstractmethod
# def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"):
# """To get text embeddings, import/call this function"""
# raise NotImplementedError
#
# @abstractmethod
# async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
# """To get multiple text embeddings in parallel, import/call this function"""
# raise NotImplementedError
# """ Get completions """
# async def acompletions_with_backoff(self, **kwargs):
# raise NotImplementedError
#
""" Structured output """
@abstractmethod
async def acreate_structured_output(self,
text_input: str,

View file

@ -5,20 +5,24 @@ import instructor
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt
from cognee.base_config import get_base_config
from cognee.config import Config
from cognee.infrastructure.llm import get_llm_config
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.shared.data_models import MonitoringTool
config = Config()
config.load()
llm_config = get_llm_config()
base_config = get_base_config()
if config.monitoring_tool == MonitoringTool.LANGFUSE:
if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
from langfuse.openai import AsyncOpenAI, OpenAI
elif config.monitoring_tool == MonitoringTool.LANGSMITH:
from langsmith import wrap_openai
elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
from langsmith import wrappers
from openai import AsyncOpenAI
AsyncOpenAI = wrap_openai(AsyncOpenAI())
AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
else:
from openai import AsyncOpenAI, OpenAI