feat: COG-585 enable custom llm and embeding models

This commit is contained in:
Boris 2024-11-22 10:26:21 +01:00 committed by GitHub
parent a8aefd57ef
commit d1f8217320
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 222 additions and 173 deletions

View file

@ -9,18 +9,21 @@ async def get_graph_engine() -> GraphDBInterface :
config = get_graph_config()
if config.graph_database_provider == "neo4j":
try:
from .neo4j_driver.adapter import Neo4jAdapter
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
raise EnvironmentError("Missing required Neo4j credentials.")
from .neo4j_driver.adapter import Neo4jAdapter
return Neo4jAdapter(
graph_database_url = config.graph_database_url,
graph_database_username = config.graph_database_username,
graph_database_password = config.graph_database_password
)
except:
pass
return Neo4jAdapter(
graph_database_url = config.graph_database_url,
graph_database_username = config.graph_database_username,
graph_database_password = config.graph_database_password
)
elif config.graph_database_provider == "falkordb":
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
raise EnvironmentError("Missing required FalkorDB credentials.")
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter

View file

@ -10,26 +10,29 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
if config["vector_db_provider"] == "weaviate":
from .weaviate_db import WeaviateAdapter
if config["vector_db_url"] is None and config["vector_db_key"] is None:
raise EnvironmentError("Weaviate is not configured!")
if not (config["vector_db_url"] and config["vector_db_key"]):
raise EnvironmentError("Missing requred Weaviate credentials!")
return WeaviateAdapter(
config["vector_db_url"],
config["vector_db_key"],
embedding_engine = embedding_engine
)
elif config["vector_db_provider"] == "qdrant":
if config["vector_db_url"] and config["vector_db_key"]:
from .qdrant.QDrantAdapter import QDrantAdapter
return QDrantAdapter(
url = config["vector_db_url"],
api_key = config["vector_db_key"],
embedding_engine = embedding_engine
)
elif config["vector_db_provider"] == "qdrant":
if not (config["vector_db_url"] and config["vector_db_key"]):
raise EnvironmentError("Missing requred Qdrant credentials!")
from .qdrant.QDrantAdapter import QDrantAdapter
return QDrantAdapter(
url = config["vector_db_url"],
api_key = config["vector_db_key"],
embedding_engine = embedding_engine
)
elif config["vector_db_provider"] == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config
from .pgvector.PGVectorAdapter import PGVectorAdapter
# Get configuration for postgres database
relational_config = get_relational_config()
@ -39,16 +42,25 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
db_port = relational_config.db_port
db_name = relational_config.db_name
if not (db_host and db_port and db_name and db_username and db_password):
raise EnvironmentError("Missing requred pgvector credentials!")
connection_string: str = (
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
)
from .pgvector.PGVectorAdapter import PGVectorAdapter
return PGVectorAdapter(
connection_string,
config["vector_db_key"],
embedding_engine,
)
elif config["vector_db_provider"] == "falkordb":
if not (config["vector_db_url"] and config["vector_db_key"]):
raise EnvironmentError("Missing requred FalkorDB credentials!")
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalkorDBAdapter(
@ -56,6 +68,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
database_port = config["vector_db_port"],
embedding_engine = embedding_engine,
)
else:
from .lancedb.LanceDBAdapter import LanceDBAdapter
@ -64,5 +77,3 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
api_key = config["vector_db_key"],
embedding_engine = embedding_engine,
)
raise EnvironmentError(f"Vector provider not configured correctly: {config['vector_db_provider']}")

View file

@ -1,32 +1,39 @@
import asyncio
from typing import List, Optional
import litellm
from litellm import aembedding
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
litellm.set_verbose = False
class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str
embedding_model: str
embedding_dimensions: int
endpoint: str
api_version: str
model: str
dimensions: int
def __init__(
self,
embedding_model: Optional[str] = "text-embedding-3-large",
embedding_dimensions: Optional[int] = 3072,
model: Optional[str] = "text-embedding-3-large",
dimensions: Optional[int] = 3072,
api_key: str = None,
endpoint: str = None,
api_version: str = None,
):
self.api_key = api_key
self.embedding_model = embedding_model
self.embedding_dimensions = embedding_dimensions
self.endpoint = endpoint
self.api_version = api_version
self.model = model
self.dimensions = dimensions
async def embed_text(self, text: List[str]) -> List[List[float]]:
async def get_embedding(text_):
response = await aembedding(
self.embedding_model,
response = await litellm.aembedding(
self.model,
input = text_,
api_key = self.api_key
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)
return response.data[0]["embedding"]
@ -36,4 +43,4 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
return result
def get_vector_size(self) -> int:
return self.embedding_dimensions
return self.dimensions

View file

@ -1,23 +1,16 @@
from typing import Optional
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
class EmbeddingConfig(BaseSettings):
openai_embedding_model: str = "text-embedding-3-large"
openai_embedding_dimensions: int = 3072
litellm_embedding_model: str = "BAAI/bge-large-en-v1.5"
litellm_embedding_dimensions: int = 1024
# embedding_engine:object = DefaultEmbeddingEngine(embedding_model=litellm_embedding_model, embedding_dimensions=litellm_embedding_dimensions)
embedding_model: Optional[str] = "text-embedding-3-large"
embedding_dimensions: Optional[int] = 3072
embedding_endpoint: Optional[str] = None
embedding_api_key: Optional[str] = None
embedding_api_version: Optional[str] = None
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
def to_dict(self) -> dict:
return {
"openai_embedding_model": self.openai_embedding_model,
"openai_embedding_dimensions": self.openai_embedding_dimensions,
"litellm_embedding_model": self.litellm_embedding_model,
"litellm_embedding_dimensions": self.litellm_embedding_dimensions,
}
@lru_cache
def get_embedding_config():
return EmbeddingConfig()

View file

@ -1,7 +1,17 @@
from cognee.infrastructure.llm import get_llm_config
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
from cognee.infrastructure.llm.config import get_llm_config
from .EmbeddingEngine import EmbeddingEngine
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
def get_embedding_engine() -> EmbeddingEngine:
config = get_embedding_config()
llm_config = get_llm_config()
return LiteLLMEmbeddingEngine(api_key = llm_config.llm_api_key)
return LiteLLMEmbeddingEngine(
# If OpenAI API is used for embeddings, litellm needs only the api_key.
api_key = config.embedding_api_key or llm_config.llm_api_key,
endpoint = config.embedding_endpoint,
api_version = config.embedding_api_version,
model = config.embedding_model,
dimensions = config.embedding_dimensions,
)

View file

@ -7,6 +7,7 @@ class LLMConfig(BaseSettings):
llm_model: str = "gpt-4o-mini"
llm_endpoint: str = ""
llm_api_key: Optional[str] = None
llm_api_version: Optional[str] = None
llm_temperature: float = 0.0
llm_streaming: bool = False
transcription_model: str = "whisper-1"
@ -19,6 +20,7 @@ class LLMConfig(BaseSettings):
"model": self.llm_model,
"endpoint": self.llm_endpoint,
"api_key": self.llm_api_key,
"api_version": self.llm_api_version,
"temperature": self.llm_temperature,
"streaming": self.llm_streaming,
"transcription_model": self.transcription_model

View file

@ -20,21 +20,33 @@ def get_llm_client():
raise ValueError("LLM API key is not set.")
from .openai.adapter import OpenAIAdapter
return OpenAIAdapter(api_key=llm_config.llm_api_key, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, streaming=llm_config.llm_streaming)
return OpenAIAdapter(
api_key = llm_config.llm_api_key,
endpoint = llm_config.llm_endpoint,
api_version = llm_config.llm_api_version,
model = llm_config.llm_model,
transcription_model = llm_config.transcription_model,
streaming = llm_config.llm_streaming,
)
elif provider == LLMProvider.OLLAMA:
if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.")
from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
elif provider == LLMProvider.ANTHROPIC:
from .anthropic.adapter import AnthropicAdapter
return AnthropicAdapter(llm_config.llm_model)
elif provider == LLMProvider.CUSTOM:
if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.")
from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
else:
raise ValueError(f"Unsupported LLM provider: {provider}")

View file

@ -1,174 +1,121 @@
import asyncio
import base64
import os
import base64
from pathlib import Path
from typing import List, Type
from typing import Type
import openai
import litellm
import instructor
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt
from cognee.base_config import get_base_config
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
# from cognee.shared.data_models import MonitoringTool
class OpenAIAdapter(LLMInterface):
name = "OpenAI"
model: str
api_key: str
api_version: str
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
def __init__(self, api_key: str, model: str, transcription_model:str, streaming: bool = False):
base_config = get_base_config()
# if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
# from langfuse.openai import AsyncOpenAI, OpenAI
# elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
# from langsmith import wrappers
# from openai import AsyncOpenAI
# AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
# else:
from openai import AsyncOpenAI, OpenAI
self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key))
self.client = instructor.from_openai(OpenAI(api_key = api_key))
self.base_openai_client = OpenAI(api_key = api_key)
self.transcription_model = "whisper-1"
def __init__(
self,
api_key: str,
endpoint: str,
api_version: str,
model: str,
transcription_model: str,
streaming: bool = False,
):
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
self.transcription_model = transcription_model
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming
@retry(stop = stop_after_attempt(5))
def completions_with_backoff(self, **kwargs):
"""Wrapper around ChatCompletion.create w/ backoff"""
return openai.chat.completions.create(**kwargs)
@retry(stop = stop_after_attempt(5))
async def acompletions_with_backoff(self,**kwargs):
"""Wrapper around ChatCompletion.acreate w/ backoff"""
return await openai.chat.completions.acreate(**kwargs)
@retry(stop = stop_after_attempt(5))
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
"""Wrapper around Embedding.acreate w/ backoff"""
return await self.aclient.embeddings.create(input = input, model = model)
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting + is async"""
text = text.replace("\n", " ")
response = await self.aclient.embeddings.create(input = text, model = model)
embedding = response.data[0].embedding
return embedding
@retry(stop = stop_after_attempt(5))
def create_embedding_with_backoff(self, **kwargs):
"""Wrapper around Embedding.create w/ backoff"""
return openai.embeddings.create(**kwargs)
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting
:param text: str
:param model: str
"""
text = text.replace("\n", " ")
response = self.create_embedding_with_backoff(input=[text], model=model)
embedding = response.data[0].embedding
return embedding
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
"""To get multiple text embeddings in parallel, import/call this function
It specifies defaults + handles rate-limiting + is async"""
# Collect all coroutines
coroutines = (self.async_get_embedding_with_backoff(text, model)
for text, model in zip(texts, models))
# Run the coroutines in parallel and gather the results
embeddings = await asyncio.gather(*coroutines)
return embeddings
@retry(stop = stop_after_attempt(5))
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
return await self.aclient.chat.completions.create(
model = self.model,
messages = [
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{"role": "system", "content": system_prompt},
],
messages = [{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
}, {
"role": "system",
"content": system_prompt,
}],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model,
max_retries = 5,
)
@retry(stop = stop_after_attempt(5))
def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
return self.client.chat.completions.create(
model = self.model,
messages = [
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{"role": "system", "content": system_prompt},
],
messages = [{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
}, {
"role": "system",
"content": system_prompt,
}],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model,
max_retries = 5,
)
@retry(stop = stop_after_attempt(5))
def create_transcript(self, input):
"""Generate a audio transcript from a user query."""
if not os.path.isfile(input):
raise FileNotFoundError(f"The file {input} does not exist.")
with open(input, 'rb') as audio_file:
audio_data = audio_file.read()
# with open(input, 'rb') as audio_file:
# audio_data = audio_file.read()
transcription = self.base_openai_client.audio.transcriptions.create(
model=self.transcription_model ,
file=Path(input),
)
transcription = litellm.transcription(
model = self.transcription_model,
file = Path(input),
max_retries = 5,
)
return transcription
@retry(stop = stop_after_attempt(5))
def transcribe_image(self, input) -> BaseModel:
with open(input, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
return self.base_openai_client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}",
},
},
],
}
],
max_tokens=300,
return litellm.completion(
model = self.model,
messages = [{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?",
}, {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}",
},
},
],
}],
max_tokens = 300,
max_retries = 5,
)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query."""
if not text_input:

View file

@ -1,5 +1,7 @@
import json
import inspect
import logging
from cognee.modules.settings import get_current_settings
from cognee.shared.utils import send_telemetry
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
@ -157,7 +159,7 @@ async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
})
raise error
async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pipeline"):
async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
user = await get_default_user()
try:
@ -185,3 +187,10 @@ async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pi
})
raise error
async def run_tasks(tasks: list[Task], data = None, pipeline_name: str = "default_pipeline"):
config = get_current_settings()
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent = 1))
async for result in run_tasks_with_telemetry(tasks, data, pipeline_name):
yield result

View file

@ -1,3 +1,4 @@
from .get_current_settings import get_current_settings
from .get_settings import get_settings, SettingsDict
from .save_llm_config import save_llm_config
from .save_vector_db_config import save_vector_db_config

View file

@ -0,0 +1,54 @@
from typing import TypedDict
from cognee.infrastructure.llm import get_llm_config
from cognee.infrastructure.databases.graph import get_graph_config
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.relational.config import get_relational_config
class LLMConfig(TypedDict):
model: str
provider: str
class VectorDBConfig(TypedDict):
url: str
provider: str
class GraphDBConfig(TypedDict):
url: str
provider: str
class RelationalConfig(TypedDict):
url: str
provider: str
class SettingsDict(TypedDict):
llm: LLMConfig
graph: GraphDBConfig
vector: VectorDBConfig
relational: RelationalConfig
def get_current_settings() -> SettingsDict:
llm_config = get_llm_config()
graph_config = get_graph_config()
vector_config = get_vectordb_config()
relational_config = get_relational_config()
return dict(
llm = {
"provider": llm_config.llm_provider,
"model": llm_config.llm_model,
},
graph = {
"provider": graph_config.graph_database_provider,
"url": graph_config.graph_database_url or graph_config.graph_file_path,
},
vector = {
"provider": vector_config.vector_db_provider,
"url": vector_config.vector_db_url,
},
relational = {
"provider": relational_config.db_provider,
"url": f"{relational_config.db_host}:{relational_config.db_port}" \
if relational_config.db_host \
else f"{relational_config.db_path}/{relational_config.db_name}",
},
)