modal ollama POC
This commit is contained in:
parent
2ab5683303
commit
c0cfcdd1f1
2 changed files with 26 additions and 42 deletions
|
|
@ -9,6 +9,8 @@ Group=ollama
|
||||||
Restart=always
|
Restart=always
|
||||||
RestartSec=3
|
RestartSec=3
|
||||||
Environment="PATH=$PATH"
|
Environment="PATH=$PATH"
|
||||||
|
Environment="OLLAMA_ORIGINS=*"
|
||||||
|
Environment="OLLAMA_HOST=0.0.0.0:11434"
|
||||||
|
|
||||||
[Install]
|
[Install]
|
||||||
WantedBy=default.target
|
WantedBy=default.target
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,13 @@ from fastapi import FastAPI, HTTPException
|
||||||
from typing import List, Any, Optional, Dict
|
from typing import List, Any, Optional, Dict
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import ollama
|
import ollama
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
MODEL = os.environ.get("MODEL", "llama3.3:70b")
|
import httpx
|
||||||
|
from fastapi import Request, Response
|
||||||
|
|
||||||
|
MODEL = os.environ.get("MODEL", "deepseek-r1:70b")
|
||||||
|
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "avr/sfr-embedding-mistral")
|
||||||
|
|
||||||
|
|
||||||
def pull() -> None:
|
def pull() -> None:
|
||||||
|
|
@ -16,6 +21,7 @@ def pull() -> None:
|
||||||
subprocess.run(["systemctl", "start", "ollama"])
|
subprocess.run(["systemctl", "start", "ollama"])
|
||||||
wait_for_ollama()
|
wait_for_ollama()
|
||||||
subprocess.run(["ollama", "pull", MODEL], stdout=subprocess.PIPE)
|
subprocess.run(["ollama", "pull", MODEL], stdout=subprocess.PIPE)
|
||||||
|
subprocess.run(["ollama", "pull", EMBEDDING_MODEL], stdout=subprocess.PIPE)
|
||||||
|
|
||||||
|
|
||||||
def wait_for_ollama(timeout: int = 30, interval: int = 2) -> None:
|
def wait_for_ollama(timeout: int = 30, interval: int = 2) -> None:
|
||||||
|
|
@ -55,48 +61,24 @@ app = modal.App(name="ollama", image=image)
|
||||||
api = FastAPI()
|
api = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
@api.api_route("/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||||
role: str = Field(..., description="The role of the message sender (e.g. 'user', 'assistant')")
|
async def proxy(full_path: str, request: Request):
|
||||||
content: str = Field(..., description="The content of the message")
|
# Construct the local Ollama endpoint URL
|
||||||
|
local_url = f"http://localhost:11434/{full_path}"
|
||||||
|
print(f"Forwarding {request.method} request to: {local_url}") # Logging the target URL
|
||||||
class ChatCompletionRequest(BaseModel):
|
# Forward the request
|
||||||
model: Optional[str] = Field(default=MODEL, description="The model to use for completion")
|
async with httpx.AsyncClient(timeout=httpx.Timeout(180.0)) as client:
|
||||||
messages: List[ChatMessage] = Field(
|
response = await client.request(
|
||||||
..., description="The messages to generate a completion for"
|
method=request.method,
|
||||||
)
|
url=local_url,
|
||||||
stream: bool = Field(default=False, description="Whether to stream the response")
|
headers=request.headers.raw,
|
||||||
format: Optional[Dict[str, Any]] = Field(
|
params=request.query_params,
|
||||||
default=None,
|
content=await request.body(),
|
||||||
description=(
|
|
||||||
"A JSON dictionary specifying any kind of structured output expected. "
|
|
||||||
"For example, it can define a JSON Schema to validate the response."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
options: Optional[Dict[str, Any]] = Field(
|
|
||||||
default=None, description="Additional options for the model (e.g., temperature, etc.)."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@api.post("/v1/api/chat")
|
|
||||||
async def v1_chat_completions(request: ChatCompletionRequest) -> Any:
|
|
||||||
try:
|
|
||||||
if not request.messages:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Messages array is required and cannot be empty",
|
|
||||||
)
|
|
||||||
response = ollama.chat(
|
|
||||||
model=request.model,
|
|
||||||
messages=[msg for msg in request.messages],
|
|
||||||
stream=request.stream,
|
|
||||||
format=request.format,
|
|
||||||
options=request.options,
|
|
||||||
)
|
)
|
||||||
return response
|
print(f"Received response with status: {response.status_code}") # Logging the response status
|
||||||
|
return Response(
|
||||||
except Exception as e:
|
content=response.content, status_code=response.status_code, headers=response.headers
|
||||||
raise HTTPException(status_code=500, detail=f"Error processing chat completion: {str(e)}")
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.cls(
|
@app.cls(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue