Added anyscale support
This commit is contained in:
parent
d1435f6cd3
commit
277acf081b
3 changed files with 11 additions and 12 deletions
|
|
@ -43,10 +43,10 @@ class Config:
|
|||
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
|
||||
|
||||
# Model parameters
|
||||
llm_provider: str = "anthropic" #openai, or custom or ollama
|
||||
custom_model: str = "claude-3-haiku-20240307"
|
||||
custom_endpoint: str = "" # pass claude endpoint
|
||||
custom_key: Optional[str] = "custom"
|
||||
llm_provider: str = "custom" #openai, or custom or ollama
|
||||
custom_model: str = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
custom_endpoint: str = "https://api.endpoints.anyscale.com/v1" # pass claude endpoint
|
||||
custom_key: Optional[str] = os.getenv("ANYSCALE_API_KEY")
|
||||
ollama_endpoint: str = "http://localhost:11434/v1"
|
||||
ollama_key: Optional[str] = "ollama"
|
||||
ollama_model: str = "mistral:instruct"
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface
|
|||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
|
||||
|
||||
class OllamaAPIAdapter(LLMInterface):
|
||||
class GenericAPIAdapter(LLMInterface):
|
||||
"""Adapter for Ollama's API"""
|
||||
|
||||
def __init__(self, ollama_endpoint, api_key: str, model: str):
|
||||
|
|
@ -89,9 +89,8 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
},
|
||||
{"role": "system", "content": system_prompt},
|
||||
extract information from the following input: {text_input}. {system_prompt} """,
|
||||
}
|
||||
],
|
||||
response_model=response_model,
|
||||
)
|
||||
|
|
@ -3,7 +3,7 @@ from enum import Enum
|
|||
from cognee.config import Config
|
||||
from .anthropic.adapter import AnthropicAdapter
|
||||
from .openai.adapter import OpenAIAdapter
|
||||
from .ollama.adapter import OllamaAPIAdapter
|
||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
|
@ -26,13 +26,13 @@ def get_llm_client():
|
|||
return OpenAIAdapter(config.openai_key, config.model)
|
||||
elif provider == LLMProvider.OLLAMA:
|
||||
print("Using Ollama API")
|
||||
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
|
||||
return GenericAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
print("Using Anthropic API")
|
||||
return AnthropicAdapter(config.ollama_endpoint, config.ollama_key, config.custom_model)
|
||||
return AnthropicAdapter(config.custom_endpoint, config.custom_endpoint, config.custom_model)
|
||||
elif provider == LLMProvider.CUSTOM:
|
||||
print("Using Custom API")
|
||||
return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model)
|
||||
return GenericAPIAdapter(config.custom_endpoint, config.custom_key, config.custom_model)
|
||||
# Add your custom LLM provider here
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue