diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index b013496e..670351bc 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -1,4 +1,6 @@ from collections.abc import AsyncIterator +import os +import re import pipmaster as pm @@ -22,10 +24,30 @@ from lightrag.exceptions import ( from lightrag.api import __api_version__ import numpy as np -from typing import Union +from typing import Optional, Union from lightrag.utils import logger +_OLLAMA_CLOUD_HOST = "https://ollama.com" +_CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$") + + +def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]: + if host: + return host + try: + model_name_str = str(model) if model is not None else "" + except (TypeError, ValueError, AttributeError) as e: + logger.warning(f"Failed to convert model to string: {e}, using empty string") + model_name_str = "" + if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str): + logger.debug( + f"Detected cloud model '{model_name_str}', using Ollama Cloud host" + ) + return _OLLAMA_CLOUD_HOST + return host + + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -53,6 +75,9 @@ async def _ollama_model_if_cache( timeout = None kwargs.pop("hashing_kv", None) api_key = kwargs.pop("api_key", None) + # fallback to environment variable when not provided explicitly + if not api_key: + api_key = os.getenv("OLLAMA_API_KEY") headers = { "Content-Type": "application/json", "User-Agent": f"LightRAG/{__api_version__}", @@ -60,6 +85,8 @@ async def _ollama_model_if_cache( if api_key: headers["Authorization"] = f"Bearer {api_key}" + host = _coerce_host_for_cloud_model(host, model) + ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) try: @@ -144,6 +171,8 @@ async def ollama_model_complete( async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: api_key = kwargs.pop("api_key", None) + if not api_key: + api_key = os.getenv("OLLAMA_API_KEY") headers = { "Content-Type": "application/json", "User-Agent": f"LightRAG/{__api_version__}", @@ -154,6 +183,8 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: host = kwargs.pop("host", None) timeout = kwargs.pop("timeout", None) + host = _coerce_host_for_cloud_model(host, embed_model) + ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) try: options = kwargs.pop("options", {})