diff --git a/cognee/config.py b/cognee/config.py index a9a210f5a..2523715fd 100644 --- a/cognee/config.py +++ b/cognee/config.py @@ -43,10 +43,10 @@ class Config: graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl") # Model parameters - llm_provider: str = "anthropic" #openai, or custom or ollama - custom_model: str = "claude-3-haiku-20240307" - custom_endpoint: str = "" # pass claude endpoint - custom_key: Optional[str] = "custom" + llm_provider: str = "custom" #openai, or custom or ollama + custom_model: str = "mistralai/Mixtral-8x7B-Instruct-v0.1" + custom_endpoint: str = "https://api.endpoints.anyscale.com/v1" # pass claude endpoint + custom_key: Optional[str] = os.getenv("ANYSCALE_API_KEY") ollama_endpoint: str = "http://localhost:11434/v1" ollama_key: Optional[str] = "ollama" ollama_model: str = "mistral:instruct" diff --git a/cognee/infrastructure/llm/ollama/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py similarity index 96% rename from cognee/infrastructure/llm/ollama/adapter.py rename to cognee/infrastructure/llm/generic_llm_api/adapter.py index 2372bd1c8..0dac5e76d 100644 --- a/cognee/infrastructure/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/generic_llm_api/adapter.py @@ -12,7 +12,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.prompts import read_query_prompt -class OllamaAPIAdapter(LLMInterface): +class GenericAPIAdapter(LLMInterface): """Adapter for Ollama's API""" def __init__(self, ollama_endpoint, api_key: str, model: str): @@ -89,9 +89,8 @@ class OllamaAPIAdapter(LLMInterface): { "role": "user", "content": f"""Use the given format to - extract information from the following input: {text_input}. """, - }, - {"role": "system", "content": system_prompt}, + extract information from the following input: {text_input}. {system_prompt} """, + } ], response_model=response_model, ) diff --git a/cognee/infrastructure/llm/get_llm_client.py b/cognee/infrastructure/llm/get_llm_client.py index 7720f3a6e..37c685571 100644 --- a/cognee/infrastructure/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/get_llm_client.py @@ -3,7 +3,7 @@ from enum import Enum from cognee.config import Config from .anthropic.adapter import AnthropicAdapter from .openai.adapter import OpenAIAdapter -from .ollama.adapter import OllamaAPIAdapter +from .generic_llm_api.adapter import GenericAPIAdapter import logging logging.basicConfig(level=logging.INFO) @@ -26,13 +26,13 @@ def get_llm_client(): return OpenAIAdapter(config.openai_key, config.model) elif provider == LLMProvider.OLLAMA: print("Using Ollama API") - return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model) + return GenericAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model) elif provider == LLMProvider.ANTHROPIC: print("Using Anthropic API") - return AnthropicAdapter(config.ollama_endpoint, config.ollama_key, config.custom_model) + return AnthropicAdapter(config.custom_endpoint, config.custom_endpoint, config.custom_model) elif provider == LLMProvider.CUSTOM: print("Using Custom API") - return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model) + return GenericAPIAdapter(config.custom_endpoint, config.custom_key, config.custom_model) # Add your custom LLM provider here else: raise ValueError(f"Unsupported LLM provider: {provider}")