Add a better regex
This commit is contained in:
parent
f7432a260e
commit
844537e378
1 changed files with 20 additions and 18 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
|
|
@ -23,10 +24,26 @@ from lightrag.exceptions import (
|
||||||
from lightrag.api import __api_version__
|
from lightrag.api import __api_version__
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
|
|
||||||
|
|
||||||
|
_OLLAMA_CLOUD_HOST = "https://ollama.com"
|
||||||
|
_CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$")
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]:
|
||||||
|
if host:
|
||||||
|
return host
|
||||||
|
try:
|
||||||
|
model_name_str = str(model) if model is not None else ""
|
||||||
|
except Exception:
|
||||||
|
model_name_str = ""
|
||||||
|
if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str):
|
||||||
|
return _OLLAMA_CLOUD_HOST
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
|
@ -64,15 +81,7 @@ async def _ollama_model_if_cache(
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
# If this is a cloud model (names include '-cloud' or ':cloud'), default
|
host = _coerce_host_for_cloud_model(host, model)
|
||||||
# the host to Ollama cloud when no explicit host was provided.
|
|
||||||
try:
|
|
||||||
model_name_str = str(model) if model is not None else ""
|
|
||||||
except Exception:
|
|
||||||
model_name_str = ""
|
|
||||||
|
|
||||||
if host is None and ("-cloud" in model_name_str or ":cloud" in model_name_str):
|
|
||||||
host = "https://ollama.com"
|
|
||||||
|
|
||||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||||
|
|
||||||
|
|
@ -170,14 +179,7 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||||
host = kwargs.pop("host", None)
|
host = kwargs.pop("host", None)
|
||||||
timeout = kwargs.pop("timeout", None)
|
timeout = kwargs.pop("timeout", None)
|
||||||
|
|
||||||
# If embed_model targets Ollama cloud, default host when not provided
|
host = _coerce_host_for_cloud_model(host, embed_model)
|
||||||
try:
|
|
||||||
embed_model_name = str(embed_model) if embed_model is not None else ""
|
|
||||||
except Exception:
|
|
||||||
embed_model_name = ""
|
|
||||||
|
|
||||||
if host is None and ("-cloud" in embed_model_name or ":cloud" in embed_model_name):
|
|
||||||
host = "https://ollama.com"
|
|
||||||
|
|
||||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue