Some checks failed
security-audit / Bandit (scan statique) (push) Successful in 12s
security-audit / pip-audit (CVE dépendances) (push) Successful in 10s
security-audit / Scan secrets (grep) (push) Successful in 8s
tests / Lint (ruff + black) (push) Successful in 13s
tests / Tests unitaires (sans GPU) (push) Failing after 14s
tests / Tests sécurité (critique) (push) Has been skipped
Point de sauvegarde incluant les fichiers non committés des sessions précédentes (systemd, docs, agents, GPU manager). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
266 lines
8.6 KiB
Python
266 lines
8.6 KiB
Python
"""
|
|
Ollama Manager - Manages VLM model lifecycle via Ollama API
|
|
|
|
Handles:
|
|
- Loading/unloading models to/from VRAM
|
|
- Health checks and availability detection
|
|
- Keep-alive management for model persistence
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import List, Optional
|
|
|
|
import aiohttp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OllamaManager:
|
|
"""
|
|
Manages Ollama VLM model lifecycle.
|
|
|
|
Uses Ollama's REST API to control model loading/unloading.
|
|
|
|
Example:
|
|
>>> manager = OllamaManager()
|
|
>>> await manager.load_model()
|
|
>>> is_loaded = await manager.is_model_loaded()
|
|
>>> await manager.unload_model()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
endpoint: str = "http://localhost:11434",
|
|
model: str = "gemma4:e4b",
|
|
default_keep_alive: str = "5m"
|
|
):
|
|
"""
|
|
Initialize OllamaManager.
|
|
|
|
Args:
|
|
endpoint: Ollama API endpoint
|
|
model: Model name to manage
|
|
default_keep_alive: Default keep-alive duration
|
|
"""
|
|
self._endpoint = endpoint.rstrip("/")
|
|
self._model = model
|
|
self._default_keep_alive = default_keep_alive
|
|
self._session: Optional[aiohttp.ClientSession] = None
|
|
|
|
async def _get_session(self) -> aiohttp.ClientSession:
|
|
"""Get or create aiohttp session."""
|
|
if self._session is None or self._session.closed:
|
|
self._session = aiohttp.ClientSession(
|
|
timeout=aiohttp.ClientTimeout(total=60)
|
|
)
|
|
return self._session
|
|
|
|
async def close(self) -> None:
|
|
"""Close the HTTP session."""
|
|
if self._session and not self._session.closed:
|
|
await self._session.close()
|
|
|
|
|
|
# =========================================================================
|
|
# Health Check
|
|
# =========================================================================
|
|
|
|
def is_available(self) -> bool:
|
|
"""
|
|
Check if Ollama service is available (synchronous).
|
|
|
|
Returns:
|
|
True if Ollama is reachable
|
|
"""
|
|
import requests
|
|
try:
|
|
response = requests.get(f"{self._endpoint}/api/tags", timeout=5)
|
|
return response.status_code == 200
|
|
except Exception:
|
|
return False
|
|
|
|
async def is_available_async(self) -> bool:
|
|
"""
|
|
Check if Ollama service is available (async).
|
|
|
|
Returns:
|
|
True if Ollama is reachable
|
|
"""
|
|
try:
|
|
session = await self._get_session()
|
|
async with session.get(f"{self._endpoint}/api/tags") as response:
|
|
return response.status == 200
|
|
except Exception:
|
|
return False
|
|
|
|
# =========================================================================
|
|
# Model Management
|
|
# =========================================================================
|
|
|
|
async def load_model(self, keep_alive: Optional[str] = None) -> bool:
|
|
"""
|
|
Load the model into VRAM.
|
|
|
|
Uses a minimal generate request to trigger model loading.
|
|
|
|
Args:
|
|
keep_alive: How long to keep model loaded (e.g., "5m", "1h")
|
|
|
|
Returns:
|
|
True if model loaded successfully
|
|
"""
|
|
keep_alive = keep_alive or self._default_keep_alive
|
|
|
|
try:
|
|
session = await self._get_session()
|
|
|
|
# Send a minimal request to load the model
|
|
# Pour Qwen3, utiliser /nothink pour désactiver le thinking mode
|
|
prompt = "/nothink " if "qwen" in self._model.lower() else ""
|
|
|
|
payload = {
|
|
"model": self._model,
|
|
"prompt": prompt,
|
|
"keep_alive": keep_alive,
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": 0.0, # Déterministe pour la classification
|
|
"top_k": 1 # Plus rapide pour les tâches de classification
|
|
}
|
|
}
|
|
|
|
logger.debug(f"Loading model {self._model} with keep_alive={keep_alive}")
|
|
|
|
async with session.post(
|
|
f"{self._endpoint}/api/generate",
|
|
json=payload
|
|
) as response:
|
|
if response.status == 200:
|
|
logger.info(f"Model {self._model} loaded successfully")
|
|
return True
|
|
else:
|
|
text = await response.text()
|
|
logger.error(f"Failed to load model: {response.status} - {text}")
|
|
return False
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.error("Timeout loading model")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {e}")
|
|
return False
|
|
|
|
async def unload_model(self) -> bool:
|
|
"""
|
|
Unload the model from VRAM.
|
|
|
|
Sets keep_alive to 0 to trigger immediate unload.
|
|
|
|
Returns:
|
|
True if model unloaded successfully
|
|
"""
|
|
try:
|
|
session = await self._get_session()
|
|
|
|
# Send request with keep_alive=0 to unload
|
|
payload = {
|
|
"model": self._model,
|
|
"prompt": "",
|
|
"keep_alive": 0,
|
|
"stream": False
|
|
}
|
|
|
|
logger.debug(f"Unloading model {self._model}")
|
|
|
|
async with session.post(
|
|
f"{self._endpoint}/api/generate",
|
|
json=payload
|
|
) as response:
|
|
if response.status == 200:
|
|
logger.info(f"Model {self._model} unloaded successfully")
|
|
return True
|
|
else:
|
|
text = await response.text()
|
|
logger.error(f"Failed to unload model: {response.status} - {text}")
|
|
return False
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.error("Timeout unloading model")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error unloading model: {e}")
|
|
return False
|
|
|
|
async def is_model_loaded(self) -> bool:
|
|
"""
|
|
Check if the model is currently loaded in VRAM.
|
|
|
|
Returns:
|
|
True if model is loaded
|
|
"""
|
|
try:
|
|
session = await self._get_session()
|
|
|
|
async with session.get(f"{self._endpoint}/api/ps") as response:
|
|
if response.status == 200:
|
|
data = await response.json()
|
|
models = data.get("models", [])
|
|
|
|
for model_info in models:
|
|
if model_info.get("name", "").startswith(self._model.split(":")[0]):
|
|
return True
|
|
|
|
return False
|
|
else:
|
|
logger.warning(f"Failed to check loaded models: {response.status}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking loaded models: {e}")
|
|
return False
|
|
|
|
async def list_loaded_models(self) -> List[str]:
|
|
"""
|
|
List all currently loaded models.
|
|
|
|
Returns:
|
|
List of loaded model names
|
|
"""
|
|
try:
|
|
session = await self._get_session()
|
|
|
|
async with session.get(f"{self._endpoint}/api/ps") as response:
|
|
if response.status == 200:
|
|
data = await response.json()
|
|
models = data.get("models", [])
|
|
return [m.get("name", "") for m in models]
|
|
else:
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error listing loaded models: {e}")
|
|
return []
|
|
|
|
async def list_available_models(self) -> List[str]:
|
|
"""
|
|
List all available models (downloaded).
|
|
|
|
Returns:
|
|
List of available model names
|
|
"""
|
|
try:
|
|
session = await self._get_session()
|
|
|
|
async with session.get(f"{self._endpoint}/api/tags") as response:
|
|
if response.status == 200:
|
|
data = await response.json()
|
|
models = data.get("models", [])
|
|
return [m.get("name", "") for m in models]
|
|
else:
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error listing available models: {e}")
|
|
return []
|