Files
Geniusia_v2/geniusia2/core/embedders/faiss_index.py
2026-03-05 00:20:25 +01:00

310 lines
9.9 KiB
Python

"""
FAISS index wrapper with proper dimension handling and persistence.
This module provides a robust wrapper around FAISS for storing and searching
image embeddings, with proper error handling for dimension mismatches and
reliable save/load functionality.
"""
import pickle
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional
import numpy as np
try:
import faiss
except ImportError:
faiss = None
logger = logging.getLogger(__name__)
class FAISSIndex:
"""
Wrapper around FAISS index with metadata storage and dimension validation.
This class handles:
- Dimension validation on add/search operations
- Metadata storage alongside embeddings
- Reliable persistence (save/load)
- Automatic index rebuilding on dimension changes
"""
def __init__(self, dimension: int):
"""
Initialize a new FAISS index.
Args:
dimension: Embedding dimension (e.g., 512 for CLIP, 768 for Pix2Struct)
Raises:
ImportError: If FAISS is not installed
ValueError: If dimension is invalid
"""
if faiss is None:
raise ImportError(
"FAISS is not installed. "
"Install it with: pip install faiss-cpu or faiss-gpu"
)
if dimension <= 0:
raise ValueError(f"Dimension must be positive, got {dimension}")
self.dimension = dimension
self.index = faiss.IndexFlatL2(dimension)
self.metadata: List[Dict[str, Any]] = []
logger.info(f"FAISSIndex created with dimension={dimension}")
def add(self, embeddings: np.ndarray, metadata: List[Dict[str, Any]]):
"""
Add embeddings to the index with associated metadata.
Args:
embeddings: Array of shape (N, dimension) containing N embeddings
metadata: List of N metadata dictionaries
Raises:
ValueError: If dimensions don't match or array shapes are invalid
"""
# Validate input shape
if embeddings.ndim == 1:
# Single embedding, reshape to (1, dimension)
embeddings = embeddings.reshape(1, -1)
elif embeddings.ndim != 2:
raise ValueError(
f"Embeddings must be 1D or 2D array, got shape {embeddings.shape}"
)
# Validate dimension
if embeddings.shape[1] != self.dimension:
raise ValueError(
f"Embedding dimension {embeddings.shape[1]} doesn't match "
f"index dimension {self.dimension}"
)
# Validate metadata count
if len(metadata) != embeddings.shape[0]:
raise ValueError(
f"Number of metadata entries ({len(metadata)}) doesn't match "
f"number of embeddings ({embeddings.shape[0]})"
)
# Add to FAISS index
self.index.add(embeddings.astype('float32'))
# Store metadata
self.metadata.extend(metadata)
logger.debug(
f"Added {embeddings.shape[0]} embeddings to index "
f"(total: {self.index.ntotal})"
)
def search(
self,
query: np.ndarray,
k: int = 5
) -> List[Dict[str, Any]]:
"""
Search for the k most similar embeddings.
Args:
query: Query embedding of shape (dimension,) or (1, dimension)
k: Number of results to return
Returns:
List of dicts with keys:
- 'index': Index in the FAISS index
- 'distance': L2 distance
- 'similarity': Similarity score (1 / (1 + distance))
- 'metadata': Associated metadata dict
Raises:
ValueError: If query dimension doesn't match index dimension
"""
if self.index.ntotal == 0:
logger.warning("Search called on empty index")
return []
# Reshape query if needed
if query.ndim == 1:
query = query.reshape(1, -1)
elif query.ndim != 2:
raise ValueError(
f"Query must be 1D or 2D array, got shape {query.shape}"
)
# Validate dimension
if query.shape[1] != self.dimension:
raise ValueError(
f"Query dimension {query.shape[1]} doesn't match "
f"index dimension {self.dimension}"
)
# Limit k to available embeddings
k = min(k, self.index.ntotal)
# Search
distances, indices = self.index.search(query.astype('float32'), k)
# Format results
results = []
for dist, idx in zip(distances[0], indices[0]):
# FAISS returns -1 if not enough results
if idx >= 0 and idx < len(self.metadata):
results.append({
'index': int(idx),
'distance': float(dist),
'similarity': float(1.0 / (1.0 + dist)),
'metadata': self.metadata[idx]
})
return results
def save(self, path: str):
"""
Save index and metadata to disk.
Args:
path: Base path for saving (will create .index and .metadata files)
Raises:
RuntimeError: If save operation fails
"""
try:
path_obj = Path(path)
path_obj.parent.mkdir(parents=True, exist_ok=True)
# Save FAISS index
index_file = f"{path}.index"
faiss.write_index(self.index, index_file)
# Save metadata
metadata_file = f"{path}.metadata"
with open(metadata_file, 'wb') as f:
pickle.dump({
'dimension': self.dimension,
'metadata': self.metadata
}, f)
logger.info(
f"Saved index with {self.index.ntotal} embeddings to {path}"
)
except Exception as e:
raise RuntimeError(f"Failed to save index: {e}")
def load(self, path: str):
"""
Load index and metadata from disk.
Args:
path: Base path for loading (will read .index and .metadata files)
Raises:
FileNotFoundError: If files don't exist
RuntimeError: If load operation fails or dimension mismatch
"""
try:
index_file = f"{path}.index"
metadata_file = f"{path}.metadata"
# Check files exist
if not Path(index_file).exists():
raise FileNotFoundError(f"Index file not found: {index_file}")
if not Path(metadata_file).exists():
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
# Load FAISS index
loaded_index = faiss.read_index(index_file)
# Load metadata
with open(metadata_file, 'rb') as f:
data = pickle.load(f)
loaded_dimension = data['dimension']
loaded_metadata = data['metadata']
# Validate dimension
if loaded_dimension != self.dimension:
raise RuntimeError(
f"Loaded index dimension ({loaded_dimension}) doesn't match "
f"current dimension ({self.dimension}). "
f"Use rebuild_if_needed() to handle dimension changes."
)
# Update state
self.index = loaded_index
self.metadata = loaded_metadata
logger.info(
f"Loaded index with {self.index.ntotal} embeddings from {path}"
)
except Exception as e:
if isinstance(e, (FileNotFoundError, RuntimeError)):
raise
raise RuntimeError(f"Failed to load index: {e}")
def rebuild_if_needed(self, new_dimension: int) -> bool:
"""
Rebuild index if dimension has changed.
This creates a new empty index with the new dimension.
Old embeddings are lost and need to be regenerated.
Args:
new_dimension: New embedding dimension
Returns:
bool: True if index was rebuilt, False if dimension unchanged
"""
if new_dimension == self.dimension:
return False
logger.warning(
f"Rebuilding FAISS index: dimension changed from "
f"{self.dimension} to {new_dimension}. "
f"Old embeddings ({self.index.ntotal}) will be lost."
)
# Create new index
self.dimension = new_dimension
self.index = faiss.IndexFlatL2(new_dimension)
self.metadata = []
return True
def clear(self):
"""Clear all embeddings from the index."""
self.index = faiss.IndexFlatL2(self.dimension)
self.metadata = []
logger.info("Index cleared")
def get_stats(self) -> Dict[str, Any]:
"""
Get index statistics.
Returns:
Dict with keys: num_embeddings, dimension, is_trained
"""
return {
'num_embeddings': self.index.ntotal,
'dimension': self.dimension,
'is_trained': self.index.is_trained
}
def __len__(self) -> int:
"""Return number of embeddings in the index."""
return self.index.ntotal
def __repr__(self) -> str:
"""String representation of the index."""
return (
f"FAISSIndex(dimension={self.dimension}, "
f"num_embeddings={self.index.ntotal})"
)