""" Ollama Manager - Manages VLM model lifecycle via Ollama API Handles: - Loading/unloading models to/from VRAM - Health checks and availability detection - Keep-alive management for model persistence """ import asyncio import logging from typing import List, Optional import aiohttp logger = logging.getLogger(__name__) class OllamaManager: """ Manages Ollama VLM model lifecycle. Uses Ollama's REST API to control model loading/unloading. Example: >>> manager = OllamaManager() >>> await manager.load_model() >>> is_loaded = await manager.is_model_loaded() >>> await manager.unload_model() """ def __init__( self, endpoint: str = "http://localhost:11434", model: str = "qwen3-vl:8b", default_keep_alive: str = "5m" ): """ Initialize OllamaManager. Args: endpoint: Ollama API endpoint model: Model name to manage default_keep_alive: Default keep-alive duration """ self._endpoint = endpoint.rstrip("/") self._model = model self._default_keep_alive = default_keep_alive self._session: Optional[aiohttp.ClientSession] = None async def _get_session(self) -> aiohttp.ClientSession: """Get or create aiohttp session.""" if self._session is None or self._session.closed: self._session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=60) ) return self._session async def close(self) -> None: """Close the HTTP session.""" if self._session and not self._session.closed: await self._session.close() # ========================================================================= # Health Check # ========================================================================= def is_available(self) -> bool: """ Check if Ollama service is available (synchronous). Returns: True if Ollama is reachable """ import requests try: response = requests.get(f"{self._endpoint}/api/tags", timeout=5) return response.status_code == 200 except Exception: return False async def is_available_async(self) -> bool: """ Check if Ollama service is available (async). Returns: True if Ollama is reachable """ try: session = await self._get_session() async with session.get(f"{self._endpoint}/api/tags") as response: return response.status == 200 except Exception: return False # ========================================================================= # Model Management # ========================================================================= async def load_model(self, keep_alive: Optional[str] = None) -> bool: """ Load the model into VRAM. Uses a minimal generate request to trigger model loading. Args: keep_alive: How long to keep model loaded (e.g., "5m", "1h") Returns: True if model loaded successfully """ keep_alive = keep_alive or self._default_keep_alive try: session = await self._get_session() # Send a minimal request to load the model # Pour Qwen3, utiliser /nothink pour désactiver le thinking mode prompt = "/nothink " if "qwen" in self._model.lower() else "" payload = { "model": self._model, "prompt": prompt, "keep_alive": keep_alive, "stream": False, "options": { "temperature": 0.0, # Déterministe pour la classification "top_k": 1 # Plus rapide pour les tâches de classification } } logger.debug(f"Loading model {self._model} with keep_alive={keep_alive}") async with session.post( f"{self._endpoint}/api/generate", json=payload ) as response: if response.status == 200: logger.info(f"Model {self._model} loaded successfully") return True else: text = await response.text() logger.error(f"Failed to load model: {response.status} - {text}") return False except asyncio.TimeoutError: logger.error("Timeout loading model") return False except Exception as e: logger.error(f"Error loading model: {e}") return False async def unload_model(self) -> bool: """ Unload the model from VRAM. Sets keep_alive to 0 to trigger immediate unload. Returns: True if model unloaded successfully """ try: session = await self._get_session() # Send request with keep_alive=0 to unload payload = { "model": self._model, "prompt": "", "keep_alive": 0, "stream": False } logger.debug(f"Unloading model {self._model}") async with session.post( f"{self._endpoint}/api/generate", json=payload ) as response: if response.status == 200: logger.info(f"Model {self._model} unloaded successfully") return True else: text = await response.text() logger.error(f"Failed to unload model: {response.status} - {text}") return False except asyncio.TimeoutError: logger.error("Timeout unloading model") return False except Exception as e: logger.error(f"Error unloading model: {e}") return False async def is_model_loaded(self) -> bool: """ Check if the model is currently loaded in VRAM. Returns: True if model is loaded """ try: session = await self._get_session() async with session.get(f"{self._endpoint}/api/ps") as response: if response.status == 200: data = await response.json() models = data.get("models", []) for model_info in models: if model_info.get("name", "").startswith(self._model.split(":")[0]): return True return False else: logger.warning(f"Failed to check loaded models: {response.status}") return False except Exception as e: logger.error(f"Error checking loaded models: {e}") return False async def list_loaded_models(self) -> List[str]: """ List all currently loaded models. Returns: List of loaded model names """ try: session = await self._get_session() async with session.get(f"{self._endpoint}/api/ps") as response: if response.status == 200: data = await response.json() models = data.get("models", []) return [m.get("name", "") for m in models] else: return [] except Exception as e: logger.error(f"Error listing loaded models: {e}") return [] async def list_available_models(self) -> List[str]: """ List all available models (downloaded). Returns: List of available model names """ try: session = await self._get_session() async with session.get(f"{self._endpoint}/api/tags") as response: if response.status == 200: data = await response.json() models = data.get("models", []) return [m.get("name", "") for m in models] else: return [] except Exception as e: logger.error(f"Error listing available models: {e}") return []