""" FAISS index wrapper with proper dimension handling and persistence. This module provides a robust wrapper around FAISS for storing and searching image embeddings, with proper error handling for dimension mismatches and reliable save/load functionality. """ import pickle import logging from pathlib import Path from typing import List, Dict, Any, Optional import numpy as np try: import faiss except ImportError: faiss = None logger = logging.getLogger(__name__) class FAISSIndex: """ Wrapper around FAISS index with metadata storage and dimension validation. This class handles: - Dimension validation on add/search operations - Metadata storage alongside embeddings - Reliable persistence (save/load) - Automatic index rebuilding on dimension changes """ def __init__(self, dimension: int): """ Initialize a new FAISS index. Args: dimension: Embedding dimension (e.g., 512 for CLIP, 768 for Pix2Struct) Raises: ImportError: If FAISS is not installed ValueError: If dimension is invalid """ if faiss is None: raise ImportError( "FAISS is not installed. " "Install it with: pip install faiss-cpu or faiss-gpu" ) if dimension <= 0: raise ValueError(f"Dimension must be positive, got {dimension}") self.dimension = dimension self.index = faiss.IndexFlatL2(dimension) self.metadata: List[Dict[str, Any]] = [] logger.info(f"FAISSIndex created with dimension={dimension}") def add(self, embeddings: np.ndarray, metadata: List[Dict[str, Any]]): """ Add embeddings to the index with associated metadata. Args: embeddings: Array of shape (N, dimension) containing N embeddings metadata: List of N metadata dictionaries Raises: ValueError: If dimensions don't match or array shapes are invalid """ # Validate input shape if embeddings.ndim == 1: # Single embedding, reshape to (1, dimension) embeddings = embeddings.reshape(1, -1) elif embeddings.ndim != 2: raise ValueError( f"Embeddings must be 1D or 2D array, got shape {embeddings.shape}" ) # Validate dimension if embeddings.shape[1] != self.dimension: raise ValueError( f"Embedding dimension {embeddings.shape[1]} doesn't match " f"index dimension {self.dimension}" ) # Validate metadata count if len(metadata) != embeddings.shape[0]: raise ValueError( f"Number of metadata entries ({len(metadata)}) doesn't match " f"number of embeddings ({embeddings.shape[0]})" ) # Add to FAISS index self.index.add(embeddings.astype('float32')) # Store metadata self.metadata.extend(metadata) logger.debug( f"Added {embeddings.shape[0]} embeddings to index " f"(total: {self.index.ntotal})" ) def search( self, query: np.ndarray, k: int = 5 ) -> List[Dict[str, Any]]: """ Search for the k most similar embeddings. Args: query: Query embedding of shape (dimension,) or (1, dimension) k: Number of results to return Returns: List of dicts with keys: - 'index': Index in the FAISS index - 'distance': L2 distance - 'similarity': Similarity score (1 / (1 + distance)) - 'metadata': Associated metadata dict Raises: ValueError: If query dimension doesn't match index dimension """ if self.index.ntotal == 0: logger.warning("Search called on empty index") return [] # Reshape query if needed if query.ndim == 1: query = query.reshape(1, -1) elif query.ndim != 2: raise ValueError( f"Query must be 1D or 2D array, got shape {query.shape}" ) # Validate dimension if query.shape[1] != self.dimension: raise ValueError( f"Query dimension {query.shape[1]} doesn't match " f"index dimension {self.dimension}" ) # Limit k to available embeddings k = min(k, self.index.ntotal) # Search distances, indices = self.index.search(query.astype('float32'), k) # Format results results = [] for dist, idx in zip(distances[0], indices[0]): # FAISS returns -1 if not enough results if idx >= 0 and idx < len(self.metadata): results.append({ 'index': int(idx), 'distance': float(dist), 'similarity': float(1.0 / (1.0 + dist)), 'metadata': self.metadata[idx] }) return results def save(self, path: str): """ Save index and metadata to disk. Args: path: Base path for saving (will create .index and .metadata files) Raises: RuntimeError: If save operation fails """ try: path_obj = Path(path) path_obj.parent.mkdir(parents=True, exist_ok=True) # Save FAISS index index_file = f"{path}.index" faiss.write_index(self.index, index_file) # Save metadata metadata_file = f"{path}.metadata" with open(metadata_file, 'wb') as f: pickle.dump({ 'dimension': self.dimension, 'metadata': self.metadata }, f) logger.info( f"Saved index with {self.index.ntotal} embeddings to {path}" ) except Exception as e: raise RuntimeError(f"Failed to save index: {e}") def load(self, path: str): """ Load index and metadata from disk. Args: path: Base path for loading (will read .index and .metadata files) Raises: FileNotFoundError: If files don't exist RuntimeError: If load operation fails or dimension mismatch """ try: index_file = f"{path}.index" metadata_file = f"{path}.metadata" # Check files exist if not Path(index_file).exists(): raise FileNotFoundError(f"Index file not found: {index_file}") if not Path(metadata_file).exists(): raise FileNotFoundError(f"Metadata file not found: {metadata_file}") # Load FAISS index loaded_index = faiss.read_index(index_file) # Load metadata with open(metadata_file, 'rb') as f: data = pickle.load(f) loaded_dimension = data['dimension'] loaded_metadata = data['metadata'] # Validate dimension if loaded_dimension != self.dimension: raise RuntimeError( f"Loaded index dimension ({loaded_dimension}) doesn't match " f"current dimension ({self.dimension}). " f"Use rebuild_if_needed() to handle dimension changes." ) # Update state self.index = loaded_index self.metadata = loaded_metadata logger.info( f"Loaded index with {self.index.ntotal} embeddings from {path}" ) except Exception as e: if isinstance(e, (FileNotFoundError, RuntimeError)): raise raise RuntimeError(f"Failed to load index: {e}") def rebuild_if_needed(self, new_dimension: int) -> bool: """ Rebuild index if dimension has changed. This creates a new empty index with the new dimension. Old embeddings are lost and need to be regenerated. Args: new_dimension: New embedding dimension Returns: bool: True if index was rebuilt, False if dimension unchanged """ if new_dimension == self.dimension: return False logger.warning( f"Rebuilding FAISS index: dimension changed from " f"{self.dimension} to {new_dimension}. " f"Old embeddings ({self.index.ntotal}) will be lost." ) # Create new index self.dimension = new_dimension self.index = faiss.IndexFlatL2(new_dimension) self.metadata = [] return True def clear(self): """Clear all embeddings from the index.""" self.index = faiss.IndexFlatL2(self.dimension) self.metadata = [] logger.info("Index cleared") def get_stats(self) -> Dict[str, Any]: """ Get index statistics. Returns: Dict with keys: num_embeddings, dimension, is_trained """ return { 'num_embeddings': self.index.ntotal, 'dimension': self.dimension, 'is_trained': self.index.is_trained } def __len__(self) -> int: """Return number of embeddings in the index.""" return self.index.ntotal def __repr__(self) -> str: """String representation of the index.""" return ( f"FAISSIndex(dimension={self.dimension}, " f"num_embeddings={self.index.ntotal})" )