Added anyscale support
This commit is contained in:
parent
d1435f6cd3
commit
277acf081b
3 changed files with 11 additions and 12 deletions
|
|
@ -43,10 +43,10 @@ class Config:
|
||||||
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
|
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
llm_provider: str = "anthropic" #openai, or custom or ollama
|
llm_provider: str = "custom" #openai, or custom or ollama
|
||||||
custom_model: str = "claude-3-haiku-20240307"
|
custom_model: str = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
custom_endpoint: str = "" # pass claude endpoint
|
custom_endpoint: str = "https://api.endpoints.anyscale.com/v1" # pass claude endpoint
|
||||||
custom_key: Optional[str] = "custom"
|
custom_key: Optional[str] = os.getenv("ANYSCALE_API_KEY")
|
||||||
ollama_endpoint: str = "http://localhost:11434/v1"
|
ollama_endpoint: str = "http://localhost:11434/v1"
|
||||||
ollama_key: Optional[str] = "ollama"
|
ollama_key: Optional[str] = "ollama"
|
||||||
ollama_model: str = "mistral:instruct"
|
ollama_model: str = "mistral:instruct"
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
|
|
||||||
|
|
||||||
class OllamaAPIAdapter(LLMInterface):
|
class GenericAPIAdapter(LLMInterface):
|
||||||
"""Adapter for Ollama's API"""
|
"""Adapter for Ollama's API"""
|
||||||
|
|
||||||
def __init__(self, ollama_endpoint, api_key: str, model: str):
|
def __init__(self, ollama_endpoint, api_key: str, model: str):
|
||||||
|
|
@ -89,9 +89,8 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"""Use the given format to
|
"content": f"""Use the given format to
|
||||||
extract information from the following input: {text_input}. """,
|
extract information from the following input: {text_input}. {system_prompt} """,
|
||||||
},
|
}
|
||||||
{"role": "system", "content": system_prompt},
|
|
||||||
],
|
],
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
@ -3,7 +3,7 @@ from enum import Enum
|
||||||
from cognee.config import Config
|
from cognee.config import Config
|
||||||
from .anthropic.adapter import AnthropicAdapter
|
from .anthropic.adapter import AnthropicAdapter
|
||||||
from .openai.adapter import OpenAIAdapter
|
from .openai.adapter import OpenAIAdapter
|
||||||
from .ollama.adapter import OllamaAPIAdapter
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
@ -26,13 +26,13 @@ def get_llm_client():
|
||||||
return OpenAIAdapter(config.openai_key, config.model)
|
return OpenAIAdapter(config.openai_key, config.model)
|
||||||
elif provider == LLMProvider.OLLAMA:
|
elif provider == LLMProvider.OLLAMA:
|
||||||
print("Using Ollama API")
|
print("Using Ollama API")
|
||||||
return OllamaAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
|
return GenericAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
|
||||||
elif provider == LLMProvider.ANTHROPIC:
|
elif provider == LLMProvider.ANTHROPIC:
|
||||||
print("Using Anthropic API")
|
print("Using Anthropic API")
|
||||||
return AnthropicAdapter(config.ollama_endpoint, config.ollama_key, config.custom_model)
|
return AnthropicAdapter(config.custom_endpoint, config.custom_endpoint, config.custom_model)
|
||||||
elif provider == LLMProvider.CUSTOM:
|
elif provider == LLMProvider.CUSTOM:
|
||||||
print("Using Custom API")
|
print("Using Custom API")
|
||||||
return OllamaAPIAdapter(config.custom_endpoint, config.custom_key, config.model)
|
return GenericAPIAdapter(config.custom_endpoint, config.custom_key, config.custom_model)
|
||||||
# Add your custom LLM provider here
|
# Add your custom LLM provider here
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue