diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 28d91075..9c2d17ee 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -1,5 +1,6 @@ from collections.abc import AsyncIterator import os +import re import pipmaster as pm @@ -23,10 +24,26 @@ from lightrag.exceptions import ( from lightrag.api import __api_version__ import numpy as np -from typing import Union +from typing import Optional, Union from lightrag.utils import logger +_OLLAMA_CLOUD_HOST = "https://ollama.com" +_CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$") + + +def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]: + if host: + return host + try: + model_name_str = str(model) if model is not None else "" + except Exception: + model_name_str = "" + if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str): + return _OLLAMA_CLOUD_HOST + return host + + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -64,15 +81,7 @@ async def _ollama_model_if_cache( if api_key: headers["Authorization"] = f"Bearer {api_key}" - # If this is a cloud model (names include '-cloud' or ':cloud'), default - # the host to Ollama cloud when no explicit host was provided. - try: - model_name_str = str(model) if model is not None else "" - except Exception: - model_name_str = "" - - if host is None and ("-cloud" in model_name_str or ":cloud" in model_name_str): - host = "https://ollama.com" + host = _coerce_host_for_cloud_model(host, model) ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) @@ -170,14 +179,7 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: host = kwargs.pop("host", None) timeout = kwargs.pop("timeout", None) - # If embed_model targets Ollama cloud, default host when not provided - try: - embed_model_name = str(embed_model) if embed_model is not None else "" - except Exception: - embed_model_name = "" - - if host is None and ("-cloud" in embed_model_name or ":cloud" in embed_model_name): - host = "https://ollama.com" + host = _coerce_host_for_cloud_model(host, embed_model) ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) try: