feat: COG-585 enable custom llm and embeding models
This commit is contained in:
parent
a8aefd57ef
commit
d1f8217320
11 changed files with 222 additions and 173 deletions
|
|
@ -9,18 +9,21 @@ async def get_graph_engine() -> GraphDBInterface :
|
||||||
config = get_graph_config()
|
config = get_graph_config()
|
||||||
|
|
||||||
if config.graph_database_provider == "neo4j":
|
if config.graph_database_provider == "neo4j":
|
||||||
try:
|
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
|
||||||
from .neo4j_driver.adapter import Neo4jAdapter
|
raise EnvironmentError("Missing required Neo4j credentials.")
|
||||||
|
|
||||||
|
from .neo4j_driver.adapter import Neo4jAdapter
|
||||||
|
|
||||||
return Neo4jAdapter(
|
return Neo4jAdapter(
|
||||||
graph_database_url = config.graph_database_url,
|
graph_database_url = config.graph_database_url,
|
||||||
graph_database_username = config.graph_database_username,
|
graph_database_username = config.graph_database_username,
|
||||||
graph_database_password = config.graph_database_password
|
graph_database_password = config.graph_database_password
|
||||||
)
|
)
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif config.graph_database_provider == "falkordb":
|
elif config.graph_database_provider == "falkordb":
|
||||||
|
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
|
||||||
|
raise EnvironmentError("Missing required FalkorDB credentials.")
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||||
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,26 +10,29 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
if config["vector_db_provider"] == "weaviate":
|
if config["vector_db_provider"] == "weaviate":
|
||||||
from .weaviate_db import WeaviateAdapter
|
from .weaviate_db import WeaviateAdapter
|
||||||
|
|
||||||
if config["vector_db_url"] is None and config["vector_db_key"] is None:
|
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||||
raise EnvironmentError("Weaviate is not configured!")
|
raise EnvironmentError("Missing requred Weaviate credentials!")
|
||||||
|
|
||||||
return WeaviateAdapter(
|
return WeaviateAdapter(
|
||||||
config["vector_db_url"],
|
config["vector_db_url"],
|
||||||
config["vector_db_key"],
|
config["vector_db_key"],
|
||||||
embedding_engine = embedding_engine
|
embedding_engine = embedding_engine
|
||||||
)
|
)
|
||||||
elif config["vector_db_provider"] == "qdrant":
|
|
||||||
if config["vector_db_url"] and config["vector_db_key"]:
|
|
||||||
from .qdrant.QDrantAdapter import QDrantAdapter
|
|
||||||
|
|
||||||
return QDrantAdapter(
|
elif config["vector_db_provider"] == "qdrant":
|
||||||
url = config["vector_db_url"],
|
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||||
api_key = config["vector_db_key"],
|
raise EnvironmentError("Missing requred Qdrant credentials!")
|
||||||
embedding_engine = embedding_engine
|
|
||||||
)
|
from .qdrant.QDrantAdapter import QDrantAdapter
|
||||||
|
|
||||||
|
return QDrantAdapter(
|
||||||
|
url = config["vector_db_url"],
|
||||||
|
api_key = config["vector_db_key"],
|
||||||
|
embedding_engine = embedding_engine
|
||||||
|
)
|
||||||
|
|
||||||
elif config["vector_db_provider"] == "pgvector":
|
elif config["vector_db_provider"] == "pgvector":
|
||||||
from cognee.infrastructure.databases.relational import get_relational_config
|
from cognee.infrastructure.databases.relational import get_relational_config
|
||||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
|
||||||
|
|
||||||
# Get configuration for postgres database
|
# Get configuration for postgres database
|
||||||
relational_config = get_relational_config()
|
relational_config = get_relational_config()
|
||||||
|
|
@ -39,16 +42,25 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
db_port = relational_config.db_port
|
db_port = relational_config.db_port
|
||||||
db_name = relational_config.db_name
|
db_name = relational_config.db_name
|
||||||
|
|
||||||
|
if not (db_host and db_port and db_name and db_username and db_password):
|
||||||
|
raise EnvironmentError("Missing requred pgvector credentials!")
|
||||||
|
|
||||||
connection_string: str = (
|
connection_string: str = (
|
||||||
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||||
|
|
||||||
return PGVectorAdapter(
|
return PGVectorAdapter(
|
||||||
connection_string,
|
connection_string,
|
||||||
config["vector_db_key"],
|
config["vector_db_key"],
|
||||||
embedding_engine,
|
embedding_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif config["vector_db_provider"] == "falkordb":
|
elif config["vector_db_provider"] == "falkordb":
|
||||||
|
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||||
|
raise EnvironmentError("Missing requred FalkorDB credentials!")
|
||||||
|
|
||||||
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||||
|
|
||||||
return FalkorDBAdapter(
|
return FalkorDBAdapter(
|
||||||
|
|
@ -56,6 +68,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
database_port = config["vector_db_port"],
|
database_port = config["vector_db_port"],
|
||||||
embedding_engine = embedding_engine,
|
embedding_engine = embedding_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||||
|
|
||||||
|
|
@ -64,5 +77,3 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
api_key = config["vector_db_key"],
|
api_key = config["vector_db_key"],
|
||||||
embedding_engine = embedding_engine,
|
embedding_engine = embedding_engine,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise EnvironmentError(f"Vector provider not configured correctly: {config['vector_db_provider']}")
|
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,39 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import aembedding
|
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
||||||
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
api_key: str
|
api_key: str
|
||||||
embedding_model: str
|
endpoint: str
|
||||||
embedding_dimensions: int
|
api_version: str
|
||||||
|
model: str
|
||||||
|
dimensions: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_model: Optional[str] = "text-embedding-3-large",
|
model: Optional[str] = "text-embedding-3-large",
|
||||||
embedding_dimensions: Optional[int] = 3072,
|
dimensions: Optional[int] = 3072,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
|
endpoint: str = None,
|
||||||
|
api_version: str = None,
|
||||||
):
|
):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.embedding_model = embedding_model
|
self.endpoint = endpoint
|
||||||
self.embedding_dimensions = embedding_dimensions
|
self.api_version = api_version
|
||||||
|
self.model = model
|
||||||
|
self.dimensions = dimensions
|
||||||
|
|
||||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||||
async def get_embedding(text_):
|
async def get_embedding(text_):
|
||||||
response = await aembedding(
|
response = await litellm.aembedding(
|
||||||
self.embedding_model,
|
self.model,
|
||||||
input = text_,
|
input = text_,
|
||||||
api_key = self.api_key
|
api_key = self.api_key,
|
||||||
|
api_base = self.endpoint,
|
||||||
|
api_version = self.api_version
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.data[0]["embedding"]
|
return response.data[0]["embedding"]
|
||||||
|
|
@ -36,4 +43,4 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_vector_size(self) -> int:
|
def get_vector_size(self) -> int:
|
||||||
return self.embedding_dimensions
|
return self.dimensions
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,16 @@
|
||||||
|
from typing import Optional
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
class EmbeddingConfig(BaseSettings):
|
class EmbeddingConfig(BaseSettings):
|
||||||
openai_embedding_model: str = "text-embedding-3-large"
|
embedding_model: Optional[str] = "text-embedding-3-large"
|
||||||
openai_embedding_dimensions: int = 3072
|
embedding_dimensions: Optional[int] = 3072
|
||||||
litellm_embedding_model: str = "BAAI/bge-large-en-v1.5"
|
embedding_endpoint: Optional[str] = None
|
||||||
litellm_embedding_dimensions: int = 1024
|
embedding_api_key: Optional[str] = None
|
||||||
# embedding_engine:object = DefaultEmbeddingEngine(embedding_model=litellm_embedding_model, embedding_dimensions=litellm_embedding_dimensions)
|
embedding_api_version: Optional[str] = None
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"openai_embedding_model": self.openai_embedding_model,
|
|
||||||
"openai_embedding_dimensions": self.openai_embedding_dimensions,
|
|
||||||
"litellm_embedding_model": self.litellm_embedding_model,
|
|
||||||
"litellm_embedding_dimensions": self.litellm_embedding_dimensions,
|
|
||||||
}
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_embedding_config():
|
def get_embedding_config():
|
||||||
return EmbeddingConfig()
|
return EmbeddingConfig()
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,17 @@
|
||||||
from cognee.infrastructure.llm import get_llm_config
|
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
||||||
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
from .EmbeddingEngine import EmbeddingEngine
|
from .EmbeddingEngine import EmbeddingEngine
|
||||||
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
||||||
|
|
||||||
def get_embedding_engine() -> EmbeddingEngine:
|
def get_embedding_engine() -> EmbeddingEngine:
|
||||||
|
config = get_embedding_config()
|
||||||
llm_config = get_llm_config()
|
llm_config = get_llm_config()
|
||||||
return LiteLLMEmbeddingEngine(api_key = llm_config.llm_api_key)
|
|
||||||
|
return LiteLLMEmbeddingEngine(
|
||||||
|
# If OpenAI API is used for embeddings, litellm needs only the api_key.
|
||||||
|
api_key = config.embedding_api_key or llm_config.llm_api_key,
|
||||||
|
endpoint = config.embedding_endpoint,
|
||||||
|
api_version = config.embedding_api_version,
|
||||||
|
model = config.embedding_model,
|
||||||
|
dimensions = config.embedding_dimensions,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ class LLMConfig(BaseSettings):
|
||||||
llm_model: str = "gpt-4o-mini"
|
llm_model: str = "gpt-4o-mini"
|
||||||
llm_endpoint: str = ""
|
llm_endpoint: str = ""
|
||||||
llm_api_key: Optional[str] = None
|
llm_api_key: Optional[str] = None
|
||||||
|
llm_api_version: Optional[str] = None
|
||||||
llm_temperature: float = 0.0
|
llm_temperature: float = 0.0
|
||||||
llm_streaming: bool = False
|
llm_streaming: bool = False
|
||||||
transcription_model: str = "whisper-1"
|
transcription_model: str = "whisper-1"
|
||||||
|
|
@ -19,6 +20,7 @@ class LLMConfig(BaseSettings):
|
||||||
"model": self.llm_model,
|
"model": self.llm_model,
|
||||||
"endpoint": self.llm_endpoint,
|
"endpoint": self.llm_endpoint,
|
||||||
"api_key": self.llm_api_key,
|
"api_key": self.llm_api_key,
|
||||||
|
"api_version": self.llm_api_version,
|
||||||
"temperature": self.llm_temperature,
|
"temperature": self.llm_temperature,
|
||||||
"streaming": self.llm_streaming,
|
"streaming": self.llm_streaming,
|
||||||
"transcription_model": self.transcription_model
|
"transcription_model": self.transcription_model
|
||||||
|
|
|
||||||
|
|
@ -20,21 +20,33 @@ def get_llm_client():
|
||||||
raise ValueError("LLM API key is not set.")
|
raise ValueError("LLM API key is not set.")
|
||||||
|
|
||||||
from .openai.adapter import OpenAIAdapter
|
from .openai.adapter import OpenAIAdapter
|
||||||
return OpenAIAdapter(api_key=llm_config.llm_api_key, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, streaming=llm_config.llm_streaming)
|
|
||||||
|
return OpenAIAdapter(
|
||||||
|
api_key = llm_config.llm_api_key,
|
||||||
|
endpoint = llm_config.llm_endpoint,
|
||||||
|
api_version = llm_config.llm_api_version,
|
||||||
|
model = llm_config.llm_model,
|
||||||
|
transcription_model = llm_config.transcription_model,
|
||||||
|
streaming = llm_config.llm_streaming,
|
||||||
|
)
|
||||||
|
|
||||||
elif provider == LLMProvider.OLLAMA:
|
elif provider == LLMProvider.OLLAMA:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
raise ValueError("LLM API key is not set.")
|
raise ValueError("LLM API key is not set.")
|
||||||
|
|
||||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
|
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
|
||||||
|
|
||||||
elif provider == LLMProvider.ANTHROPIC:
|
elif provider == LLMProvider.ANTHROPIC:
|
||||||
from .anthropic.adapter import AnthropicAdapter
|
from .anthropic.adapter import AnthropicAdapter
|
||||||
return AnthropicAdapter(llm_config.llm_model)
|
return AnthropicAdapter(llm_config.llm_model)
|
||||||
|
|
||||||
elif provider == LLMProvider.CUSTOM:
|
elif provider == LLMProvider.CUSTOM:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None:
|
||||||
raise ValueError("LLM API key is not set.")
|
raise ValueError("LLM API key is not set.")
|
||||||
|
|
||||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
|
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -1,174 +1,121 @@
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import os
|
import os
|
||||||
|
import base64
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Type
|
from typing import Type
|
||||||
|
|
||||||
import openai
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tenacity import retry, stop_after_attempt
|
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
|
||||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
# from cognee.shared.data_models import MonitoringTool
|
|
||||||
|
|
||||||
class OpenAIAdapter(LLMInterface):
|
class OpenAIAdapter(LLMInterface):
|
||||||
name = "OpenAI"
|
name = "OpenAI"
|
||||||
model: str
|
model: str
|
||||||
api_key: str
|
api_key: str
|
||||||
|
api_version: str
|
||||||
|
|
||||||
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
||||||
def __init__(self, api_key: str, model: str, transcription_model:str, streaming: bool = False):
|
def __init__(
|
||||||
base_config = get_base_config()
|
self,
|
||||||
|
api_key: str,
|
||||||
# if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
|
endpoint: str,
|
||||||
# from langfuse.openai import AsyncOpenAI, OpenAI
|
api_version: str,
|
||||||
# elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
|
model: str,
|
||||||
# from langsmith import wrappers
|
transcription_model: str,
|
||||||
# from openai import AsyncOpenAI
|
streaming: bool = False,
|
||||||
# AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
|
):
|
||||||
# else:
|
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||||
from openai import AsyncOpenAI, OpenAI
|
self.client = instructor.from_litellm(litellm.completion)
|
||||||
|
self.transcription_model = transcription_model
|
||||||
self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key))
|
|
||||||
self.client = instructor.from_openai(OpenAI(api_key = api_key))
|
|
||||||
self.base_openai_client = OpenAI(api_key = api_key)
|
|
||||||
self.transcription_model = "whisper-1"
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.api_version = api_version
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def completions_with_backoff(self, **kwargs):
|
|
||||||
"""Wrapper around ChatCompletion.create w/ backoff"""
|
|
||||||
return openai.chat.completions.create(**kwargs)
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
async def acompletions_with_backoff(self,**kwargs):
|
|
||||||
"""Wrapper around ChatCompletion.acreate w/ backoff"""
|
|
||||||
return await openai.chat.completions.acreate(**kwargs)
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
|
|
||||||
"""Wrapper around Embedding.acreate w/ backoff"""
|
|
||||||
|
|
||||||
return await self.aclient.embeddings.create(input = input, model = model)
|
|
||||||
|
|
||||||
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
|
|
||||||
"""To get text embeddings, import/call this function
|
|
||||||
It specifies defaults + handles rate-limiting + is async"""
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
response = await self.aclient.embeddings.create(input = text, model = model)
|
|
||||||
embedding = response.data[0].embedding
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def create_embedding_with_backoff(self, **kwargs):
|
|
||||||
"""Wrapper around Embedding.create w/ backoff"""
|
|
||||||
return openai.embeddings.create(**kwargs)
|
|
||||||
|
|
||||||
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
|
|
||||||
"""To get text embeddings, import/call this function
|
|
||||||
It specifies defaults + handles rate-limiting
|
|
||||||
:param text: str
|
|
||||||
:param model: str
|
|
||||||
"""
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
response = self.create_embedding_with_backoff(input=[text], model=model)
|
|
||||||
embedding = response.data[0].embedding
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
|
|
||||||
"""To get multiple text embeddings in parallel, import/call this function
|
|
||||||
It specifies defaults + handles rate-limiting + is async"""
|
|
||||||
# Collect all coroutines
|
|
||||||
coroutines = (self.async_get_embedding_with_backoff(text, model)
|
|
||||||
for text, model in zip(texts, models))
|
|
||||||
|
|
||||||
# Run the coroutines in parallel and gather the results
|
|
||||||
embeddings = await asyncio.gather(*coroutines)
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
||||||
"""Generate a response from a user query."""
|
"""Generate a response from a user query."""
|
||||||
|
|
||||||
return await self.aclient.chat.completions.create(
|
return await self.aclient.chat.completions.create(
|
||||||
model = self.model,
|
model = self.model,
|
||||||
messages = [
|
messages = [{
|
||||||
{
|
"role": "user",
|
||||||
"role": "user",
|
"content": f"""Use the given format to
|
||||||
"content": f"""Use the given format to
|
extract information from the following input: {text_input}. """,
|
||||||
extract information from the following input: {text_input}. """,
|
}, {
|
||||||
},
|
"role": "system",
|
||||||
{"role": "system", "content": system_prompt},
|
"content": system_prompt,
|
||||||
],
|
}],
|
||||||
|
api_key = self.api_key,
|
||||||
|
api_base = self.endpoint,
|
||||||
|
api_version = self.api_version,
|
||||||
response_model = response_model,
|
response_model = response_model,
|
||||||
|
max_retries = 5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
||||||
"""Generate a response from a user query."""
|
"""Generate a response from a user query."""
|
||||||
|
|
||||||
return self.client.chat.completions.create(
|
return self.client.chat.completions.create(
|
||||||
model = self.model,
|
model = self.model,
|
||||||
messages = [
|
messages = [{
|
||||||
{
|
"role": "user",
|
||||||
"role": "user",
|
"content": f"""Use the given format to
|
||||||
"content": f"""Use the given format to
|
extract information from the following input: {text_input}. """,
|
||||||
extract information from the following input: {text_input}. """,
|
}, {
|
||||||
},
|
"role": "system",
|
||||||
{"role": "system", "content": system_prompt},
|
"content": system_prompt,
|
||||||
],
|
}],
|
||||||
|
api_key = self.api_key,
|
||||||
|
api_base = self.endpoint,
|
||||||
|
api_version = self.api_version,
|
||||||
response_model = response_model,
|
response_model = response_model,
|
||||||
|
max_retries = 5,
|
||||||
)
|
)
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def create_transcript(self, input):
|
def create_transcript(self, input):
|
||||||
"""Generate a audio transcript from a user query."""
|
"""Generate a audio transcript from a user query."""
|
||||||
|
|
||||||
if not os.path.isfile(input):
|
if not os.path.isfile(input):
|
||||||
raise FileNotFoundError(f"The file {input} does not exist.")
|
raise FileNotFoundError(f"The file {input} does not exist.")
|
||||||
|
|
||||||
with open(input, 'rb') as audio_file:
|
# with open(input, 'rb') as audio_file:
|
||||||
audio_data = audio_file.read()
|
# audio_data = audio_file.read()
|
||||||
|
|
||||||
|
transcription = litellm.transcription(
|
||||||
|
model = self.transcription_model,
|
||||||
transcription = self.base_openai_client.audio.transcriptions.create(
|
file = Path(input),
|
||||||
model=self.transcription_model ,
|
max_retries = 5,
|
||||||
file=Path(input),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return transcription
|
return transcription
|
||||||
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
|
||||||
def transcribe_image(self, input) -> BaseModel:
|
def transcribe_image(self, input) -> BaseModel:
|
||||||
with open(input, "rb") as image_file:
|
with open(input, "rb") as image_file:
|
||||||
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
|
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
|
||||||
return self.base_openai_client.chat.completions.create(
|
return litellm.completion(
|
||||||
model=self.model,
|
model = self.model,
|
||||||
messages=[
|
messages = [{
|
||||||
{
|
"role": "user",
|
||||||
"role": "user",
|
"content": [
|
||||||
"content": [
|
{
|
||||||
{"type": "text", "text": "What’s in this image?"},
|
"type": "text",
|
||||||
{
|
"text": "What’s in this image?",
|
||||||
"type": "image_url",
|
}, {
|
||||||
"image_url": {
|
"type": "image_url",
|
||||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
"image_url": {
|
||||||
},
|
"url": f"data:image/jpeg;base64,{encoded_image}",
|
||||||
},
|
},
|
||||||
],
|
},
|
||||||
}
|
],
|
||||||
],
|
}],
|
||||||
max_tokens=300,
|
max_tokens = 300,
|
||||||
|
max_retries = 5,
|
||||||
)
|
)
|
||||||
|
|
||||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||||
"""Format and display the prompt for a user query."""
|
"""Format and display the prompt for a user query."""
|
||||||
if not text_input:
|
if not text_input:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
import json
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
from cognee.modules.settings import get_current_settings
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
|
@ -157,7 +159,7 @@ async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
|
||||||
})
|
})
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pipeline"):
|
async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -185,3 +187,10 @@ async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pi
|
||||||
})
|
})
|
||||||
|
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
|
async def run_tasks(tasks: list[Task], data = None, pipeline_name: str = "default_pipeline"):
|
||||||
|
config = get_current_settings()
|
||||||
|
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent = 1))
|
||||||
|
|
||||||
|
async for result in run_tasks_with_telemetry(tasks, data, pipeline_name):
|
||||||
|
yield result
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from .get_current_settings import get_current_settings
|
||||||
from .get_settings import get_settings, SettingsDict
|
from .get_settings import get_settings, SettingsDict
|
||||||
from .save_llm_config import save_llm_config
|
from .save_llm_config import save_llm_config
|
||||||
from .save_vector_db_config import save_vector_db_config
|
from .save_vector_db_config import save_vector_db_config
|
||||||
|
|
|
||||||
54
cognee/modules/settings/get_current_settings.py
Normal file
54
cognee/modules/settings/get_current_settings.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
from typing import TypedDict
|
||||||
|
from cognee.infrastructure.llm import get_llm_config
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
|
from cognee.infrastructure.databases.relational.config import get_relational_config
|
||||||
|
|
||||||
|
class LLMConfig(TypedDict):
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
class VectorDBConfig(TypedDict):
|
||||||
|
url: str
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
class GraphDBConfig(TypedDict):
|
||||||
|
url: str
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
class RelationalConfig(TypedDict):
|
||||||
|
url: str
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
class SettingsDict(TypedDict):
|
||||||
|
llm: LLMConfig
|
||||||
|
graph: GraphDBConfig
|
||||||
|
vector: VectorDBConfig
|
||||||
|
relational: RelationalConfig
|
||||||
|
|
||||||
|
def get_current_settings() -> SettingsDict:
|
||||||
|
llm_config = get_llm_config()
|
||||||
|
graph_config = get_graph_config()
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
relational_config = get_relational_config()
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
llm = {
|
||||||
|
"provider": llm_config.llm_provider,
|
||||||
|
"model": llm_config.llm_model,
|
||||||
|
},
|
||||||
|
graph = {
|
||||||
|
"provider": graph_config.graph_database_provider,
|
||||||
|
"url": graph_config.graph_database_url or graph_config.graph_file_path,
|
||||||
|
},
|
||||||
|
vector = {
|
||||||
|
"provider": vector_config.vector_db_provider,
|
||||||
|
"url": vector_config.vector_db_url,
|
||||||
|
},
|
||||||
|
relational = {
|
||||||
|
"provider": relational_config.db_provider,
|
||||||
|
"url": f"{relational_config.db_host}:{relational_config.db_port}" \
|
||||||
|
if relational_config.db_host \
|
||||||
|
else f"{relational_config.db_path}/{relational_config.db_name}",
|
||||||
|
},
|
||||||
|
)
|
||||||
Loading…
Add table
Reference in a new issue