From d1f82173206d7d9b415ba371da335fafba11a7e7 Mon Sep 17 00:00:00 2001 From: Boris Date: Fri, 22 Nov 2024 10:26:21 +0100 Subject: [PATCH] feat: COG-585 enable custom llm and embeding models --- .../databases/graph/get_graph_engine.py | 21 +- .../databases/vector/create_vector_engine.py | 41 ++-- .../embeddings/LiteLLMEmbeddingEngine.py | 29 ++- .../databases/vector/embeddings/config.py | 19 +- .../vector/embeddings/get_embedding_engine.py | 14 +- cognee/infrastructure/llm/config.py | 2 + cognee/infrastructure/llm/get_llm_client.py | 14 +- cognee/infrastructure/llm/openai/adapter.py | 189 +++++++----------- .../modules/pipelines/operations/run_tasks.py | 11 +- cognee/modules/settings/__init__.py | 1 + .../modules/settings/get_current_settings.py | 54 +++++ 11 files changed, 222 insertions(+), 173 deletions(-) create mode 100644 cognee/modules/settings/get_current_settings.py diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 038e878c0..5770bcda4 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -9,18 +9,21 @@ async def get_graph_engine() -> GraphDBInterface : config = get_graph_config() if config.graph_database_provider == "neo4j": - try: - from .neo4j_driver.adapter import Neo4jAdapter + if not (config.graph_database_url and config.graph_database_username and config.graph_database_password): + raise EnvironmentError("Missing required Neo4j credentials.") + + from .neo4j_driver.adapter import Neo4jAdapter - return Neo4jAdapter( - graph_database_url = config.graph_database_url, - graph_database_username = config.graph_database_username, - graph_database_password = config.graph_database_password - ) - except: - pass + return Neo4jAdapter( + graph_database_url = config.graph_database_url, + graph_database_username = config.graph_database_username, + graph_database_password = config.graph_database_password + ) elif config.graph_database_provider == "falkordb": + if not (config.graph_database_url and config.graph_database_username and config.graph_database_password): + raise EnvironmentError("Missing required FalkorDB credentials.") + from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index db5ef3129..4b4799ee7 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -10,26 +10,29 @@ def create_vector_engine(config: VectorConfig, embedding_engine): if config["vector_db_provider"] == "weaviate": from .weaviate_db import WeaviateAdapter - if config["vector_db_url"] is None and config["vector_db_key"] is None: - raise EnvironmentError("Weaviate is not configured!") + if not (config["vector_db_url"] and config["vector_db_key"]): + raise EnvironmentError("Missing requred Weaviate credentials!") return WeaviateAdapter( config["vector_db_url"], config["vector_db_key"], embedding_engine = embedding_engine ) - elif config["vector_db_provider"] == "qdrant": - if config["vector_db_url"] and config["vector_db_key"]: - from .qdrant.QDrantAdapter import QDrantAdapter - return QDrantAdapter( - url = config["vector_db_url"], - api_key = config["vector_db_key"], - embedding_engine = embedding_engine - ) + elif config["vector_db_provider"] == "qdrant": + if not (config["vector_db_url"] and config["vector_db_key"]): + raise EnvironmentError("Missing requred Qdrant credentials!") + + from .qdrant.QDrantAdapter import QDrantAdapter + + return QDrantAdapter( + url = config["vector_db_url"], + api_key = config["vector_db_key"], + embedding_engine = embedding_engine + ) + elif config["vector_db_provider"] == "pgvector": from cognee.infrastructure.databases.relational import get_relational_config - from .pgvector.PGVectorAdapter import PGVectorAdapter # Get configuration for postgres database relational_config = get_relational_config() @@ -39,16 +42,25 @@ def create_vector_engine(config: VectorConfig, embedding_engine): db_port = relational_config.db_port db_name = relational_config.db_name + if not (db_host and db_port and db_name and db_username and db_password): + raise EnvironmentError("Missing requred pgvector credentials!") + connection_string: str = ( - f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" + f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" ) - + + from .pgvector.PGVectorAdapter import PGVectorAdapter + return PGVectorAdapter( connection_string, config["vector_db_key"], embedding_engine, ) + elif config["vector_db_provider"] == "falkordb": + if not (config["vector_db_url"] and config["vector_db_key"]): + raise EnvironmentError("Missing requred FalkorDB credentials!") + from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter return FalkorDBAdapter( @@ -56,6 +68,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine): database_port = config["vector_db_port"], embedding_engine = embedding_engine, ) + else: from .lancedb.LanceDBAdapter import LanceDBAdapter @@ -64,5 +77,3 @@ def create_vector_engine(config: VectorConfig, embedding_engine): api_key = config["vector_db_key"], embedding_engine = embedding_engine, ) - - raise EnvironmentError(f"Vector provider not configured correctly: {config['vector_db_provider']}") diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index a41618f18..617698fd1 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -1,32 +1,39 @@ import asyncio from typing import List, Optional import litellm -from litellm import aembedding from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine litellm.set_verbose = False class LiteLLMEmbeddingEngine(EmbeddingEngine): api_key: str - embedding_model: str - embedding_dimensions: int + endpoint: str + api_version: str + model: str + dimensions: int def __init__( self, - embedding_model: Optional[str] = "text-embedding-3-large", - embedding_dimensions: Optional[int] = 3072, + model: Optional[str] = "text-embedding-3-large", + dimensions: Optional[int] = 3072, api_key: str = None, + endpoint: str = None, + api_version: str = None, ): self.api_key = api_key - self.embedding_model = embedding_model - self.embedding_dimensions = embedding_dimensions + self.endpoint = endpoint + self.api_version = api_version + self.model = model + self.dimensions = dimensions async def embed_text(self, text: List[str]) -> List[List[float]]: async def get_embedding(text_): - response = await aembedding( - self.embedding_model, + response = await litellm.aembedding( + self.model, input = text_, - api_key = self.api_key + api_key = self.api_key, + api_base = self.endpoint, + api_version = self.api_version ) return response.data[0]["embedding"] @@ -36,4 +43,4 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): return result def get_vector_size(self) -> int: - return self.embedding_dimensions + return self.dimensions diff --git a/cognee/infrastructure/databases/vector/embeddings/config.py b/cognee/infrastructure/databases/vector/embeddings/config.py index 8c03d389b..ecfb37204 100644 --- a/cognee/infrastructure/databases/vector/embeddings/config.py +++ b/cognee/infrastructure/databases/vector/embeddings/config.py @@ -1,23 +1,16 @@ +from typing import Optional from functools import lru_cache from pydantic_settings import BaseSettings, SettingsConfigDict class EmbeddingConfig(BaseSettings): - openai_embedding_model: str = "text-embedding-3-large" - openai_embedding_dimensions: int = 3072 - litellm_embedding_model: str = "BAAI/bge-large-en-v1.5" - litellm_embedding_dimensions: int = 1024 - # embedding_engine:object = DefaultEmbeddingEngine(embedding_model=litellm_embedding_model, embedding_dimensions=litellm_embedding_dimensions) + embedding_model: Optional[str] = "text-embedding-3-large" + embedding_dimensions: Optional[int] = 3072 + embedding_endpoint: Optional[str] = None + embedding_api_key: Optional[str] = None + embedding_api_version: Optional[str] = None model_config = SettingsConfigDict(env_file = ".env", extra = "allow") - def to_dict(self) -> dict: - return { - "openai_embedding_model": self.openai_embedding_model, - "openai_embedding_dimensions": self.openai_embedding_dimensions, - "litellm_embedding_model": self.litellm_embedding_model, - "litellm_embedding_dimensions": self.litellm_embedding_dimensions, - } - @lru_cache def get_embedding_config(): return EmbeddingConfig() diff --git a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py index a82876ef8..d2582fbf0 100644 --- a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +++ b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py @@ -1,7 +1,17 @@ -from cognee.infrastructure.llm import get_llm_config +from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config +from cognee.infrastructure.llm.config import get_llm_config from .EmbeddingEngine import EmbeddingEngine from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine def get_embedding_engine() -> EmbeddingEngine: + config = get_embedding_config() llm_config = get_llm_config() - return LiteLLMEmbeddingEngine(api_key = llm_config.llm_api_key) + + return LiteLLMEmbeddingEngine( + # If OpenAI API is used for embeddings, litellm needs only the api_key. + api_key = config.embedding_api_key or llm_config.llm_api_key, + endpoint = config.embedding_endpoint, + api_version = config.embedding_api_version, + model = config.embedding_model, + dimensions = config.embedding_dimensions, + ) diff --git a/cognee/infrastructure/llm/config.py b/cognee/infrastructure/llm/config.py index 37541adf2..d148042be 100644 --- a/cognee/infrastructure/llm/config.py +++ b/cognee/infrastructure/llm/config.py @@ -7,6 +7,7 @@ class LLMConfig(BaseSettings): llm_model: str = "gpt-4o-mini" llm_endpoint: str = "" llm_api_key: Optional[str] = None + llm_api_version: Optional[str] = None llm_temperature: float = 0.0 llm_streaming: bool = False transcription_model: str = "whisper-1" @@ -19,6 +20,7 @@ class LLMConfig(BaseSettings): "model": self.llm_model, "endpoint": self.llm_endpoint, "api_key": self.llm_api_key, + "api_version": self.llm_api_version, "temperature": self.llm_temperature, "streaming": self.llm_streaming, "transcription_model": self.transcription_model diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index 16ff5b320..1449d33b3 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -20,21 +20,33 @@ def get_llm_client(): raise ValueError("LLM API key is not set.") from .openai.adapter import OpenAIAdapter - return OpenAIAdapter(api_key=llm_config.llm_api_key, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, streaming=llm_config.llm_streaming) + + return OpenAIAdapter( + api_key = llm_config.llm_api_key, + endpoint = llm_config.llm_endpoint, + api_version = llm_config.llm_api_version, + model = llm_config.llm_model, + transcription_model = llm_config.transcription_model, + streaming = llm_config.llm_streaming, + ) + elif provider == LLMProvider.OLLAMA: if llm_config.llm_api_key is None: raise ValueError("LLM API key is not set.") from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama") + elif provider == LLMProvider.ANTHROPIC: from .anthropic.adapter import AnthropicAdapter return AnthropicAdapter(llm_config.llm_model) + elif provider == LLMProvider.CUSTOM: if llm_config.llm_api_key is None: raise ValueError("LLM API key is not set.") from .generic_llm_api.adapter import GenericAPIAdapter return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom") + else: raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index 2ad275e22..28cdfff4e 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -1,174 +1,121 @@ -import asyncio -import base64 import os +import base64 from pathlib import Path -from typing import List, Type +from typing import Type -import openai +import litellm import instructor from pydantic import BaseModel -from tenacity import retry, stop_after_attempt -from cognee.base_config import get_base_config from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.prompts import read_query_prompt -# from cognee.shared.data_models import MonitoringTool class OpenAIAdapter(LLMInterface): name = "OpenAI" model: str api_key: str + api_version: str """Adapter for OpenAI's GPT-3, GPT=4 API""" - def __init__(self, api_key: str, model: str, transcription_model:str, streaming: bool = False): - base_config = get_base_config() - - # if base_config.monitoring_tool == MonitoringTool.LANGFUSE: - # from langfuse.openai import AsyncOpenAI, OpenAI - # elif base_config.monitoring_tool == MonitoringTool.LANGSMITH: - # from langsmith import wrappers - # from openai import AsyncOpenAI - # AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI()) - # else: - from openai import AsyncOpenAI, OpenAI - - self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key)) - self.client = instructor.from_openai(OpenAI(api_key = api_key)) - self.base_openai_client = OpenAI(api_key = api_key) - self.transcription_model = "whisper-1" + def __init__( + self, + api_key: str, + endpoint: str, + api_version: str, + model: str, + transcription_model: str, + streaming: bool = False, + ): + self.aclient = instructor.from_litellm(litellm.acompletion) + self.client = instructor.from_litellm(litellm.completion) + self.transcription_model = transcription_model self.model = model self.api_key = api_key + self.endpoint = endpoint + self.api_version = api_version self.streaming = streaming - @retry(stop = stop_after_attempt(5)) - def completions_with_backoff(self, **kwargs): - """Wrapper around ChatCompletion.create w/ backoff""" - return openai.chat.completions.create(**kwargs) - @retry(stop = stop_after_attempt(5)) - async def acompletions_with_backoff(self,**kwargs): - """Wrapper around ChatCompletion.acreate w/ backoff""" - return await openai.chat.completions.acreate(**kwargs) - - @retry(stop = stop_after_attempt(5)) - async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"): - """Wrapper around Embedding.acreate w/ backoff""" - - return await self.aclient.embeddings.create(input = input, model = model) - - async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"): - """To get text embeddings, import/call this function - It specifies defaults + handles rate-limiting + is async""" - text = text.replace("\n", " ") - response = await self.aclient.embeddings.create(input = text, model = model) - embedding = response.data[0].embedding - return embedding - - @retry(stop = stop_after_attempt(5)) - def create_embedding_with_backoff(self, **kwargs): - """Wrapper around Embedding.create w/ backoff""" - return openai.embeddings.create(**kwargs) - - def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"): - """To get text embeddings, import/call this function - It specifies defaults + handles rate-limiting - :param text: str - :param model: str - """ - text = text.replace("\n", " ") - response = self.create_embedding_with_backoff(input=[text], model=model) - embedding = response.data[0].embedding - return embedding - - async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]): - """To get multiple text embeddings in parallel, import/call this function - It specifies defaults + handles rate-limiting + is async""" - # Collect all coroutines - coroutines = (self.async_get_embedding_with_backoff(text, model) - for text, model in zip(texts, models)) - - # Run the coroutines in parallel and gather the results - embeddings = await asyncio.gather(*coroutines) - - return embeddings - - @retry(stop = stop_after_attempt(5)) async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel: """Generate a response from a user query.""" return await self.aclient.chat.completions.create( model = self.model, - messages = [ - { - "role": "user", - "content": f"""Use the given format to - extract information from the following input: {text_input}. """, - }, - {"role": "system", "content": system_prompt}, - ], + messages = [{ + "role": "user", + "content": f"""Use the given format to + extract information from the following input: {text_input}. """, + }, { + "role": "system", + "content": system_prompt, + }], + api_key = self.api_key, + api_base = self.endpoint, + api_version = self.api_version, response_model = response_model, + max_retries = 5, ) - - @retry(stop = stop_after_attempt(5)) def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel: """Generate a response from a user query.""" return self.client.chat.completions.create( model = self.model, - messages = [ - { - "role": "user", - "content": f"""Use the given format to - extract information from the following input: {text_input}. """, - }, - {"role": "system", "content": system_prompt}, - ], + messages = [{ + "role": "user", + "content": f"""Use the given format to + extract information from the following input: {text_input}. """, + }, { + "role": "system", + "content": system_prompt, + }], + api_key = self.api_key, + api_base = self.endpoint, + api_version = self.api_version, response_model = response_model, + max_retries = 5, ) - @retry(stop = stop_after_attempt(5)) def create_transcript(self, input): """Generate a audio transcript from a user query.""" if not os.path.isfile(input): raise FileNotFoundError(f"The file {input} does not exist.") - with open(input, 'rb') as audio_file: - audio_data = audio_file.read() + # with open(input, 'rb') as audio_file: + # audio_data = audio_file.read() - - - transcription = self.base_openai_client.audio.transcriptions.create( - model=self.transcription_model , - file=Path(input), - ) + transcription = litellm.transcription( + model = self.transcription_model, + file = Path(input), + max_retries = 5, + ) return transcription - - @retry(stop = stop_after_attempt(5)) def transcribe_image(self, input) -> BaseModel: with open(input, "rb") as image_file: encoded_image = base64.b64encode(image_file.read()).decode('utf-8') - return self.base_openai_client.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What’s in this image?"}, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{encoded_image}", - }, - }, - ], - } - ], - max_tokens=300, + return litellm.completion( + model = self.model, + messages = [{ + "role": "user", + "content": [ + { + "type": "text", + "text": "What’s in this image?", + }, { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{encoded_image}", + }, + }, + ], + }], + max_tokens = 300, + max_retries = 5, ) + def show_prompt(self, text_input: str, system_prompt: str) -> str: """Format and display the prompt for a user query.""" if not text_input: diff --git a/cognee/modules/pipelines/operations/run_tasks.py b/cognee/modules/pipelines/operations/run_tasks.py index 7058bdb69..205670b90 100644 --- a/cognee/modules/pipelines/operations/run_tasks.py +++ b/cognee/modules/pipelines/operations/run_tasks.py @@ -1,5 +1,7 @@ +import json import inspect import logging +from cognee.modules.settings import get_current_settings from cognee.shared.utils import send_telemetry from cognee.modules.users.models import User from cognee.modules.users.methods import get_default_user @@ -157,7 +159,7 @@ async def run_tasks_base(tasks: list[Task], data = None, user: User = None): }) raise error -async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pipeline"): +async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str): user = await get_default_user() try: @@ -185,3 +187,10 @@ async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pi }) raise error + +async def run_tasks(tasks: list[Task], data = None, pipeline_name: str = "default_pipeline"): + config = get_current_settings() + logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent = 1)) + + async for result in run_tasks_with_telemetry(tasks, data, pipeline_name): + yield result diff --git a/cognee/modules/settings/__init__.py b/cognee/modules/settings/__init__.py index e705f8767..d7e67e73b 100644 --- a/cognee/modules/settings/__init__.py +++ b/cognee/modules/settings/__init__.py @@ -1,3 +1,4 @@ +from .get_current_settings import get_current_settings from .get_settings import get_settings, SettingsDict from .save_llm_config import save_llm_config from .save_vector_db_config import save_vector_db_config diff --git a/cognee/modules/settings/get_current_settings.py b/cognee/modules/settings/get_current_settings.py new file mode 100644 index 000000000..3d6bad896 --- /dev/null +++ b/cognee/modules/settings/get_current_settings.py @@ -0,0 +1,54 @@ +from typing import TypedDict +from cognee.infrastructure.llm import get_llm_config +from cognee.infrastructure.databases.graph import get_graph_config +from cognee.infrastructure.databases.vector import get_vectordb_config +from cognee.infrastructure.databases.relational.config import get_relational_config + +class LLMConfig(TypedDict): + model: str + provider: str + +class VectorDBConfig(TypedDict): + url: str + provider: str + +class GraphDBConfig(TypedDict): + url: str + provider: str + +class RelationalConfig(TypedDict): + url: str + provider: str + +class SettingsDict(TypedDict): + llm: LLMConfig + graph: GraphDBConfig + vector: VectorDBConfig + relational: RelationalConfig + +def get_current_settings() -> SettingsDict: + llm_config = get_llm_config() + graph_config = get_graph_config() + vector_config = get_vectordb_config() + relational_config = get_relational_config() + + return dict( + llm = { + "provider": llm_config.llm_provider, + "model": llm_config.llm_model, + }, + graph = { + "provider": graph_config.graph_database_provider, + "url": graph_config.graph_database_url or graph_config.graph_file_path, + }, + vector = { + "provider": vector_config.vector_db_provider, + "url": vector_config.vector_db_url, + }, + relational = { + "provider": relational_config.db_provider, + "url": f"{relational_config.db_host}:{relational_config.db_port}" \ + if relational_config.db_host \ + else f"{relational_config.db_path}/{relational_config.db_name}", + }, + )