modal ollama POC

This commit is contained in:
hajdul88 2025-03-28 18:04:52 +01:00
parent 2ab5683303
commit c0cfcdd1f1
2 changed files with 26 additions and 42 deletions

View file

@ -9,6 +9,8 @@ Group=ollama
Restart=always Restart=always
RestartSec=3 RestartSec=3
Environment="PATH=$PATH" Environment="PATH=$PATH"
Environment="OLLAMA_ORIGINS=*"
Environment="OLLAMA_HOST=0.0.0.0:11434"
[Install] [Install]
WantedBy=default.target WantedBy=default.target

View file

@ -6,8 +6,13 @@ from fastapi import FastAPI, HTTPException
from typing import List, Any, Optional, Dict from typing import List, Any, Optional, Dict
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import ollama import ollama
from fastapi.middleware.cors import CORSMiddleware
MODEL = os.environ.get("MODEL", "llama3.3:70b") import httpx
from fastapi import Request, Response
MODEL = os.environ.get("MODEL", "deepseek-r1:70b")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "avr/sfr-embedding-mistral")
def pull() -> None: def pull() -> None:
@ -16,6 +21,7 @@ def pull() -> None:
subprocess.run(["systemctl", "start", "ollama"]) subprocess.run(["systemctl", "start", "ollama"])
wait_for_ollama() wait_for_ollama()
subprocess.run(["ollama", "pull", MODEL], stdout=subprocess.PIPE) subprocess.run(["ollama", "pull", MODEL], stdout=subprocess.PIPE)
subprocess.run(["ollama", "pull", EMBEDDING_MODEL], stdout=subprocess.PIPE)
def wait_for_ollama(timeout: int = 30, interval: int = 2) -> None: def wait_for_ollama(timeout: int = 30, interval: int = 2) -> None:
@ -55,48 +61,24 @@ app = modal.App(name="ollama", image=image)
api = FastAPI() api = FastAPI()
class ChatMessage(BaseModel): @api.api_route("/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE"])
role: str = Field(..., description="The role of the message sender (e.g. 'user', 'assistant')") async def proxy(full_path: str, request: Request):
content: str = Field(..., description="The content of the message") # Construct the local Ollama endpoint URL
local_url = f"http://localhost:11434/{full_path}"
print(f"Forwarding {request.method} request to: {local_url}") # Logging the target URL
class ChatCompletionRequest(BaseModel): # Forward the request
model: Optional[str] = Field(default=MODEL, description="The model to use for completion") async with httpx.AsyncClient(timeout=httpx.Timeout(180.0)) as client:
messages: List[ChatMessage] = Field( response = await client.request(
..., description="The messages to generate a completion for" method=request.method,
) url=local_url,
stream: bool = Field(default=False, description="Whether to stream the response") headers=request.headers.raw,
format: Optional[Dict[str, Any]] = Field( params=request.query_params,
default=None, content=await request.body(),
description=(
"A JSON dictionary specifying any kind of structured output expected. "
"For example, it can define a JSON Schema to validate the response."
),
)
options: Optional[Dict[str, Any]] = Field(
default=None, description="Additional options for the model (e.g., temperature, etc.)."
)
@api.post("/v1/api/chat")
async def v1_chat_completions(request: ChatCompletionRequest) -> Any:
try:
if not request.messages:
raise HTTPException(
status_code=400,
detail="Messages array is required and cannot be empty",
)
response = ollama.chat(
model=request.model,
messages=[msg for msg in request.messages],
stream=request.stream,
format=request.format,
options=request.options,
) )
return response print(f"Received response with status: {response.status_code}") # Logging the response status
return Response(
except Exception as e: content=response.content, status_code=response.status_code, headers=response.headers
raise HTTPException(status_code=500, detail=f"Error processing chat completion: {str(e)}") )
@app.cls( @app.cls(