""" CLIP Manager - Manages CLIP model device migration Handles: - CPU/GPU device migration for CLIP model - Pipeline reinitialization after device change - Graceful fallback on migration failures """ import asyncio import logging from typing import Any, Optional import torch logger = logging.getLogger(__name__) class CLIPManager: """ Manages CLIP model device migration between CPU and GPU. Coordinates with the embedding pipeline to ensure consistent device usage after migration. Example: >>> manager = CLIPManager() >>> await manager.migrate_to_device("cuda") >>> device = manager.get_current_device() """ def __init__(self, model_name: str = "ViT-B-32"): """ Initialize CLIPManager. Args: model_name: CLIP model variant to manage """ self._model_name = model_name self._current_device = "cpu" self._model: Optional[Any] = None self._preprocess: Optional[Any] = None self._initialized = False # Check CUDA availability self._cuda_available = torch.cuda.is_available() if not self._cuda_available: logger.warning("CUDA not available, CLIP will stay on CPU") def get_current_device(self) -> str: """ Get the current device for CLIP model. Returns: "cpu" or "cuda" """ return self._current_device def is_cuda_available(self) -> bool: """Check if CUDA is available for GPU migration.""" return self._cuda_available async def migrate_to_device(self, device: str) -> bool: """ Migrate CLIP model to specified device. Args: device: Target device ("cpu" or "cuda") Returns: True if migration successful """ if device not in ["cpu", "cuda"]: logger.error(f"Invalid device: {device}") return False if device == self._current_device: logger.debug(f"CLIP already on {device}") return True if device == "cuda" and not self._cuda_available: logger.warning("Cannot migrate to CUDA: not available") return False logger.info(f"Migrating CLIP from {self._current_device} to {device}") try: # Run migration in executor to avoid blocking loop = asyncio.get_event_loop() success = await loop.run_in_executor( None, self._do_migration, device ) if success: self._current_device = device logger.info(f"CLIP migrated to {device}") return True except Exception as e: logger.error(f"CLIP migration failed: {e}") return False def _do_migration(self, device: str) -> bool: """ Perform the actual device migration (blocking). Args: device: Target device Returns: True if successful """ try: # If model is loaded, move it if self._model is not None: self._model = self._model.to(device) logger.debug(f"Moved existing model to {device}") # Reinitialize pipeline with new device self.reinitialize_pipeline(device) return True except Exception as e: logger.error(f"Migration error: {e}") return False def reinitialize_pipeline(self, device: Optional[str] = None) -> None: """ Reinitialize the embedding pipeline with current/specified device. Args: device: Device to use (uses current if None) """ device = device or self._current_device try: # Try to notify FusionEngine about device change self._notify_fusion_engine(device) logger.debug(f"Pipeline reinitialized for {device}") except Exception as e: logger.warning(f"Pipeline reinitialization warning: {e}") def _notify_fusion_engine(self, device: str) -> None: """ Notify FusionEngine about device change. This allows the embedding system to update its device configuration. """ try: from core.embedding.fusion_engine import FusionEngine # FusionEngine is typically a singleton, try to get instance # and update its device configuration # This is a soft dependency - if it fails, we continue except ImportError: pass # FusionEngine not available, that's OK def get_model(self) -> Optional[Any]: """ Get the CLIP model instance. Returns: CLIP model or None if not loaded """ return self._model def load_model(self) -> bool: """ Load the CLIP model on current device. Returns: True if loaded successfully """ try: import open_clip model, _, preprocess = open_clip.create_model_and_transforms( self._model_name, pretrained='openai', device=self._current_device ) self._model = model self._preprocess = preprocess self._initialized = True logger.info(f"CLIP model {self._model_name} loaded on {self._current_device}") return True except Exception as e: logger.error(f"Failed to load CLIP model: {e}") return False def unload_model(self) -> None: """Unload the CLIP model to free memory.""" if self._model is not None: del self._model self._model = None self._preprocess = None self._initialized = False # Force garbage collection import gc gc.collect() if self._cuda_available: torch.cuda.empty_cache() logger.info("CLIP model unloaded") def encode_image(self, image) -> Optional[Any]: """ Encode an image using CLIP. Args: image: PIL Image or tensor Returns: Image embedding or None on error """ if not self._initialized or self._model is None: if not self.load_model(): return None try: import torch with torch.no_grad(): if self._preprocess: image_tensor = self._preprocess(image).unsqueeze(0) else: image_tensor = image image_tensor = image_tensor.to(self._current_device) embedding = self._model.encode_image(image_tensor) return embedding.cpu().numpy() except Exception as e: logger.error(f"Image encoding error: {e}") return None