feat: COG-585 enable custom llm and embeding models

This commit is contained in:
Boris 2024-11-22 10:26:21 +01:00 committed by GitHub
parent a8aefd57ef
commit d1f8217320
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 222 additions and 173 deletions

View file

@ -9,18 +9,21 @@ async def get_graph_engine() -> GraphDBInterface :
config = get_graph_config() config = get_graph_config()
if config.graph_database_provider == "neo4j": if config.graph_database_provider == "neo4j":
try: if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
from .neo4j_driver.adapter import Neo4jAdapter raise EnvironmentError("Missing required Neo4j credentials.")
from .neo4j_driver.adapter import Neo4jAdapter
return Neo4jAdapter( return Neo4jAdapter(
graph_database_url = config.graph_database_url, graph_database_url = config.graph_database_url,
graph_database_username = config.graph_database_username, graph_database_username = config.graph_database_username,
graph_database_password = config.graph_database_password graph_database_password = config.graph_database_password
) )
except:
pass
elif config.graph_database_provider == "falkordb": elif config.graph_database_provider == "falkordb":
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
raise EnvironmentError("Missing required FalkorDB credentials.")
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter

View file

@ -10,26 +10,29 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
if config["vector_db_provider"] == "weaviate": if config["vector_db_provider"] == "weaviate":
from .weaviate_db import WeaviateAdapter from .weaviate_db import WeaviateAdapter
if config["vector_db_url"] is None and config["vector_db_key"] is None: if not (config["vector_db_url"] and config["vector_db_key"]):
raise EnvironmentError("Weaviate is not configured!") raise EnvironmentError("Missing requred Weaviate credentials!")
return WeaviateAdapter( return WeaviateAdapter(
config["vector_db_url"], config["vector_db_url"],
config["vector_db_key"], config["vector_db_key"],
embedding_engine = embedding_engine embedding_engine = embedding_engine
) )
elif config["vector_db_provider"] == "qdrant":
if config["vector_db_url"] and config["vector_db_key"]:
from .qdrant.QDrantAdapter import QDrantAdapter
return QDrantAdapter( elif config["vector_db_provider"] == "qdrant":
url = config["vector_db_url"], if not (config["vector_db_url"] and config["vector_db_key"]):
api_key = config["vector_db_key"], raise EnvironmentError("Missing requred Qdrant credentials!")
embedding_engine = embedding_engine
) from .qdrant.QDrantAdapter import QDrantAdapter
return QDrantAdapter(
url = config["vector_db_url"],
api_key = config["vector_db_key"],
embedding_engine = embedding_engine
)
elif config["vector_db_provider"] == "pgvector": elif config["vector_db_provider"] == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config from cognee.infrastructure.databases.relational import get_relational_config
from .pgvector.PGVectorAdapter import PGVectorAdapter
# Get configuration for postgres database # Get configuration for postgres database
relational_config = get_relational_config() relational_config = get_relational_config()
@ -39,16 +42,25 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
db_port = relational_config.db_port db_port = relational_config.db_port
db_name = relational_config.db_name db_name = relational_config.db_name
if not (db_host and db_port and db_name and db_username and db_password):
raise EnvironmentError("Missing requred pgvector credentials!")
connection_string: str = ( connection_string: str = (
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
) )
from .pgvector.PGVectorAdapter import PGVectorAdapter
return PGVectorAdapter( return PGVectorAdapter(
connection_string, connection_string,
config["vector_db_key"], config["vector_db_key"],
embedding_engine, embedding_engine,
) )
elif config["vector_db_provider"] == "falkordb": elif config["vector_db_provider"] == "falkordb":
if not (config["vector_db_url"] and config["vector_db_key"]):
raise EnvironmentError("Missing requred FalkorDB credentials!")
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalkorDBAdapter( return FalkorDBAdapter(
@ -56,6 +68,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
database_port = config["vector_db_port"], database_port = config["vector_db_port"],
embedding_engine = embedding_engine, embedding_engine = embedding_engine,
) )
else: else:
from .lancedb.LanceDBAdapter import LanceDBAdapter from .lancedb.LanceDBAdapter import LanceDBAdapter
@ -64,5 +77,3 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
api_key = config["vector_db_key"], api_key = config["vector_db_key"],
embedding_engine = embedding_engine, embedding_engine = embedding_engine,
) )
raise EnvironmentError(f"Vector provider not configured correctly: {config['vector_db_provider']}")

View file

@ -1,32 +1,39 @@
import asyncio import asyncio
from typing import List, Optional from typing import List, Optional
import litellm import litellm
from litellm import aembedding
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
litellm.set_verbose = False litellm.set_verbose = False
class LiteLLMEmbeddingEngine(EmbeddingEngine): class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str api_key: str
embedding_model: str endpoint: str
embedding_dimensions: int api_version: str
model: str
dimensions: int
def __init__( def __init__(
self, self,
embedding_model: Optional[str] = "text-embedding-3-large", model: Optional[str] = "text-embedding-3-large",
embedding_dimensions: Optional[int] = 3072, dimensions: Optional[int] = 3072,
api_key: str = None, api_key: str = None,
endpoint: str = None,
api_version: str = None,
): ):
self.api_key = api_key self.api_key = api_key
self.embedding_model = embedding_model self.endpoint = endpoint
self.embedding_dimensions = embedding_dimensions self.api_version = api_version
self.model = model
self.dimensions = dimensions
async def embed_text(self, text: List[str]) -> List[List[float]]: async def embed_text(self, text: List[str]) -> List[List[float]]:
async def get_embedding(text_): async def get_embedding(text_):
response = await aembedding( response = await litellm.aembedding(
self.embedding_model, self.model,
input = text_, input = text_,
api_key = self.api_key api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
) )
return response.data[0]["embedding"] return response.data[0]["embedding"]
@ -36,4 +43,4 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
return result return result
def get_vector_size(self) -> int: def get_vector_size(self) -> int:
return self.embedding_dimensions return self.dimensions

View file

@ -1,23 +1,16 @@
from typing import Optional
from functools import lru_cache from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
class EmbeddingConfig(BaseSettings): class EmbeddingConfig(BaseSettings):
openai_embedding_model: str = "text-embedding-3-large" embedding_model: Optional[str] = "text-embedding-3-large"
openai_embedding_dimensions: int = 3072 embedding_dimensions: Optional[int] = 3072
litellm_embedding_model: str = "BAAI/bge-large-en-v1.5" embedding_endpoint: Optional[str] = None
litellm_embedding_dimensions: int = 1024 embedding_api_key: Optional[str] = None
# embedding_engine:object = DefaultEmbeddingEngine(embedding_model=litellm_embedding_model, embedding_dimensions=litellm_embedding_dimensions) embedding_api_version: Optional[str] = None
model_config = SettingsConfigDict(env_file = ".env", extra = "allow") model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
def to_dict(self) -> dict:
return {
"openai_embedding_model": self.openai_embedding_model,
"openai_embedding_dimensions": self.openai_embedding_dimensions,
"litellm_embedding_model": self.litellm_embedding_model,
"litellm_embedding_dimensions": self.litellm_embedding_dimensions,
}
@lru_cache @lru_cache
def get_embedding_config(): def get_embedding_config():
return EmbeddingConfig() return EmbeddingConfig()

View file

@ -1,7 +1,17 @@
from cognee.infrastructure.llm import get_llm_config from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
from cognee.infrastructure.llm.config import get_llm_config
from .EmbeddingEngine import EmbeddingEngine from .EmbeddingEngine import EmbeddingEngine
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
def get_embedding_engine() -> EmbeddingEngine: def get_embedding_engine() -> EmbeddingEngine:
config = get_embedding_config()
llm_config = get_llm_config() llm_config = get_llm_config()
return LiteLLMEmbeddingEngine(api_key = llm_config.llm_api_key)
return LiteLLMEmbeddingEngine(
# If OpenAI API is used for embeddings, litellm needs only the api_key.
api_key = config.embedding_api_key or llm_config.llm_api_key,
endpoint = config.embedding_endpoint,
api_version = config.embedding_api_version,
model = config.embedding_model,
dimensions = config.embedding_dimensions,
)

View file

@ -7,6 +7,7 @@ class LLMConfig(BaseSettings):
llm_model: str = "gpt-4o-mini" llm_model: str = "gpt-4o-mini"
llm_endpoint: str = "" llm_endpoint: str = ""
llm_api_key: Optional[str] = None llm_api_key: Optional[str] = None
llm_api_version: Optional[str] = None
llm_temperature: float = 0.0 llm_temperature: float = 0.0
llm_streaming: bool = False llm_streaming: bool = False
transcription_model: str = "whisper-1" transcription_model: str = "whisper-1"
@ -19,6 +20,7 @@ class LLMConfig(BaseSettings):
"model": self.llm_model, "model": self.llm_model,
"endpoint": self.llm_endpoint, "endpoint": self.llm_endpoint,
"api_key": self.llm_api_key, "api_key": self.llm_api_key,
"api_version": self.llm_api_version,
"temperature": self.llm_temperature, "temperature": self.llm_temperature,
"streaming": self.llm_streaming, "streaming": self.llm_streaming,
"transcription_model": self.transcription_model "transcription_model": self.transcription_model

View file

@ -20,21 +20,33 @@ def get_llm_client():
raise ValueError("LLM API key is not set.") raise ValueError("LLM API key is not set.")
from .openai.adapter import OpenAIAdapter from .openai.adapter import OpenAIAdapter
return OpenAIAdapter(api_key=llm_config.llm_api_key, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, streaming=llm_config.llm_streaming)
return OpenAIAdapter(
api_key = llm_config.llm_api_key,
endpoint = llm_config.llm_endpoint,
api_version = llm_config.llm_api_version,
model = llm_config.llm_model,
transcription_model = llm_config.transcription_model,
streaming = llm_config.llm_streaming,
)
elif provider == LLMProvider.OLLAMA: elif provider == LLMProvider.OLLAMA:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.") raise ValueError("LLM API key is not set.")
from .generic_llm_api.adapter import GenericAPIAdapter from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama") return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
elif provider == LLMProvider.ANTHROPIC: elif provider == LLMProvider.ANTHROPIC:
from .anthropic.adapter import AnthropicAdapter from .anthropic.adapter import AnthropicAdapter
return AnthropicAdapter(llm_config.llm_model) return AnthropicAdapter(llm_config.llm_model)
elif provider == LLMProvider.CUSTOM: elif provider == LLMProvider.CUSTOM:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
raise ValueError("LLM API key is not set.") raise ValueError("LLM API key is not set.")
from .generic_llm_api.adapter import GenericAPIAdapter from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom") return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
else: else:
raise ValueError(f"Unsupported LLM provider: {provider}") raise ValueError(f"Unsupported LLM provider: {provider}")

View file

@ -1,174 +1,121 @@
import asyncio
import base64
import os import os
import base64
from pathlib import Path from pathlib import Path
from typing import List, Type from typing import Type
import openai import litellm
import instructor import instructor
from pydantic import BaseModel from pydantic import BaseModel
from tenacity import retry, stop_after_attempt
from cognee.base_config import get_base_config
from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
# from cognee.shared.data_models import MonitoringTool
class OpenAIAdapter(LLMInterface): class OpenAIAdapter(LLMInterface):
name = "OpenAI" name = "OpenAI"
model: str model: str
api_key: str api_key: str
api_version: str
"""Adapter for OpenAI's GPT-3, GPT=4 API""" """Adapter for OpenAI's GPT-3, GPT=4 API"""
def __init__(self, api_key: str, model: str, transcription_model:str, streaming: bool = False): def __init__(
base_config = get_base_config() self,
api_key: str,
# if base_config.monitoring_tool == MonitoringTool.LANGFUSE: endpoint: str,
# from langfuse.openai import AsyncOpenAI, OpenAI api_version: str,
# elif base_config.monitoring_tool == MonitoringTool.LANGSMITH: model: str,
# from langsmith import wrappers transcription_model: str,
# from openai import AsyncOpenAI streaming: bool = False,
# AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI()) ):
# else: self.aclient = instructor.from_litellm(litellm.acompletion)
from openai import AsyncOpenAI, OpenAI self.client = instructor.from_litellm(litellm.completion)
self.transcription_model = transcription_model
self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key))
self.client = instructor.from_openai(OpenAI(api_key = api_key))
self.base_openai_client = OpenAI(api_key = api_key)
self.transcription_model = "whisper-1"
self.model = model self.model = model
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming self.streaming = streaming
@retry(stop = stop_after_attempt(5))
def completions_with_backoff(self, **kwargs):
"""Wrapper around ChatCompletion.create w/ backoff"""
return openai.chat.completions.create(**kwargs)
@retry(stop = stop_after_attempt(5))
async def acompletions_with_backoff(self,**kwargs):
"""Wrapper around ChatCompletion.acreate w/ backoff"""
return await openai.chat.completions.acreate(**kwargs)
@retry(stop = stop_after_attempt(5))
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
"""Wrapper around Embedding.acreate w/ backoff"""
return await self.aclient.embeddings.create(input = input, model = model)
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting + is async"""
text = text.replace("\n", " ")
response = await self.aclient.embeddings.create(input = text, model = model)
embedding = response.data[0].embedding
return embedding
@retry(stop = stop_after_attempt(5))
def create_embedding_with_backoff(self, **kwargs):
"""Wrapper around Embedding.create w/ backoff"""
return openai.embeddings.create(**kwargs)
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting
:param text: str
:param model: str
"""
text = text.replace("\n", " ")
response = self.create_embedding_with_backoff(input=[text], model=model)
embedding = response.data[0].embedding
return embedding
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
"""To get multiple text embeddings in parallel, import/call this function
It specifies defaults + handles rate-limiting + is async"""
# Collect all coroutines
coroutines = (self.async_get_embedding_with_backoff(text, model)
for text, model in zip(texts, models))
# Run the coroutines in parallel and gather the results
embeddings = await asyncio.gather(*coroutines)
return embeddings
@retry(stop = stop_after_attempt(5))
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel: async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query.""" """Generate a response from a user query."""
return await self.aclient.chat.completions.create( return await self.aclient.chat.completions.create(
model = self.model, model = self.model,
messages = [ messages = [{
{ "role": "user",
"role": "user", "content": f"""Use the given format to
"content": f"""Use the given format to extract information from the following input: {text_input}. """,
extract information from the following input: {text_input}. """, }, {
}, "role": "system",
{"role": "system", "content": system_prompt}, "content": system_prompt,
], }],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model, response_model = response_model,
max_retries = 5,
) )
@retry(stop = stop_after_attempt(5))
def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel: def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query.""" """Generate a response from a user query."""
return self.client.chat.completions.create( return self.client.chat.completions.create(
model = self.model, model = self.model,
messages = [ messages = [{
{ "role": "user",
"role": "user", "content": f"""Use the given format to
"content": f"""Use the given format to extract information from the following input: {text_input}. """,
extract information from the following input: {text_input}. """, }, {
}, "role": "system",
{"role": "system", "content": system_prompt}, "content": system_prompt,
], }],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model, response_model = response_model,
max_retries = 5,
) )
@retry(stop = stop_after_attempt(5))
def create_transcript(self, input): def create_transcript(self, input):
"""Generate a audio transcript from a user query.""" """Generate a audio transcript from a user query."""
if not os.path.isfile(input): if not os.path.isfile(input):
raise FileNotFoundError(f"The file {input} does not exist.") raise FileNotFoundError(f"The file {input} does not exist.")
with open(input, 'rb') as audio_file: # with open(input, 'rb') as audio_file:
audio_data = audio_file.read() # audio_data = audio_file.read()
transcription = litellm.transcription(
model = self.transcription_model,
transcription = self.base_openai_client.audio.transcriptions.create( file = Path(input),
model=self.transcription_model , max_retries = 5,
file=Path(input), )
)
return transcription return transcription
@retry(stop = stop_after_attempt(5))
def transcribe_image(self, input) -> BaseModel: def transcribe_image(self, input) -> BaseModel:
with open(input, "rb") as image_file: with open(input, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode('utf-8') encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
return self.base_openai_client.chat.completions.create( return litellm.completion(
model=self.model, model = self.model,
messages=[ messages = [{
{ "role": "user",
"role": "user", "content": [
"content": [ {
{"type": "text", "text": "Whats in this image?"}, "type": "text",
{ "text": "Whats in this image?",
"type": "image_url", }, {
"image_url": { "type": "image_url",
"url": f"data:image/jpeg;base64,{encoded_image}", "image_url": {
}, "url": f"data:image/jpeg;base64,{encoded_image}",
}, },
], },
} ],
], }],
max_tokens=300, max_tokens = 300,
max_retries = 5,
) )
def show_prompt(self, text_input: str, system_prompt: str) -> str: def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query.""" """Format and display the prompt for a user query."""
if not text_input: if not text_input:

View file

@ -1,5 +1,7 @@
import json
import inspect import inspect
import logging import logging
from cognee.modules.settings import get_current_settings
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
@ -157,7 +159,7 @@ async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
}) })
raise error raise error
async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pipeline"): async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
user = await get_default_user() user = await get_default_user()
try: try:
@ -185,3 +187,10 @@ async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pi
}) })
raise error raise error
async def run_tasks(tasks: list[Task], data = None, pipeline_name: str = "default_pipeline"):
config = get_current_settings()
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent = 1))
async for result in run_tasks_with_telemetry(tasks, data, pipeline_name):
yield result

View file

@ -1,3 +1,4 @@
from .get_current_settings import get_current_settings
from .get_settings import get_settings, SettingsDict from .get_settings import get_settings, SettingsDict
from .save_llm_config import save_llm_config from .save_llm_config import save_llm_config
from .save_vector_db_config import save_vector_db_config from .save_vector_db_config import save_vector_db_config

View file

@ -0,0 +1,54 @@
from typing import TypedDict
from cognee.infrastructure.llm import get_llm_config
from cognee.infrastructure.databases.graph import get_graph_config
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.relational.config import get_relational_config
class LLMConfig(TypedDict):
model: str
provider: str
class VectorDBConfig(TypedDict):
url: str
provider: str
class GraphDBConfig(TypedDict):
url: str
provider: str
class RelationalConfig(TypedDict):
url: str
provider: str
class SettingsDict(TypedDict):
llm: LLMConfig
graph: GraphDBConfig
vector: VectorDBConfig
relational: RelationalConfig
def get_current_settings() -> SettingsDict:
llm_config = get_llm_config()
graph_config = get_graph_config()
vector_config = get_vectordb_config()
relational_config = get_relational_config()
return dict(
llm = {
"provider": llm_config.llm_provider,
"model": llm_config.llm_model,
},
graph = {
"provider": graph_config.graph_database_provider,
"url": graph_config.graph_database_url or graph_config.graph_file_path,
},
vector = {
"provider": vector_config.vector_db_provider,
"url": vector_config.vector_db_url,
},
relational = {
"provider": relational_config.db_provider,
"url": f"{relational_config.db_host}:{relational_config.db_port}" \
if relational_config.db_host \
else f"{relational_config.db_path}/{relational_config.db_name}",
},
)