Merge branch 'feat/add_cloud_ollama_support'

This commit is contained in:
yangdx 2025-11-13 20:41:58 +08:00
commit 1889301597

View file

@ -1,4 +1,6 @@
from collections.abc import AsyncIterator
import os
import re
import pipmaster as pm
@ -22,10 +24,30 @@ from lightrag.exceptions import (
from lightrag.api import __api_version__
import numpy as np
from typing import Union
from typing import Optional, Union
from lightrag.utils import logger
_OLLAMA_CLOUD_HOST = "https://ollama.com"
_CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$")
def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]:
if host:
return host
try:
model_name_str = str(model) if model is not None else ""
except (TypeError, ValueError, AttributeError) as e:
logger.warning(f"Failed to convert model to string: {e}, using empty string")
model_name_str = ""
if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str):
logger.debug(
f"Detected cloud model '{model_name_str}', using Ollama Cloud host"
)
return _OLLAMA_CLOUD_HOST
return host
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
@ -53,6 +75,9 @@ async def _ollama_model_if_cache(
timeout = None
kwargs.pop("hashing_kv", None)
api_key = kwargs.pop("api_key", None)
# fallback to environment variable when not provided explicitly
if not api_key:
api_key = os.getenv("OLLAMA_API_KEY")
headers = {
"Content-Type": "application/json",
"User-Agent": f"LightRAG/{__api_version__}",
@ -60,6 +85,8 @@ async def _ollama_model_if_cache(
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
host = _coerce_host_for_cloud_model(host, model)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
try:
@ -144,6 +171,8 @@ async def ollama_model_complete(
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
api_key = kwargs.pop("api_key", None)
if not api_key:
api_key = os.getenv("OLLAMA_API_KEY")
headers = {
"Content-Type": "application/json",
"User-Agent": f"LightRAG/{__api_version__}",
@ -154,6 +183,8 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
host = _coerce_host_for_cloud_model(host, embed_model)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
try:
options = kwargs.pop("options", {})