Added anyscale support

This commit is contained in:
Vasilije 2024-03-28 11:54:23 +01:00
parent d1435f6cd3
commit 277acf081b
3 changed files with 11 additions and 12 deletions

View file

@ -43,10 +43,10 @@ class Config:
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
# Model parameters
llm_provider: str = "anthropic" #openai, or custom or ollama
custom_model: str = "claude-3-haiku-20240307"
custom_endpoint: str = "" # pass claude endpoint
custom_key: Optional[str] = "custom"
llm_provider: str = "custom" #openai, or custom or ollama
custom_model: str = "mistralai/Mixtral-8x7B-Instruct-v0.1"
custom_endpoint: str = "https://api.endpoints.anyscale.com/v1" # pass claude endpoint
custom_key: Optional[str] = os.getenv("ANYSCALE_API_KEY")
ollama_endpoint: str = "http://localhost:11434/v1"
ollama_key: Optional[str] = "ollama"
ollama_model: str = "mistral:instruct"

View file

@ -12,7 +12,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
class OllamaAPIAdapter(LLMInterface):
class GenericAPIAdapter(LLMInterface):
"""Adapter for Ollama's API"""
def __init__(self, ollama_endpoint, api_key: str, model: str):
@ -89,9 +89,8 @@ class OllamaAPIAdapter(LLMInterface):
{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
},
{"role": "system", "content": system_prompt},
extract information from the following input: {text_input}. {system_prompt} """,
}
],
response_model=response_model,
)

View file

@ -3,7 +3,7 @@ from enum import Enum
from cognee.config import Config
from .anthropic.adapter import AnthropicAdapter
from .openai.adapter import OpenAIAdapter
from .ollama.adapter import OllamaAPIAdapter
from .generic_llm_api.adapter import GenericAPIAdapter
import logging
logging.basicConfig(level=logging.INFO)
@ -26,13 +26,13 @@ def get_llm_client():
return OpenAIAdapter(config.openai_key, config.model)
elif provider == LLMProvider.OLLAMA:
print("Using Ollama API")
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
return GenericAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
elif provider == LLMProvider.ANTHROPIC:
print("Using Anthropic API")
return AnthropicAdapter(config.ollama_endpoint, config.ollama_key, config.custom_model)
return AnthropicAdapter(config.custom_endpoint, config.custom_endpoint, config.custom_model)
elif provider == LLMProvider.CUSTOM:
print("Using Custom API")
return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model)
return GenericAPIAdapter(config.custom_endpoint, config.custom_key, config.custom_model)
# Add your custom LLM provider here
else:
raise ValueError(f"Unsupported LLM provider: {provider}")