add test for linter
This commit is contained in:
parent
a3e218e5a4
commit
630588bd46
11 changed files with 38 additions and 41 deletions
|
|
@ -98,9 +98,9 @@ class Config:
|
||||||
anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex)
|
anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex)
|
||||||
|
|
||||||
#Chunking parameters
|
#Chunking parameters
|
||||||
chunk_size: int = 1500
|
# chunk_size: int = 1500
|
||||||
chunk_overlap: int = 0
|
# chunk_overlap: int = 0
|
||||||
chunk_strategy: str = ChunkStrategy.PARAGRAPH
|
# chunk_strategy: str = ChunkStrategy.PARAGRAPH
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
"""Loads the configuration from a file or environment variables."""
|
"""Loads the configuration from a file or environment variables."""
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
base_config = get_base_config()
|
||||||
|
|
||||||
class VectorConfig(BaseSettings):
|
class VectorConfig(BaseSettings):
|
||||||
vector_db_url: str = ""
|
vector_db_url: str = ""
|
||||||
vector_db_key: str = ""
|
vector_db_key: str = ""
|
||||||
vector_db_path: str = ""
|
vector_db_path: str = os.path.join(base_config.database_directory_path + "cognee.lancedb")
|
||||||
vector_db_engine: object = ""
|
vector_db_engine: object = ""
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from typing import BinaryIO
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
|
|
||||||
def extract_text_from_file(file: BinaryIO, file_type) -> str:
|
def extract_text_from_file(file: BinaryIO, file_type) -> str:
|
||||||
|
"""Extract text from a file"""
|
||||||
if file_type.extension == "pdf":
|
if file_type.extension == "pdf":
|
||||||
reader = PdfReader(stream = file)
|
reader = PdfReader(stream = file)
|
||||||
pages = list(reader.pages[:3])
|
pages = list(reader.pages[:3])
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ class FileMetadata(TypedDict):
|
||||||
keywords: list[str]
|
keywords: list[str]
|
||||||
|
|
||||||
def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
def get_file_metadata(file: BinaryIO) -> FileMetadata:
|
||||||
|
"""Get metadata from a file"""
|
||||||
file.seek(0)
|
file.seek(0)
|
||||||
file_type = guess_file_type(file)
|
file_type = guess_file_type(file)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def get_file_size(file_path: str):
|
def get_file_size(file_path: str):
|
||||||
|
"""Get the size of a file"""
|
||||||
return os.path.getsize(file_path)
|
return os.path.getsize(file_path)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ class FileTypeException(Exception):
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
class TxtFileType(filetype.Type):
|
class TxtFileType(filetype.Type):
|
||||||
|
"""Text file type"""
|
||||||
MIME = "text/plain"
|
MIME = "text/plain"
|
||||||
EXTENSION = "txt"
|
EXTENSION = "txt"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
def is_text_content(content):
|
def is_text_content(content):
|
||||||
|
"""Check if the content is text."""
|
||||||
# Check for null bytes
|
# Check for null bytes
|
||||||
if b'\0' in content:
|
if b'\0' in content:
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
'''Adapter for Generic API LLM provider API'''
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -5,18 +6,20 @@ import instructor
|
||||||
from tenacity import retry, stop_after_attempt
|
from tenacity import retry, stop_after_attempt
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
from cognee.config import Config
|
|
||||||
from cognee.infrastructure import infrastructure_config
|
from cognee.infrastructure import infrastructure_config
|
||||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
from cognee.shared.data_models import MonitoringTool
|
from cognee.shared.data_models import MonitoringTool
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
config = Config()
|
llm_config = get_llm_config()
|
||||||
config.load()
|
base_config = get_base_config()
|
||||||
|
|
||||||
if config.monitoring_tool == MonitoringTool.LANGFUSE:
|
if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
|
||||||
from langfuse.openai import AsyncOpenAI, OpenAI
|
from langfuse.openai import AsyncOpenAI, OpenAI
|
||||||
elif config.monitoring_tool == MonitoringTool.LANGSMITH:
|
elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
|
||||||
from langsmith import wrappers
|
from langsmith import wrappers
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
|
AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
|
||||||
|
|
@ -34,7 +37,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
||||||
if infrastructure_config.get_config()["llm_provider"] == "groq":
|
if llm_config.llm_provider == "groq":
|
||||||
from groq import groq
|
from groq import groq
|
||||||
self.aclient = instructor.from_openai(
|
self.aclient = instructor.from_openai(
|
||||||
client = groq.Groq(
|
client = groq.Groq(
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,9 @@ import logging
|
||||||
# from cognee.infrastructure.llm import llm_config
|
# from cognee.infrastructure.llm import llm_config
|
||||||
|
|
||||||
from cognee.config import Config
|
from cognee.config import Config
|
||||||
|
from cognee.infrastructure.llm import get_llm_config
|
||||||
|
|
||||||
|
|
||||||
# Define an Enum for LLM Providers
|
# Define an Enum for LLM Providers
|
||||||
class LLMProvider(Enum):
|
class LLMProvider(Enum):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
|
@ -12,24 +15,23 @@ class LLMProvider(Enum):
|
||||||
ANTHROPIC = "anthropic"
|
ANTHROPIC = "anthropic"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
|
||||||
config = Config()
|
llm_config = get_llm_config()
|
||||||
config.load()
|
|
||||||
def get_llm_client():
|
def get_llm_client():
|
||||||
"""Get the LLM client based on the configuration using Enums."""
|
"""Get the LLM client based on the configuration using Enums."""
|
||||||
# logging.error(json.dumps(llm_config.to_dict()))
|
# logging.error(json.dumps(llm_config.to_dict()))
|
||||||
provider = LLMProvider(config.llm_provider)
|
provider = LLMProvider(llm_config.llm_provider)
|
||||||
|
|
||||||
if provider == LLMProvider.OPENAI:
|
if provider == LLMProvider.OPENAI:
|
||||||
from .openai.adapter import OpenAIAdapter
|
from .openai.adapter import OpenAIAdapter
|
||||||
return OpenAIAdapter(config.llm_api_key, config.llm_model)
|
return OpenAIAdapter(llm_config.llm_api_key, llm_config.llm_model)
|
||||||
elif provider == LLMProvider.OLLAMA:
|
elif provider == LLMProvider.OLLAMA:
|
||||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
return GenericAPIAdapter(config.llm_endpoint, config.llm_api_key, config.llm_model, "Ollama")
|
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
|
||||||
elif provider == LLMProvider.ANTHROPIC:
|
elif provider == LLMProvider.ANTHROPIC:
|
||||||
from .anthropic.adapter import AnthropicAdapter
|
from .anthropic.adapter import AnthropicAdapter
|
||||||
return AnthropicAdapter(config.llm_model)
|
return AnthropicAdapter(llm_config.llm_model)
|
||||||
elif provider == LLMProvider.CUSTOM:
|
elif provider == LLMProvider.CUSTOM:
|
||||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
return GenericAPIAdapter(config.llm_endpoint, config.llm_api_key, config.llm_model, "Custom")
|
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -6,26 +6,6 @@ from pydantic import BaseModel
|
||||||
class LLMInterface(Protocol):
|
class LLMInterface(Protocol):
|
||||||
""" LLM Interface """
|
""" LLM Interface """
|
||||||
|
|
||||||
# @abstractmethod
|
|
||||||
# async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"):
|
|
||||||
# """To get text embeddings, import/call this function"""
|
|
||||||
# raise NotImplementedError
|
|
||||||
#
|
|
||||||
# @abstractmethod
|
|
||||||
# def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"):
|
|
||||||
# """To get text embeddings, import/call this function"""
|
|
||||||
# raise NotImplementedError
|
|
||||||
#
|
|
||||||
# @abstractmethod
|
|
||||||
# async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
|
|
||||||
# """To get multiple text embeddings in parallel, import/call this function"""
|
|
||||||
# raise NotImplementedError
|
|
||||||
|
|
||||||
# """ Get completions """
|
|
||||||
# async def acompletions_with_backoff(self, **kwargs):
|
|
||||||
# raise NotImplementedError
|
|
||||||
#
|
|
||||||
""" Structured output """
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def acreate_structured_output(self,
|
async def acreate_structured_output(self,
|
||||||
text_input: str,
|
text_input: str,
|
||||||
|
|
|
||||||
|
|
@ -5,20 +5,24 @@ import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tenacity import retry, stop_after_attempt
|
from tenacity import retry, stop_after_attempt
|
||||||
|
|
||||||
|
from cognee.base_config import get_base_config
|
||||||
from cognee.config import Config
|
from cognee.config import Config
|
||||||
|
from cognee.infrastructure.llm import get_llm_config
|
||||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
from cognee.shared.data_models import MonitoringTool
|
from cognee.shared.data_models import MonitoringTool
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
|
llm_config = get_llm_config()
|
||||||
|
base_config = get_base_config()
|
||||||
|
|
||||||
if config.monitoring_tool == MonitoringTool.LANGFUSE:
|
if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
|
||||||
from langfuse.openai import AsyncOpenAI, OpenAI
|
from langfuse.openai import AsyncOpenAI, OpenAI
|
||||||
elif config.monitoring_tool == MonitoringTool.LANGSMITH:
|
elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
|
||||||
from langsmith import wrap_openai
|
from langsmith import wrappers
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
AsyncOpenAI = wrap_openai(AsyncOpenAI())
|
AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
|
||||||
else:
|
else:
|
||||||
from openai import AsyncOpenAI, OpenAI
|
from openai import AsyncOpenAI, OpenAI
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue