From c0cfcdd1f1a4c03db8bc91d5670063320de55ce6 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Fri, 28 Mar 2025 18:04:52 +0100 Subject: [PATCH] modal ollama POC --- ollama_modal_deployment/ollama.service | 2 + ollama_modal_deployment/ollama_api.py | 66 ++++++++++---------------- 2 files changed, 26 insertions(+), 42 deletions(-) diff --git a/ollama_modal_deployment/ollama.service b/ollama_modal_deployment/ollama.service index 9deee1da0..6359adb17 100644 --- a/ollama_modal_deployment/ollama.service +++ b/ollama_modal_deployment/ollama.service @@ -9,6 +9,8 @@ Group=ollama Restart=always RestartSec=3 Environment="PATH=$PATH" +Environment="OLLAMA_ORIGINS=*" +Environment="OLLAMA_HOST=0.0.0.0:11434" [Install] WantedBy=default.target diff --git a/ollama_modal_deployment/ollama_api.py b/ollama_modal_deployment/ollama_api.py index 3bae77dd8..73f22eaf8 100644 --- a/ollama_modal_deployment/ollama_api.py +++ b/ollama_modal_deployment/ollama_api.py @@ -6,8 +6,13 @@ from fastapi import FastAPI, HTTPException from typing import List, Any, Optional, Dict from pydantic import BaseModel, Field import ollama +from fastapi.middleware.cors import CORSMiddleware -MODEL = os.environ.get("MODEL", "llama3.3:70b") +import httpx +from fastapi import Request, Response + +MODEL = os.environ.get("MODEL", "deepseek-r1:70b") +EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "avr/sfr-embedding-mistral") def pull() -> None: @@ -16,6 +21,7 @@ def pull() -> None: subprocess.run(["systemctl", "start", "ollama"]) wait_for_ollama() subprocess.run(["ollama", "pull", MODEL], stdout=subprocess.PIPE) + subprocess.run(["ollama", "pull", EMBEDDING_MODEL], stdout=subprocess.PIPE) def wait_for_ollama(timeout: int = 30, interval: int = 2) -> None: @@ -55,48 +61,24 @@ app = modal.App(name="ollama", image=image) api = FastAPI() -class ChatMessage(BaseModel): - role: str = Field(..., description="The role of the message sender (e.g. 'user', 'assistant')") - content: str = Field(..., description="The content of the message") - - -class ChatCompletionRequest(BaseModel): - model: Optional[str] = Field(default=MODEL, description="The model to use for completion") - messages: List[ChatMessage] = Field( - ..., description="The messages to generate a completion for" - ) - stream: bool = Field(default=False, description="Whether to stream the response") - format: Optional[Dict[str, Any]] = Field( - default=None, - description=( - "A JSON dictionary specifying any kind of structured output expected. " - "For example, it can define a JSON Schema to validate the response." - ), - ) - options: Optional[Dict[str, Any]] = Field( - default=None, description="Additional options for the model (e.g., temperature, etc.)." - ) - - -@api.post("/v1/api/chat") -async def v1_chat_completions(request: ChatCompletionRequest) -> Any: - try: - if not request.messages: - raise HTTPException( - status_code=400, - detail="Messages array is required and cannot be empty", - ) - response = ollama.chat( - model=request.model, - messages=[msg for msg in request.messages], - stream=request.stream, - format=request.format, - options=request.options, +@api.api_route("/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(full_path: str, request: Request): + # Construct the local Ollama endpoint URL + local_url = f"http://localhost:11434/{full_path}" + print(f"Forwarding {request.method} request to: {local_url}") # Logging the target URL + # Forward the request + async with httpx.AsyncClient(timeout=httpx.Timeout(180.0)) as client: + response = await client.request( + method=request.method, + url=local_url, + headers=request.headers.raw, + params=request.query_params, + content=await request.body(), ) - return response - - except Exception as e: - raise HTTPException(status_code=500, detail=f"Error processing chat completion: {str(e)}") + print(f"Received response with status: {response.status_code}") # Logging the response status + return Response( + content=response.content, status_code=response.status_code, headers=response.headers + ) @app.cls(