Add support for environment variable fallback for API key and default host for cloud models
This commit is contained in:
parent
fa9206d69a
commit
5127bf20ae
1 changed files with 25 additions and 0 deletions
|
|
@ -1,4 +1,5 @@
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
import os
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
|
|
@ -53,6 +54,9 @@ async def _ollama_model_if_cache(
|
||||||
timeout = None
|
timeout = None
|
||||||
kwargs.pop("hashing_kv", None)
|
kwargs.pop("hashing_kv", None)
|
||||||
api_key = kwargs.pop("api_key", None)
|
api_key = kwargs.pop("api_key", None)
|
||||||
|
# fallback to environment variable when not provided explicitly
|
||||||
|
if not api_key:
|
||||||
|
api_key = os.getenv("OLLAMA_API_KEY")
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"User-Agent": f"LightRAG/{__api_version__}",
|
"User-Agent": f"LightRAG/{__api_version__}",
|
||||||
|
|
@ -60,6 +64,16 @@ async def _ollama_model_if_cache(
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
# If this is a cloud model (names include '-cloud' or ':cloud'), default
|
||||||
|
# the host to Ollama cloud when no explicit host was provided.
|
||||||
|
try:
|
||||||
|
model_name_str = str(model) if model is not None else ""
|
||||||
|
except Exception:
|
||||||
|
model_name_str = ""
|
||||||
|
|
||||||
|
if host is None and ("-cloud" in model_name_str or ":cloud" in model_name_str):
|
||||||
|
host = "https://ollama.com"
|
||||||
|
|
||||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -144,6 +158,8 @@ async def ollama_model_complete(
|
||||||
|
|
||||||
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||||
api_key = kwargs.pop("api_key", None)
|
api_key = kwargs.pop("api_key", None)
|
||||||
|
if not api_key:
|
||||||
|
api_key = os.getenv("OLLAMA_API_KEY")
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"User-Agent": f"LightRAG/{__api_version__}",
|
"User-Agent": f"LightRAG/{__api_version__}",
|
||||||
|
|
@ -154,6 +170,15 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||||
host = kwargs.pop("host", None)
|
host = kwargs.pop("host", None)
|
||||||
timeout = kwargs.pop("timeout", None)
|
timeout = kwargs.pop("timeout", None)
|
||||||
|
|
||||||
|
# If embed_model targets Ollama cloud, default host when not provided
|
||||||
|
try:
|
||||||
|
embed_model_name = str(embed_model) if embed_model is not None else ""
|
||||||
|
except Exception:
|
||||||
|
embed_model_name = ""
|
||||||
|
|
||||||
|
if host is None and ("-cloud" in embed_model_name or ":cloud" in embed_model_name):
|
||||||
|
host = "https://ollama.com"
|
||||||
|
|
||||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||||
try:
|
try:
|
||||||
options = kwargs.pop("options", {})
|
options = kwargs.pop("options", {})
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue